Improved cleanup and model loading for parakeet

This commit is contained in:
Beingpax 2025-10-04 13:54:10 +05:45
parent 0af071aede
commit eb364416ea
3 changed files with 103 additions and 110 deletions

View File

@ -1,116 +1,107 @@
import Foundation
import CoreML
import AVFoundation
import FluidAudio
import os.log
class ParakeetTranscriptionService: TranscriptionService {
private var asrManager: AsrManager?
private var vadManager: VadManager?
private let customModelsDirectory: URL?
@Published var isModelLoaded = false
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
init(customModelsDirectory: URL? = nil) {
self.customModelsDirectory = customModelsDirectory
}
func loadModel() async throws {
if isModelLoaded {
return
}
if let customModelsDirectory {
do {
asrManager = AsrManager(config: .default)
let models = try await AsrModels.load(from: customModelsDirectory)
try await asrManager?.initialize(models: models)
isModelLoaded = true
} catch {
isModelLoaded = false
asrManager = nil
}
}
}
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
if asrManager == nil || !isModelLoaded {
try await loadModel()
}
guard let asrManager = asrManager else {
throw ASRError.notInitialized
}
let audioSamples = try readAudioSamples(from: audioURL)
let durationSeconds = Double(audioSamples.count) / 16000.0
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
let speechAudio: [Float]
if durationSeconds < 20.0 || !isVADEnabled {
speechAudio = audioSamples
} else {
let vadConfig = VadConfig(threshold: 0.7)
if vadManager == nil, let customModelsDirectory {
do {
vadManager = try await VadManager(
config: vadConfig,
modelDirectory: customModelsDirectory.deletingLastPathComponent()
)
import Foundation
import CoreML
import AVFoundation
import FluidAudio
import os.log
class ParakeetTranscriptionService: TranscriptionService {
private var asrManager: AsrManager?
private var vadManager: VadManager?
private let customModelsDirectory: URL?
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
init(customModelsDirectory: URL? = nil) {
self.customModelsDirectory = customModelsDirectory
}
func loadModel() async throws {
guard asrManager == nil else {
return
}
guard let customModelsDirectory else {
throw ASRError.modelLoadFailed
}
let manager = AsrManager(config: .default)
let models = try await AsrModels.load(from: customModelsDirectory)
try await manager.initialize(models: models)
self.asrManager = manager
logger.notice("🦜 Parakeet ASR models loaded successfully")
}
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
try await loadModel()
guard let asrManager = asrManager else {
logger.notice("🦜 ASR manager not initialized, cannot transcribe")
throw ASRError.notInitialized
}
let audioSamples = try readAudioSamples(from: audioURL)
let durationSeconds = Double(audioSamples.count) / 16000.0
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
let speechAudio: [Float]
if durationSeconds < 20.0 || !isVADEnabled {
speechAudio = audioSamples
} else {
let vadConfig = VadConfig(threshold: 0.7)
if vadManager == nil, let customModelsDirectory {
do {
vadManager = try await VadManager(
config: vadConfig,
modelDirectory: customModelsDirectory.deletingLastPathComponent()
)
} catch {
// Silent failure
}
}
logger.notice("🦜 VAD initialization failed, using full audio: \(error.localizedDescription)")
}
}
do {
if let vadManager {
let segments = try await vadManager.segmentSpeechAudio(audioSamples)
if segments.isEmpty {
speechAudio = audioSamples
} else {
speechAudio = segments.flatMap { $0 }
}
speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 }
} else {
speechAudio = audioSamples
}
} catch {
logger.notice("🦜 VAD segmentation failed, using full audio: \(error.localizedDescription)")
speechAudio = audioSamples
}
}
let result = try await asrManager.transcribe(speechAudio)
asrManager.cleanup()
isModelLoaded = false
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
let text = result.text
return text
}
private func readAudioSamples(from url: URL) throws -> [Float] {
do {
let data = try Data(contentsOf: url)
guard data.count > 44 else {
throw ASRError.invalidAudioData
}
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
} catch {
throw ASRError.invalidAudioData
}
}
}
}
}
let result = try await asrManager.transcribe(speechAudio, source: .system)
return result.text
}
private func readAudioSamples(from url: URL) throws -> [Float] {
do {
let data = try Data(contentsOf: url)
guard data.count > 44 else {
throw ASRError.invalidAudioData
}
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
} catch {
throw ASRError.invalidAudioData
}
}
func cleanup() {
asrManager?.cleanup()
asrManager = nil
vadManager = nil
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
}
}

View File

@ -341,6 +341,8 @@ extension WhisperState {
await whisperContext?.releaseResources()
whisperContext = nil
isModelLoaded = false
parakeetTranscriptionService.cleanup()
}
// MARK: - Helper Methods

View File

@ -59,7 +59,7 @@ class WhisperState: NSObject, ObservableObject {
private var localTranscriptionService: LocalTranscriptionService!
private lazy var cloudTranscriptionService = CloudTranscriptionService()
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
private lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory)
internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory)
private var modelUrl: URL? {
let possibleURLs = [
@ -200,8 +200,8 @@ class WhisperState: NSObject, ObservableObject {
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
}
}
} else if let model = self.currentTranscriptionModel, model.provider == .parakeet {
try? await parakeetTranscriptionService.loadModel()
} else if let model = self.currentTranscriptionModel, model.provider == .parakeet {
try? await self.parakeetTranscriptionService.loadModel()
}
if let enhancementService = self.enhancementService,