Improved VAD for Parakeet model

This commit is contained in:
Beingpax 2025-09-06 07:13:06 +05:45
parent 106fd653ea
commit c0ed2dc78a
2 changed files with 34 additions and 35 deletions

View File

@ -78,11 +78,29 @@ class ParakeetTranscriptionService: TranscriptionService {
}
// Use VAD to get speech segments
let speechAudio: [Float]
var speechAudio: [Float] = []
if let modelPath = await VADModelManager.shared.getModelPath() {
if let vad = VoiceActivityDetector(modelPath: modelPath) {
speechAudio = vad.process(audioSamples: audioSamples)
logger.notice("🦜 VAD processed audio, resulting in \(speechAudio.count) samples.")
let speechSegments = vad.process(audioSamples: audioSamples)
logger.notice("🦜 VAD detected \(speechSegments.count) speech segments.")
let sampleRate = 16000 // Assuming 16kHz sample rate
for segment in speechSegments {
let startSample = Int(segment.start * Double(sampleRate))
var endSample = Int(segment.end * Double(sampleRate))
// Cap endSample to the audio buffer size
if endSample > audioSamples.count {
endSample = audioSamples.count
}
if startSample < endSample {
speechAudio.append(contentsOf: audioSamples[startSample..<endSample])
} else {
logger.warning("🦜 Invalid sample range for segment: start=\(startSample), end=\(endSample). Skipping.")
}
}
logger.notice("🦜 Extracted \(speechAudio.count) samples from VAD segments.")
} else {
logger.warning("🦜 VAD could not be initialized. Transcribing original audio.")
speechAudio = audioSamples

View File

@ -86,8 +86,8 @@ class VoiceActivityDetector {
logger.notice("VAD context freed.")
}
/// Processes audio samples to detect speech segments and returns the stitched audio containing only speech.
func process(audioSamples: [Float]) -> [Float] {
/// Processes audio samples to detect speech segments and returns an array of (start: TimeInterval, end: TimeInterval) tuples.
func process(audioSamples: [Float]) -> [(start: TimeInterval, end: TimeInterval)] {
// 1. Detect speech and get probabilities internally in the context
let success = audioSamples.withUnsafeBufferPointer { buffer in
whisper_vad_detect_speech(vadContext, buffer.baseAddress!, Int32(audioSamples.count))
@ -100,10 +100,12 @@ class VoiceActivityDetector {
// 2. Get segments from probabilities
var vadParams = whisper_vad_default_params()
vadParams.threshold = 0.5
vadParams.min_speech_duration_ms = 250
vadParams.min_silence_duration_ms = 100
vadParams.speech_pad_ms = 30
vadParams.threshold = 0.45
vadParams.min_speech_duration_ms = 150
vadParams.min_silence_duration_ms = 750
vadParams.max_speech_duration_s = Float.greatestFiniteMagnitude // Use the largest representable Float value for no max duration
vadParams.speech_pad_ms = 100
vadParams.samples_overlap = 0.1 // Add samples_overlap parameter
guard let segments = whisper_vad_segments_from_probs(vadContext, vadParams) else {
logger.error("Failed to get VAD segments from probabilities.")
@ -117,36 +119,15 @@ class VoiceActivityDetector {
let nSegments = whisper_vad_segments_n_segments(segments)
logger.notice("Detected \(nSegments) speech segments.")
// 3. Stitch audio segments together
var stitchedAudio = [Float]()
let sampleRate = 16000 // Assuming 16kHz sample rate
var speechSegments: [(start: TimeInterval, end: TimeInterval)] = []
for i in 0..<nSegments {
// Timestamps from C are mysteriously multiplied by 100, so we correct them here.
let startTimeSec = whisper_vad_segments_get_segment_t0(segments, i) / 100.0
let endTimeSec = whisper_vad_segments_get_segment_t1(segments, i) / 100.0
logger.debug("Segment \(i): start=\(startTimeSec, privacy: .public)s, end=\(endTimeSec, privacy: .public)s")
let startSample = Int(startTimeSec * Float(sampleRate))
var endSample = Int(endTimeSec * Float(sampleRate))
logger.debug("Segment \(i): startSample=\(startSample, privacy: .public), endSample=\(endSample, privacy: .public)")
// Cap endSample to the audio buffer size
if endSample > audioSamples.count {
logger.debug("Capping endSample from \(endSample, privacy: .public) to \(audioSamples.count, privacy: .public)")
endSample = audioSamples.count
}
if startSample < endSample {
stitchedAudio.append(contentsOf: audioSamples[startSample..<endSample])
} else {
logger.warning("Segment \(i): Invalid sample range, skipping.")
}
speechSegments.append((start: TimeInterval(startTimeSec), end: TimeInterval(endTimeSec)))
}
logger.notice("Stitched audio contains \(stitchedAudio.count) samples.")
return stitchedAudio
logger.notice("Returning \(speechSegments.count) speech segments.")
return speechSegments
}
}
}