From eb364416ea660b59f70008bb8b96ad82200bbea4 Mon Sep 17 00:00:00 2001 From: Beingpax Date: Sat, 4 Oct 2025 13:54:10 +0545 Subject: [PATCH] Improved cleanup and model loading for parakeet --- .../ParakeetTranscriptionService.swift | 205 +++++++++--------- .../WhisperState+LocalModelManager.swift | 2 + VoiceInk/Whisper/WhisperState.swift | 6 +- 3 files changed, 103 insertions(+), 110 deletions(-) diff --git a/VoiceInk/Services/ParakeetTranscriptionService.swift b/VoiceInk/Services/ParakeetTranscriptionService.swift index aada335..f855e75 100644 --- a/VoiceInk/Services/ParakeetTranscriptionService.swift +++ b/VoiceInk/Services/ParakeetTranscriptionService.swift @@ -1,116 +1,107 @@ -import Foundation -import CoreML -import AVFoundation -import FluidAudio -import os.log - - - -class ParakeetTranscriptionService: TranscriptionService { - private var asrManager: AsrManager? - private var vadManager: VadManager? - private let customModelsDirectory: URL? - @Published var isModelLoaded = false - private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService") - - init(customModelsDirectory: URL? = nil) { - self.customModelsDirectory = customModelsDirectory - } - - func loadModel() async throws { - if isModelLoaded { - return - } - - if let customModelsDirectory { - do { - asrManager = AsrManager(config: .default) - let models = try await AsrModels.load(from: customModelsDirectory) - try await asrManager?.initialize(models: models) - isModelLoaded = true - } catch { - isModelLoaded = false - asrManager = nil - } - } - } - - func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { - if asrManager == nil || !isModelLoaded { - try await loadModel() - } - - guard let asrManager = asrManager else { - throw ASRError.notInitialized - } - - let audioSamples = try readAudioSamples(from: audioURL) - - let durationSeconds = Double(audioSamples.count) / 16000.0 - - let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true - - let speechAudio: [Float] - if durationSeconds < 20.0 || !isVADEnabled { - speechAudio = audioSamples - } else { - let vadConfig = VadConfig(threshold: 0.7) - if vadManager == nil, let customModelsDirectory { - do { - vadManager = try await VadManager( - config: vadConfig, - modelDirectory: customModelsDirectory.deletingLastPathComponent() - ) +import Foundation +import CoreML +import AVFoundation +import FluidAudio +import os.log + +class ParakeetTranscriptionService: TranscriptionService { + private var asrManager: AsrManager? + private var vadManager: VadManager? + private let customModelsDirectory: URL? + private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService") + + init(customModelsDirectory: URL? = nil) { + self.customModelsDirectory = customModelsDirectory + } + + func loadModel() async throws { + guard asrManager == nil else { + return + } + + guard let customModelsDirectory else { + throw ASRError.modelLoadFailed + } + + let manager = AsrManager(config: .default) + let models = try await AsrModels.load(from: customModelsDirectory) + try await manager.initialize(models: models) + + self.asrManager = manager + logger.notice("🦜 Parakeet ASR models loaded successfully") + } + + func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { + try await loadModel() + + guard let asrManager = asrManager else { + logger.notice("🦜 ASR manager not initialized, cannot transcribe") + throw ASRError.notInitialized + } + + let audioSamples = try readAudioSamples(from: audioURL) + + let durationSeconds = Double(audioSamples.count) / 16000.0 + let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true + + let speechAudio: [Float] + if durationSeconds < 20.0 || !isVADEnabled { + speechAudio = audioSamples + } else { + let vadConfig = VadConfig(threshold: 0.7) + if vadManager == nil, let customModelsDirectory { + do { + vadManager = try await VadManager( + config: vadConfig, + modelDirectory: customModelsDirectory.deletingLastPathComponent() + ) } catch { - // Silent failure - } - } - + logger.notice("🦜 VAD initialization failed, using full audio: \(error.localizedDescription)") + } + } + do { if let vadManager { let segments = try await vadManager.segmentSpeechAudio(audioSamples) - if segments.isEmpty { - speechAudio = audioSamples - } else { - speechAudio = segments.flatMap { $0 } - } + speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 } } else { speechAudio = audioSamples } } catch { + logger.notice("🦜 VAD segmentation failed, using full audio: \(error.localizedDescription)") speechAudio = audioSamples - } - } - - let result = try await asrManager.transcribe(speechAudio) - - asrManager.cleanup() - isModelLoaded = false - logger.notice("🦜 Parakeet ASR models cleaned up from memory") - - let text = result.text - - return text - } - - private func readAudioSamples(from url: URL) throws -> [Float] { - do { - let data = try Data(contentsOf: url) - guard data.count > 44 else { - throw ASRError.invalidAudioData - } - - let floats = stride(from: 44, to: data.count, by: 2).map { - return data[$0..<$0 + 2].withUnsafeBytes { - let short = Int16(littleEndian: $0.load(as: Int16.self)) - return max(-1.0, min(Float(short) / 32767.0, 1.0)) - } - } - - return floats - } catch { - throw ASRError.invalidAudioData - } - } - -} + } + } + + let result = try await asrManager.transcribe(speechAudio, source: .system) + + return result.text + } + + private func readAudioSamples(from url: URL) throws -> [Float] { + do { + let data = try Data(contentsOf: url) + guard data.count > 44 else { + throw ASRError.invalidAudioData + } + + let floats = stride(from: 44, to: data.count, by: 2).map { + return data[$0..<$0 + 2].withUnsafeBytes { + let short = Int16(littleEndian: $0.load(as: Int16.self)) + return max(-1.0, min(Float(short) / 32767.0, 1.0)) + } + } + + return floats + } catch { + throw ASRError.invalidAudioData + } + } + + func cleanup() { + asrManager?.cleanup() + asrManager = nil + vadManager = nil + logger.notice("🦜 Parakeet ASR models cleaned up from memory") + } +} \ No newline at end of file diff --git a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift index ecd760f..6ba1c1b 100644 --- a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift +++ b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift @@ -341,6 +341,8 @@ extension WhisperState { await whisperContext?.releaseResources() whisperContext = nil isModelLoaded = false + + parakeetTranscriptionService.cleanup() } // MARK: - Helper Methods diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index d9ff183..23d883c 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -59,7 +59,7 @@ class WhisperState: NSObject, ObservableObject { private var localTranscriptionService: LocalTranscriptionService! private lazy var cloudTranscriptionService = CloudTranscriptionService() private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService() - private lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory) + internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory) private var modelUrl: URL? { let possibleURLs = [ @@ -200,8 +200,8 @@ class WhisperState: NSObject, ObservableObject { self.logger.error("❌ Model loading failed: \(error.localizedDescription)") } } - } else if let model = self.currentTranscriptionModel, model.provider == .parakeet { - try? await parakeetTranscriptionService.loadModel() + } else if let model = self.currentTranscriptionModel, model.provider == .parakeet { + try? await self.parakeetTranscriptionService.loadModel() } if let enhancementService = self.enhancementService,