diff --git a/VoiceInk/Services/ParakeetTranscriptionService.swift b/VoiceInk/Services/ParakeetTranscriptionService.swift index f3e4a09..cb39801 100644 --- a/VoiceInk/Services/ParakeetTranscriptionService.swift +++ b/VoiceInk/Services/ParakeetTranscriptionService.swift @@ -71,13 +71,34 @@ class ParakeetTranscriptionService: TranscriptionService { let audioSamples = try readAudioSamples(from: audioURL) - // Validate audio data before transcription - guard audioSamples.count >= 16000 else { - logger.notice("🦜 Audio too short for transcription: \(audioSamples.count) samples") + // Validate audio data before VAD + guard !audioSamples.isEmpty else { + logger.notice("🦜 Audio is empty, skipping transcription.") + throw ASRError.invalidAudioData + } + + // Use VAD to get speech segments + let speechAudio: [Float] + if let modelPath = await VADModelManager.shared.getModelPath() { + if let vad = VoiceActivityDetector(modelPath: modelPath) { + speechAudio = vad.process(audioSamples: audioSamples) + logger.notice("🦜 VAD processed audio, resulting in \(speechAudio.count) samples.") + } else { + logger.warning("🦜 VAD could not be initialized. Transcribing original audio.") + speechAudio = audioSamples + } + } else { + logger.warning("🦜 VAD model path not found. Transcribing original audio.") + speechAudio = audioSamples + } + + // Validate audio data after VAD + guard speechAudio.count >= 16000 else { + logger.notice("🦜 Audio too short for transcription after VAD: \(speechAudio.count) samples") throw ASRError.invalidAudioData } - let result = try await asrManager.transcribe(audioSamples) + let result = try await asrManager.transcribe(speechAudio) // Reset decoder state and cleanup after transcription to avoid blocking the transcription start Task { diff --git a/VoiceInk/Services/VoiceActivityDetector.swift b/VoiceInk/Services/VoiceActivityDetector.swift new file mode 100644 index 0000000..658deaf --- /dev/null +++ b/VoiceInk/Services/VoiceActivityDetector.swift @@ -0,0 +1,152 @@ +import Foundation +import AVFoundation +import os.log + +// MARK: - C API Bridge + +// Opaque pointers for the C contexts +fileprivate typealias WhisperVADContext = OpaquePointer +fileprivate typealias WhisperVADSegments = OpaquePointer + +// Define the C function signatures for Swift, scoped to this file + +@_silgen_name("whisper_vad_default_params") +fileprivate func whisper_vad_default_params() -> whisper_vad_params + +@_silgen_name("whisper_vad_default_context_params") +fileprivate func whisper_vad_default_context_params() -> whisper_vad_context_params + +@_silgen_name("whisper_vad_init_from_file_with_params") +fileprivate func whisper_vad_init_from_file_with_params(_ path_model: UnsafePointer, _ params: whisper_vad_context_params) -> WhisperVADContext? + +@_silgen_name("whisper_vad_detect_speech") +fileprivate func whisper_vad_detect_speech(_ vctx: WhisperVADContext, _ samples: UnsafePointer, _ n_samples: Int32) -> Bool + +@_silgen_name("whisper_vad_n_probs") +fileprivate func whisper_vad_n_probs(_ vctx: WhisperVADContext) -> Int32 + +@_silgen_name("whisper_vad_probs") +fileprivate func whisper_vad_probs(_ vctx: WhisperVADContext) -> UnsafeMutablePointer + +@_silgen_name("whisper_vad_segments_from_probs") +fileprivate func whisper_vad_segments_from_probs(_ vctx: WhisperVADContext, _ params: whisper_vad_params) -> WhisperVADSegments? + +@_silgen_name("whisper_vad_segments_n_segments") +fileprivate func whisper_vad_segments_n_segments(_ segments: WhisperVADSegments) -> Int32 + +@_silgen_name("whisper_vad_segments_get_segment_t0") +fileprivate func whisper_vad_segments_get_segment_t0(_ segments: WhisperVADSegments, _ i_segment: Int32) -> Float + +@_silgen_name("whisper_vad_segments_get_segment_t1") +fileprivate func whisper_vad_segments_get_segment_t1(_ segments: WhisperVADSegments, _ i_segment: Int32) -> Float + +@_silgen_name("whisper_vad_free_segments") +fileprivate func whisper_vad_free_segments(_ segments: WhisperVADSegments) + +@_silgen_name("whisper_vad_free") +fileprivate func whisper_vad_free(_ ctx: WhisperVADContext) + +// Structs matching whisper.h, scoped to this file +fileprivate struct whisper_vad_params { + var threshold: Float + var min_speech_duration_ms: Int32 + var min_silence_duration_ms: Int32 + var max_speech_duration_s: Float + var speech_pad_ms: Int32 + var samples_overlap: Float +} + +fileprivate struct whisper_vad_context_params { + var n_threads: Int32 + var use_gpu: Bool + var gpu_device: Int32 +} + + +// MARK: - VoiceActivityDetector Class + +class VoiceActivityDetector { + private var vadContext: WhisperVADContext + private let logger = Logger(subsystem: "com.voiceink.app", category: "VoiceActivityDetector") + + init?(modelPath: String) { + var contextParams = whisper_vad_default_context_params() + contextParams.n_threads = max(1, min(8, Int32(ProcessInfo.processInfo.processorCount) - 2)) + + guard let context = whisper_vad_init_from_file_with_params(modelPath, contextParams) else { + logger.error("Failed to initialize VAD context.") + return nil + } + self.vadContext = context + logger.notice("VAD context initialized successfully.") + } + + deinit { + whisper_vad_free(vadContext) + logger.notice("VAD context freed.") + } + + /// Processes audio samples to detect speech segments and returns the stitched audio containing only speech. + func process(audioSamples: [Float]) -> [Float] { + // 1. Detect speech and get probabilities internally in the context + let success = audioSamples.withUnsafeBufferPointer { buffer in + whisper_vad_detect_speech(vadContext, buffer.baseAddress!, Int32(audioSamples.count)) + } + + guard success else { + logger.error("Failed to detect speech probabilities.") + return [] + } + + // 2. Get segments from probabilities + var vadParams = whisper_vad_default_params() + vadParams.threshold = 0.5 + vadParams.min_speech_duration_ms = 250 + vadParams.min_silence_duration_ms = 100 + vadParams.speech_pad_ms = 30 + + guard let segments = whisper_vad_segments_from_probs(vadContext, vadParams) else { + logger.error("Failed to get VAD segments from probabilities.") + return [] + } + defer { + // Ensure segments are freed + whisper_vad_free_segments(segments) + } + + let nSegments = whisper_vad_segments_n_segments(segments) + logger.notice("Detected \(nSegments) speech segments.") + + // 3. Stitch audio segments together + var stitchedAudio = [Float]() + let sampleRate = 16000 // Assuming 16kHz sample rate + + for i in 0.. audioSamples.count { + logger.debug("Capping endSample from \(endSample, privacy: .public) to \(audioSamples.count, privacy: .public)") + endSample = audioSamples.count + } + + if startSample < endSample { + stitchedAudio.append(contentsOf: audioSamples[startSample..