Clean up
This commit is contained in:
parent
a0e4dd1367
commit
3f01f49f56
@ -1,35 +1,28 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import CoreML
|
import CoreML
|
||||||
import AVFoundation
|
import AVFoundation
|
||||||
import FluidAudio
|
import FluidAudio
|
||||||
import os.log
|
import os.log
|
||||||
|
|
||||||
class ParakeetTranscriptionService: TranscriptionService {
|
class ParakeetTranscriptionService: TranscriptionService {
|
||||||
private var asrManager: AsrManager?
|
private var asrManager: AsrManager?
|
||||||
private var vadManager: VadManager?
|
private var vadManager: VadManager?
|
||||||
private var activeVersion: AsrModelVersion?
|
private var activeVersion: AsrModelVersion?
|
||||||
private let logger = Logger(
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink.parakeet", category: "ParakeetTranscriptionService")
|
||||||
subsystem: "com.prakashjoshipax.voiceink.parakeet",
|
|
||||||
category: "ParakeetTranscriptionService"
|
|
||||||
)
|
|
||||||
|
|
||||||
private func version(for model: any TranscriptionModel) -> AsrModelVersion {
|
private func version(for model: any TranscriptionModel) -> AsrModelVersion {
|
||||||
model.name.lowercased().contains("v2") ? .v2 : .v3
|
model.name.lowercased().contains("v2") ? .v2 : .v3
|
||||||
}
|
}
|
||||||
|
|
||||||
private func ensureModelsLoaded(for version: AsrModelVersion) async throws {
|
private func ensureModelsLoaded(for version: AsrModelVersion) async throws {
|
||||||
if let manager = asrManager, activeVersion == version {
|
if let manager = asrManager, activeVersion == version {
|
||||||
logger.notice("🦜 Parakeet \(version == .v2 ? "V2" : "V3") models already loaded, skipping")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
||||||
let manager = AsrManager(config: .default)
|
let manager = AsrManager(config: .default)
|
||||||
let cacheDirectory = AsrModels.defaultCacheDirectory(for: version)
|
let models = try await AsrModels.loadFromCache(
|
||||||
logger.notice("🦜 Loading Parakeet models from \(cacheDirectory.path)")
|
|
||||||
let models = try await AsrModels.load(
|
|
||||||
from: cacheDirectory,
|
|
||||||
configuration: nil,
|
configuration: nil,
|
||||||
version: version
|
version: version
|
||||||
)
|
)
|
||||||
@ -41,78 +34,72 @@ class ParakeetTranscriptionService: TranscriptionService {
|
|||||||
func loadModel(for model: ParakeetModel) async throws {
|
func loadModel(for model: ParakeetModel) async throws {
|
||||||
try await ensureModelsLoaded(for: version(for: model))
|
try await ensureModelsLoaded(for: version(for: model))
|
||||||
}
|
}
|
||||||
|
|
||||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||||
let targetVersion = version(for: model)
|
let targetVersion = version(for: model)
|
||||||
try await ensureModelsLoaded(for: targetVersion)
|
try await ensureModelsLoaded(for: targetVersion)
|
||||||
|
|
||||||
guard let asrManager = asrManager else {
|
guard let asrManager = asrManager else {
|
||||||
logger.notice("🦜 ASR manager not initialized, cannot transcribe")
|
|
||||||
throw ASRError.notInitialized
|
throw ASRError.notInitialized
|
||||||
}
|
}
|
||||||
|
|
||||||
let audioSamples = try readAudioSamples(from: audioURL)
|
let audioSamples = try readAudioSamples(from: audioURL)
|
||||||
|
|
||||||
let durationSeconds = Double(audioSamples.count) / 16000.0
|
let durationSeconds = Double(audioSamples.count) / 16000.0
|
||||||
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
||||||
|
|
||||||
let speechAudio: [Float]
|
var speechAudio = audioSamples
|
||||||
if durationSeconds < 20.0 || !isVADEnabled {
|
if durationSeconds >= 20.0, isVADEnabled {
|
||||||
speechAudio = audioSamples
|
|
||||||
} else {
|
|
||||||
let vadConfig = VadConfig(threshold: 0.7)
|
let vadConfig = VadConfig(threshold: 0.7)
|
||||||
if vadManager == nil {
|
if vadManager == nil {
|
||||||
do {
|
do {
|
||||||
vadManager = try await VadManager(config: vadConfig)
|
vadManager = try await VadManager(config: vadConfig)
|
||||||
} catch {
|
} catch {
|
||||||
logger.notice("🦜 VAD initialization failed, using full audio: \(error.localizedDescription)")
|
logger.notice("VAD init failed; falling back to full audio: \(error.localizedDescription)")
|
||||||
|
vadManager = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
do {
|
if let vadManager {
|
||||||
if let vadManager {
|
do {
|
||||||
let segments = try await vadManager.segmentSpeechAudio(audioSamples)
|
let segments = try await vadManager.segmentSpeechAudio(audioSamples)
|
||||||
speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 }
|
speechAudio = segments.isEmpty ? audioSamples : segments.flatMap { $0 }
|
||||||
} else {
|
} catch {
|
||||||
|
logger.notice("VAD segmentation failed; using full audio: \(error.localizedDescription)")
|
||||||
speechAudio = audioSamples
|
speechAudio = audioSamples
|
||||||
}
|
}
|
||||||
} catch {
|
}
|
||||||
logger.notice("🦜 VAD segmentation failed, using full audio: \(error.localizedDescription)")
|
}
|
||||||
speechAudio = audioSamples
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = try await asrManager.transcribe(speechAudio)
|
let result = try await asrManager.transcribe(speechAudio)
|
||||||
|
|
||||||
logger.notice("🦜 Parakeet transcription result: \(result.text)")
|
return result.text
|
||||||
|
}
|
||||||
return result.text
|
|
||||||
}
|
private func readAudioSamples(from url: URL) throws -> [Float] {
|
||||||
|
do {
|
||||||
private func readAudioSamples(from url: URL) throws -> [Float] {
|
let data = try Data(contentsOf: url)
|
||||||
do {
|
guard data.count > 44 else {
|
||||||
let data = try Data(contentsOf: url)
|
throw ASRError.invalidAudioData
|
||||||
guard data.count > 44 else {
|
}
|
||||||
throw ASRError.invalidAudioData
|
|
||||||
}
|
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||||
|
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
}
|
||||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
}
|
||||||
}
|
|
||||||
}
|
return floats
|
||||||
|
} catch {
|
||||||
return floats
|
throw ASRError.invalidAudioData
|
||||||
} catch {
|
}
|
||||||
throw ASRError.invalidAudioData
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanup() {
|
func cleanup() {
|
||||||
asrManager?.cleanup()
|
asrManager?.cleanup()
|
||||||
asrManager = nil
|
asrManager = nil
|
||||||
vadManager = nil
|
vadManager = nil
|
||||||
activeVersion = nil
|
activeVersion = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user