Simplified Parakeet transcription without Whisper VAD
This commit is contained in:
parent
afd6e91207
commit
6e6773068f
@ -15,7 +15,6 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
|
||||
init(customModelsDirectory: URL? = nil) {
|
||||
self.customModelsDirectory = customModelsDirectory
|
||||
logger.notice("🦜 ParakeetTranscriptionService initialized with directory: \(customModelsDirectory?.path ?? "default")")
|
||||
}
|
||||
|
||||
func loadModel() async throws {
|
||||
@ -23,36 +22,27 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
return
|
||||
}
|
||||
|
||||
logger.notice("🦜 Starting Parakeet model loading")
|
||||
|
||||
|
||||
do {
|
||||
|
||||
asrManager = AsrManager(config: .default)
|
||||
let models: AsrModels
|
||||
if let customDirectory = customModelsDirectory {
|
||||
logger.notice("🦜 Loading models from custom directory: \(customDirectory.path)")
|
||||
models = try await AsrModels.downloadAndLoad(to: customDirectory)
|
||||
} else {
|
||||
logger.notice("🦜 Loading models from default directory")
|
||||
models = try await AsrModels.downloadAndLoad()
|
||||
}
|
||||
if let customDirectory = customModelsDirectory {
|
||||
logger.notice("🦜 Loading Parakeet models from: \(customDirectory.path)")
|
||||
models = try await AsrModels.downloadAndLoad(to: customDirectory)
|
||||
} else {
|
||||
logger.notice("🦜 Loading Parakeet models from default directory")
|
||||
models = try await AsrModels.downloadAndLoad()
|
||||
}
|
||||
|
||||
try await asrManager?.initialize(models: models)
|
||||
isModelLoaded = true
|
||||
logger.notice("🦜 Parakeet model loaded successfully")
|
||||
isModelLoaded = true
|
||||
logger.notice("🦜 Parakeet model loaded successfully")
|
||||
|
||||
} catch let error as ASRError {
|
||||
logger.notice("🦜 Parakeet-specific error loading model: \(error.localizedDescription)")
|
||||
isModelLoaded = false
|
||||
asrManager = nil
|
||||
throw error
|
||||
} catch let error as AsrModelsError {
|
||||
logger.notice("🦜 Parakeet model management error loading model: \(error.localizedDescription)")
|
||||
isModelLoaded = false
|
||||
asrManager = nil
|
||||
throw error
|
||||
} catch {
|
||||
logger.notice("🦜 Unexpected error loading Parakeet model: \(error.localizedDescription)")
|
||||
} catch {
|
||||
let description = (error as? LocalizedError)?.errorDescription ?? error.localizedDescription
|
||||
logger.error("🦜 Failed to load Parakeet model: \(description)")
|
||||
isModelLoaded = false
|
||||
asrManager = nil
|
||||
throw error
|
||||
@ -64,81 +54,26 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
try await loadModel()
|
||||
}
|
||||
|
||||
guard let asrManager = asrManager else {
|
||||
logger.notice("🦜 Parakeet manager is still nil after attempting to load the model.")
|
||||
throw ASRError.notInitialized
|
||||
}
|
||||
guard let asrManager = asrManager else {
|
||||
throw ASRError.notInitialized
|
||||
}
|
||||
|
||||
let audioSamples = try readAudioSamples(from: audioURL)
|
||||
|
||||
// Validate audio data before VAD
|
||||
guard !audioSamples.isEmpty else {
|
||||
logger.notice("🦜 Audio is empty, skipping transcription.")
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
|
||||
// Use VAD to get speech segments
|
||||
var speechAudio: [Float] = []
|
||||
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
||||
// Use full audio for transcription
|
||||
let speechAudio: [Float] = audioSamples
|
||||
|
||||
if isVADEnabled {
|
||||
if let modelPath = await VADModelManager.shared.getModelPath() {
|
||||
if let vad = VoiceActivityDetector(modelPath: modelPath) {
|
||||
let speechSegments = vad.process(audioSamples: audioSamples)
|
||||
logger.notice("🦜 VAD detected \(speechSegments.count) speech segments.")
|
||||
|
||||
let sampleRate = 16000 // Assuming 16kHz sample rate
|
||||
for segment in speechSegments {
|
||||
let startSample = Int(segment.start * Double(sampleRate))
|
||||
var endSample = Int(segment.end * Double(sampleRate))
|
||||
|
||||
// Cap endSample to the audio buffer size
|
||||
if endSample > audioSamples.count {
|
||||
endSample = audioSamples.count
|
||||
}
|
||||
|
||||
if startSample < endSample {
|
||||
speechAudio.append(contentsOf: audioSamples[startSample..<endSample])
|
||||
} else {
|
||||
logger.warning("🦜 Invalid sample range for segment: start=\(startSample), end=\(endSample). Skipping.")
|
||||
}
|
||||
}
|
||||
logger.notice("🦜 Extracted \(speechAudio.count) samples from VAD segments.")
|
||||
} else {
|
||||
logger.warning("🦜 VAD could not be initialized. Transcribing original audio.")
|
||||
speechAudio = audioSamples
|
||||
}
|
||||
} else {
|
||||
logger.warning("🦜 VAD model path not found. Transcribing original audio.")
|
||||
speechAudio = audioSamples
|
||||
}
|
||||
} else {
|
||||
logger.notice("🦜 VAD is disabled by user setting. Transcribing original audio.")
|
||||
speechAudio = audioSamples
|
||||
}
|
||||
|
||||
// Validate audio data after VAD
|
||||
guard speechAudio.count >= 16000 else {
|
||||
logger.notice("🦜 Audio too short for transcription after VAD: \(speechAudio.count) samples")
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
|
||||
let result = try await asrManager.transcribe(speechAudio)
|
||||
print(result.text)
|
||||
|
||||
|
||||
// Reset decoder state and cleanup after transcription to avoid blocking the transcription start
|
||||
Task {
|
||||
asrManager.cleanup()
|
||||
isModelLoaded = false
|
||||
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
||||
}
|
||||
Task {
|
||||
asrManager.cleanup()
|
||||
isModelLoaded = false
|
||||
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
||||
}
|
||||
|
||||
// Check for empty results (vocabulary issue indicator)
|
||||
if result.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
logger.notice("🦜 Warning: Empty transcription result for \(audioSamples.count) samples - possible vocabulary issue")
|
||||
}
|
||||
|
||||
var text = result.text
|
||||
let text = result.text
|
||||
|
||||
return text
|
||||
}
|
||||
@ -147,11 +82,10 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
do {
|
||||
let data = try Data(contentsOf: url)
|
||||
|
||||
// Check minimum file size for valid WAV header
|
||||
guard data.count > 44 else {
|
||||
logger.notice("🦜 Audio file too small (\(data.count) bytes), expected > 44 bytes")
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
// Check minimum file size for valid WAV header
|
||||
guard data.count > 44 else {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
@ -161,10 +95,9 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
}
|
||||
|
||||
return floats
|
||||
} catch {
|
||||
logger.notice("🦜 Failed to read audio file: \(error.localizedDescription)")
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
} catch {
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user