This commit is contained in:
Beingpax 2025-10-19 16:59:23 +05:45
parent a0e4dd1367
commit 3f01f49f56

View File

@ -1,35 +1,28 @@
import Foundation import Foundation
import CoreML import CoreML
import AVFoundation import AVFoundation
import FluidAudio import FluidAudio
import os.log import os.log
class ParakeetTranscriptionService: TranscriptionService { class ParakeetTranscriptionService: TranscriptionService {
private var asrManager: AsrManager? private var asrManager: AsrManager?
private var vadManager: VadManager? private var vadManager: VadManager?
private var activeVersion: AsrModelVersion? private var activeVersion: AsrModelVersion?
private let logger = Logger( private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink.parakeet", category: "ParakeetTranscriptionService")
subsystem: "com.prakashjoshipax.voiceink.parakeet",
category: "ParakeetTranscriptionService"
)
private func version(for model: any TranscriptionModel) -> AsrModelVersion { private func version(for model: any TranscriptionModel) -> AsrModelVersion {
model.name.lowercased().contains("v2") ? .v2 : .v3 model.name.lowercased().contains("v2") ? .v2 : .v3
} }
private func ensureModelsLoaded(for version: AsrModelVersion) async throws { private func ensureModelsLoaded(for version: AsrModelVersion) async throws {
if let manager = asrManager, activeVersion == version { if let manager = asrManager, activeVersion == version {
logger.notice("🦜 Parakeet \(version == .v2 ? "V2" : "V3") models already loaded, skipping")
return return
} }
cleanup() cleanup()
let manager = AsrManager(config: .default) let manager = AsrManager(config: .default)
let cacheDirectory = AsrModels.defaultCacheDirectory(for: version) let models = try await AsrModels.loadFromCache(
logger.notice("🦜 Loading Parakeet models from \(cacheDirectory.path)")
let models = try await AsrModels.load(
from: cacheDirectory,
configuration: nil, configuration: nil,
version: version version: version
) )
@ -41,78 +34,72 @@ class ParakeetTranscriptionService: TranscriptionService {
func loadModel(for model: ParakeetModel) async throws { func loadModel(for model: ParakeetModel) async throws {
try await ensureModelsLoaded(for: version(for: model)) try await ensureModelsLoaded(for: version(for: model))
} }
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
let targetVersion = version(for: model) let targetVersion = version(for: model)
try await ensureModelsLoaded(for: targetVersion) try await ensureModelsLoaded(for: targetVersion)
guard let asrManager = asrManager else { guard let asrManager = asrManager else {
logger.notice("🦜 ASR manager not initialized, cannot transcribe")
throw ASRError.notInitialized throw ASRError.notInitialized
} }
let audioSamples = try readAudioSamples(from: audioURL) let audioSamples = try readAudioSamples(from: audioURL)
let durationSeconds = Double(audioSamples.count) / 16000.0 let durationSeconds = Double(audioSamples.count) / 16000.0
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
let speechAudio: [Float] var speechAudio = audioSamples
if durationSeconds < 20.0 || !isVADEnabled { if durationSeconds >= 20.0, isVADEnabled {
speechAudio = audioSamples
} else {
let vadConfig = VadConfig(threshold: 0.7) let vadConfig = VadConfig(threshold: 0.7)
if vadManager == nil { if vadManager == nil {
do { do {
vadManager = try await VadManager(config: vadConfig) vadManager = try await VadManager(config: vadConfig)
} catch { } catch {
logger.notice("🦜 VAD initialization failed, using full audio: \(error.localizedDescription)") logger.notice("VAD init failed; falling back to full audio: \(error.localizedDescription)")
vadManager = nil
} }
} }
do { if let vadManager {
if let vadManager { do {
let segments = try await vadManager.segmentSpeechAudio(audioSamples) let segments = try await vadManager.segmentSpeechAudio(audioSamples)
speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 } speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 }
} else { } catch {
logger.notice("VAD segmentation failed; using full audio: \(error.localizedDescription)")
speechAudio = audioSamples speechAudio = audioSamples
} }
} catch { }
logger.notice("🦜 VAD segmentation failed, using full audio: \(error.localizedDescription)") }
speechAudio = audioSamples
}
}
let result = try await asrManager.transcribe(speechAudio) let result = try await asrManager.transcribe(speechAudio)
logger.notice("🦜 Parakeet transcription result: \(result.text)") return result.text
}
return result.text
} private func readAudioSamples(from url: URL) throws -> [Float] {
do {
private func readAudioSamples(from url: URL) throws -> [Float] { let data = try Data(contentsOf: url)
do { guard data.count > 44 else {
let data = try Data(contentsOf: url) throw ASRError.invalidAudioData
guard data.count > 44 else { }
throw ASRError.invalidAudioData
} let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let floats = stride(from: 44, to: data.count, by: 2).map { let short = Int16(littleEndian: $0.load(as: Int16.self))
return data[$0..<$0 + 2].withUnsafeBytes { return max(-1.0, min(Float(short) / 32767.0, 1.0))
let short = Int16(littleEndian: $0.load(as: Int16.self)) }
return max(-1.0, min(Float(short) / 32767.0, 1.0)) }
}
} return floats
} catch {
return floats throw ASRError.invalidAudioData
} catch { }
throw ASRError.invalidAudioData }
}
}
func cleanup() { func cleanup() {
asrManager?.cleanup() asrManager?.cleanup()
asrManager = nil asrManager = nil
vadManager = nil vadManager = nil
activeVersion = nil activeVersion = nil
} }
} }