Improved cleanup and model loading for parakeet
This commit is contained in:
parent
0af071aede
commit
eb364416ea
@ -1,116 +1,107 @@
|
||||
import Foundation
|
||||
import CoreML
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import os.log
|
||||
|
||||
|
||||
|
||||
class ParakeetTranscriptionService: TranscriptionService {
|
||||
private var asrManager: AsrManager?
|
||||
private var vadManager: VadManager?
|
||||
private let customModelsDirectory: URL?
|
||||
@Published var isModelLoaded = false
|
||||
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
|
||||
|
||||
init(customModelsDirectory: URL? = nil) {
|
||||
self.customModelsDirectory = customModelsDirectory
|
||||
}
|
||||
|
||||
func loadModel() async throws {
|
||||
if isModelLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
if let customModelsDirectory {
|
||||
do {
|
||||
asrManager = AsrManager(config: .default)
|
||||
let models = try await AsrModels.load(from: customModelsDirectory)
|
||||
try await asrManager?.initialize(models: models)
|
||||
isModelLoaded = true
|
||||
} catch {
|
||||
isModelLoaded = false
|
||||
asrManager = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
if asrManager == nil || !isModelLoaded {
|
||||
try await loadModel()
|
||||
}
|
||||
|
||||
guard let asrManager = asrManager else {
|
||||
throw ASRError.notInitialized
|
||||
}
|
||||
|
||||
let audioSamples = try readAudioSamples(from: audioURL)
|
||||
|
||||
let durationSeconds = Double(audioSamples.count) / 16000.0
|
||||
|
||||
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
||||
|
||||
let speechAudio: [Float]
|
||||
if durationSeconds < 20.0 || !isVADEnabled {
|
||||
speechAudio = audioSamples
|
||||
} else {
|
||||
let vadConfig = VadConfig(threshold: 0.7)
|
||||
if vadManager == nil, let customModelsDirectory {
|
||||
do {
|
||||
vadManager = try await VadManager(
|
||||
config: vadConfig,
|
||||
modelDirectory: customModelsDirectory.deletingLastPathComponent()
|
||||
)
|
||||
import Foundation
|
||||
import CoreML
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import os.log
|
||||
|
||||
class ParakeetTranscriptionService: TranscriptionService {
|
||||
private var asrManager: AsrManager?
|
||||
private var vadManager: VadManager?
|
||||
private let customModelsDirectory: URL?
|
||||
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
|
||||
|
||||
init(customModelsDirectory: URL? = nil) {
|
||||
self.customModelsDirectory = customModelsDirectory
|
||||
}
|
||||
|
||||
func loadModel() async throws {
|
||||
guard asrManager == nil else {
|
||||
return
|
||||
}
|
||||
|
||||
guard let customModelsDirectory else {
|
||||
throw ASRError.modelLoadFailed
|
||||
}
|
||||
|
||||
let manager = AsrManager(config: .default)
|
||||
let models = try await AsrModels.load(from: customModelsDirectory)
|
||||
try await manager.initialize(models: models)
|
||||
|
||||
self.asrManager = manager
|
||||
logger.notice("🦜 Parakeet ASR models loaded successfully")
|
||||
}
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
try await loadModel()
|
||||
|
||||
guard let asrManager = asrManager else {
|
||||
logger.notice("🦜 ASR manager not initialized, cannot transcribe")
|
||||
throw ASRError.notInitialized
|
||||
}
|
||||
|
||||
let audioSamples = try readAudioSamples(from: audioURL)
|
||||
|
||||
let durationSeconds = Double(audioSamples.count) / 16000.0
|
||||
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
||||
|
||||
let speechAudio: [Float]
|
||||
if durationSeconds < 20.0 || !isVADEnabled {
|
||||
speechAudio = audioSamples
|
||||
} else {
|
||||
let vadConfig = VadConfig(threshold: 0.7)
|
||||
if vadManager == nil, let customModelsDirectory {
|
||||
do {
|
||||
vadManager = try await VadManager(
|
||||
config: vadConfig,
|
||||
modelDirectory: customModelsDirectory.deletingLastPathComponent()
|
||||
)
|
||||
} catch {
|
||||
// Silent failure
|
||||
}
|
||||
}
|
||||
|
||||
logger.notice("🦜 VAD initialization failed, using full audio: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
do {
|
||||
if let vadManager {
|
||||
let segments = try await vadManager.segmentSpeechAudio(audioSamples)
|
||||
if segments.isEmpty {
|
||||
speechAudio = audioSamples
|
||||
} else {
|
||||
speechAudio = segments.flatMap { $0 }
|
||||
}
|
||||
speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 }
|
||||
} else {
|
||||
speechAudio = audioSamples
|
||||
}
|
||||
} catch {
|
||||
logger.notice("🦜 VAD segmentation failed, using full audio: \(error.localizedDescription)")
|
||||
speechAudio = audioSamples
|
||||
}
|
||||
}
|
||||
|
||||
let result = try await asrManager.transcribe(speechAudio)
|
||||
|
||||
asrManager.cleanup()
|
||||
isModelLoaded = false
|
||||
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
||||
|
||||
let text = result.text
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
private func readAudioSamples(from url: URL) throws -> [Float] {
|
||||
do {
|
||||
let data = try Data(contentsOf: url)
|
||||
guard data.count > 44 else {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
return floats
|
||||
} catch {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = try await asrManager.transcribe(speechAudio, source: .system)
|
||||
|
||||
return result.text
|
||||
}
|
||||
|
||||
private func readAudioSamples(from url: URL) throws -> [Float] {
|
||||
do {
|
||||
let data = try Data(contentsOf: url)
|
||||
guard data.count > 44 else {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
return floats
|
||||
} catch {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
asrManager?.cleanup()
|
||||
asrManager = nil
|
||||
vadManager = nil
|
||||
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
||||
}
|
||||
}
|
||||
@ -341,6 +341,8 @@ extension WhisperState {
|
||||
await whisperContext?.releaseResources()
|
||||
whisperContext = nil
|
||||
isModelLoaded = false
|
||||
|
||||
parakeetTranscriptionService.cleanup()
|
||||
}
|
||||
|
||||
// MARK: - Helper Methods
|
||||
|
||||
@ -59,7 +59,7 @@ class WhisperState: NSObject, ObservableObject {
|
||||
private var localTranscriptionService: LocalTranscriptionService!
|
||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
private lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory)
|
||||
internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService(customModelsDirectory: parakeetModelsDirectory)
|
||||
|
||||
private var modelUrl: URL? {
|
||||
let possibleURLs = [
|
||||
@ -200,8 +200,8 @@ class WhisperState: NSObject, ObservableObject {
|
||||
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
} else if let model = self.currentTranscriptionModel, model.provider == .parakeet {
|
||||
try? await parakeetTranscriptionService.loadModel()
|
||||
} else if let model = self.currentTranscriptionModel, model.provider == .parakeet {
|
||||
try? await self.parakeetTranscriptionService.loadModel()
|
||||
}
|
||||
|
||||
if let enhancementService = self.enhancementService,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user