commit
df97832104
@ -7,7 +7,7 @@
|
|||||||
"location" : "https://github.com/FluidInference/FluidAudio",
|
"location" : "https://github.com/FluidInference/FluidAudio",
|
||||||
"state" : {
|
"state" : {
|
||||||
"branch" : "main",
|
"branch" : "main",
|
||||||
"revision" : "052cbb27cf073a9407251d74ef3459ea258e41b3"
|
"revision" : "328036d255ef76b8d661eacc16ac108eb45f9218"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
|
import CoreML
|
||||||
import AVFoundation
|
import AVFoundation
|
||||||
import FluidAudio
|
import FluidAudio
|
||||||
import os.log
|
import os.log
|
||||||
@ -7,15 +8,13 @@ import os.log
|
|||||||
|
|
||||||
class ParakeetTranscriptionService: TranscriptionService {
|
class ParakeetTranscriptionService: TranscriptionService {
|
||||||
private var asrManager: AsrManager?
|
private var asrManager: AsrManager?
|
||||||
|
private var vadManager: VadManager?
|
||||||
private let customModelsDirectory: URL?
|
private let customModelsDirectory: URL?
|
||||||
@Published var isModelLoaded = false
|
@Published var isModelLoaded = false
|
||||||
|
|
||||||
// Logger for Parakeet transcription service
|
|
||||||
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
|
private let logger = Logger(subsystem: "com.voiceink.app", category: "ParakeetTranscriptionService")
|
||||||
|
|
||||||
init(customModelsDirectory: URL? = nil) {
|
init(customModelsDirectory: URL? = nil) {
|
||||||
self.customModelsDirectory = customModelsDirectory
|
self.customModelsDirectory = customModelsDirectory
|
||||||
logger.notice("🦜 ParakeetTranscriptionService initialized with directory: \(customModelsDirectory?.path ?? "default")")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadModel() async throws {
|
func loadModel() async throws {
|
||||||
@ -23,39 +22,16 @@ class ParakeetTranscriptionService: TranscriptionService {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.notice("🦜 Starting Parakeet model loading")
|
if let customModelsDirectory {
|
||||||
|
do {
|
||||||
do {
|
asrManager = AsrManager(config: .default)
|
||||||
|
let models = try await AsrModels.load(from: customModelsDirectory)
|
||||||
asrManager = AsrManager(config: .default)
|
try await asrManager?.initialize(models: models)
|
||||||
let models: AsrModels
|
isModelLoaded = true
|
||||||
if let customDirectory = customModelsDirectory {
|
} catch {
|
||||||
logger.notice("🦜 Loading models from custom directory: \(customDirectory.path)")
|
isModelLoaded = false
|
||||||
models = try await AsrModels.downloadAndLoad(to: customDirectory)
|
asrManager = nil
|
||||||
} else {
|
|
||||||
logger.notice("🦜 Loading models from default directory")
|
|
||||||
models = try await AsrModels.downloadAndLoad()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try await asrManager?.initialize(models: models)
|
|
||||||
isModelLoaded = true
|
|
||||||
logger.notice("🦜 Parakeet model loaded successfully")
|
|
||||||
|
|
||||||
} catch let error as ASRError {
|
|
||||||
logger.notice("🦜 Parakeet-specific error loading model: \(error.localizedDescription)")
|
|
||||||
isModelLoaded = false
|
|
||||||
asrManager = nil
|
|
||||||
throw error
|
|
||||||
} catch let error as AsrModelsError {
|
|
||||||
logger.notice("🦜 Parakeet model management error loading model: \(error.localizedDescription)")
|
|
||||||
isModelLoaded = false
|
|
||||||
asrManager = nil
|
|
||||||
throw error
|
|
||||||
} catch {
|
|
||||||
logger.notice("🦜 Unexpected error loading Parakeet model: \(error.localizedDescription)")
|
|
||||||
isModelLoaded = false
|
|
||||||
asrManager = nil
|
|
||||||
throw error
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,81 +40,57 @@ class ParakeetTranscriptionService: TranscriptionService {
|
|||||||
try await loadModel()
|
try await loadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
guard let asrManager = asrManager else {
|
guard let asrManager = asrManager else {
|
||||||
logger.notice("🦜 Parakeet manager is still nil after attempting to load the model.")
|
throw ASRError.notInitialized
|
||||||
throw ASRError.notInitialized
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let audioSamples = try readAudioSamples(from: audioURL)
|
let audioSamples = try readAudioSamples(from: audioURL)
|
||||||
|
|
||||||
// Validate audio data before VAD
|
|
||||||
guard !audioSamples.isEmpty else {
|
|
||||||
logger.notice("🦜 Audio is empty, skipping transcription.")
|
|
||||||
throw ASRError.invalidAudioData
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use VAD to get speech segments
|
let durationSeconds = Double(audioSamples.count) / 16000.0
|
||||||
var speechAudio: [Float] = []
|
|
||||||
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
let isVADEnabled = UserDefaults.standard.object(forKey: "IsVADEnabled") as? Bool ?? true
|
||||||
|
|
||||||
if isVADEnabled {
|
let speechAudio: [Float]
|
||||||
if let modelPath = await VADModelManager.shared.getModelPath() {
|
if durationSeconds < 20.0 || !isVADEnabled {
|
||||||
if let vad = VoiceActivityDetector(modelPath: modelPath) {
|
speechAudio = audioSamples
|
||||||
let speechSegments = vad.process(audioSamples: audioSamples)
|
} else {
|
||||||
logger.notice("🦜 VAD detected \(speechSegments.count) speech segments.")
|
let vadConfig = VadConfig(threshold: 0.7)
|
||||||
|
if vadManager == nil, let customModelsDirectory {
|
||||||
|
do {
|
||||||
|
vadManager = try await VadManager(
|
||||||
|
config: vadConfig,
|
||||||
|
modelDirectory: customModelsDirectory.deletingLastPathComponent()
|
||||||
|
)
|
||||||
|
} catch {
|
||||||
|
// Silent failure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let sampleRate = 16000 // Assuming 16kHz sample rate
|
do {
|
||||||
for segment in speechSegments {
|
if let vadManager {
|
||||||
let startSample = Int(segment.start * Double(sampleRate))
|
let segments = try await vadManager.segmentSpeechAudio(audioSamples)
|
||||||
var endSample = Int(segment.end * Double(sampleRate))
|
if segments.isEmpty {
|
||||||
|
speechAudio = audioSamples
|
||||||
// Cap endSample to the audio buffer size
|
} else {
|
||||||
if endSample > audioSamples.count {
|
speechAudio = segments.flatMap { $0 }
|
||||||
endSample = audioSamples.count
|
|
||||||
}
|
|
||||||
|
|
||||||
if startSample < endSample {
|
|
||||||
speechAudio.append(contentsOf: audioSamples[startSample..<endSample])
|
|
||||||
} else {
|
|
||||||
logger.warning("🦜 Invalid sample range for segment: start=\(startSample), end=\(endSample). Skipping.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
logger.notice("🦜 Extracted \(speechAudio.count) samples from VAD segments.")
|
|
||||||
} else {
|
} else {
|
||||||
logger.warning("🦜 VAD could not be initialized. Transcribing original audio.")
|
|
||||||
speechAudio = audioSamples
|
speechAudio = audioSamples
|
||||||
}
|
}
|
||||||
} else {
|
} catch {
|
||||||
logger.warning("🦜 VAD model path not found. Transcribing original audio.")
|
|
||||||
speechAudio = audioSamples
|
speechAudio = audioSamples
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
logger.notice("🦜 VAD is disabled by user setting. Transcribing original audio.")
|
|
||||||
speechAudio = audioSamples
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate audio data after VAD
|
|
||||||
guard speechAudio.count >= 16000 else {
|
|
||||||
logger.notice("🦜 Audio too short for transcription after VAD: \(speechAudio.count) samples")
|
|
||||||
throw ASRError.invalidAudioData
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = try await asrManager.transcribe(speechAudio)
|
let result = try await asrManager.transcribe(speechAudio)
|
||||||
print(result.text)
|
|
||||||
|
Task {
|
||||||
|
asrManager.cleanup()
|
||||||
|
isModelLoaded = false
|
||||||
|
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
||||||
|
}
|
||||||
|
|
||||||
// Reset decoder state and cleanup after transcription to avoid blocking the transcription start
|
let text = result.text
|
||||||
Task {
|
|
||||||
asrManager.cleanup()
|
|
||||||
isModelLoaded = false
|
|
||||||
logger.notice("🦜 Parakeet ASR models cleaned up from memory")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for empty results (vocabulary issue indicator)
|
|
||||||
if result.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
|
||||||
logger.notice("🦜 Warning: Empty transcription result for \(audioSamples.count) samples - possible vocabulary issue")
|
|
||||||
}
|
|
||||||
|
|
||||||
var text = result.text
|
|
||||||
|
|
||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
@ -146,12 +98,9 @@ class ParakeetTranscriptionService: TranscriptionService {
|
|||||||
private func readAudioSamples(from url: URL) throws -> [Float] {
|
private func readAudioSamples(from url: URL) throws -> [Float] {
|
||||||
do {
|
do {
|
||||||
let data = try Data(contentsOf: url)
|
let data = try Data(contentsOf: url)
|
||||||
|
guard data.count > 44 else {
|
||||||
// Check minimum file size for valid WAV header
|
throw ASRError.invalidAudioData
|
||||||
guard data.count > 44 else {
|
}
|
||||||
logger.notice("🦜 Audio file too small (\(data.count) bytes), expected > 44 bytes")
|
|
||||||
throw ASRError.invalidAudioData
|
|
||||||
}
|
|
||||||
|
|
||||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||||
@ -161,10 +110,9 @@ class ParakeetTranscriptionService: TranscriptionService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return floats
|
return floats
|
||||||
} catch {
|
} catch {
|
||||||
logger.notice("🦜 Failed to read audio file: \(error.localizedDescription)")
|
throw ASRError.invalidAudioData
|
||||||
throw ASRError.invalidAudioData
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,6 +33,14 @@ extension WhisperState {
|
|||||||
|
|
||||||
do {
|
do {
|
||||||
_ = try await AsrModels.downloadAndLoad(to: parakeetModelsDirectory)
|
_ = try await AsrModels.downloadAndLoad(to: parakeetModelsDirectory)
|
||||||
|
|
||||||
|
// Also download VAD model into the same parent directory as ASR models
|
||||||
|
let parentDir = parakeetModelsDirectory.deletingLastPathComponent()
|
||||||
|
_ = try await DownloadUtils.loadModels(
|
||||||
|
.vad,
|
||||||
|
modelNames: Array(ModelNames.VAD.requiredModels),
|
||||||
|
directory: parentDir
|
||||||
|
)
|
||||||
self.isParakeetModelDownloaded = true
|
self.isParakeetModelDownloaded = true
|
||||||
downloadProgress["parakeet-tdt-0.6b"] = 1.0
|
downloadProgress["parakeet-tdt-0.6b"] = 1.0
|
||||||
} catch {
|
} catch {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user