Improved VAD for Parakeet model
This commit is contained in:
parent
106fd653ea
commit
c0ed2dc78a
@ -78,11 +78,29 @@ class ParakeetTranscriptionService: TranscriptionService {
|
||||
}
|
||||
|
||||
// Use VAD to get speech segments
|
||||
let speechAudio: [Float]
|
||||
var speechAudio: [Float] = []
|
||||
if let modelPath = await VADModelManager.shared.getModelPath() {
|
||||
if let vad = VoiceActivityDetector(modelPath: modelPath) {
|
||||
speechAudio = vad.process(audioSamples: audioSamples)
|
||||
logger.notice("🦜 VAD processed audio, resulting in \(speechAudio.count) samples.")
|
||||
let speechSegments = vad.process(audioSamples: audioSamples)
|
||||
logger.notice("🦜 VAD detected \(speechSegments.count) speech segments.")
|
||||
|
||||
let sampleRate = 16000 // Assuming 16kHz sample rate
|
||||
for segment in speechSegments {
|
||||
let startSample = Int(segment.start * Double(sampleRate))
|
||||
var endSample = Int(segment.end * Double(sampleRate))
|
||||
|
||||
// Cap endSample to the audio buffer size
|
||||
if endSample > audioSamples.count {
|
||||
endSample = audioSamples.count
|
||||
}
|
||||
|
||||
if startSample < endSample {
|
||||
speechAudio.append(contentsOf: audioSamples[startSample..<endSample])
|
||||
} else {
|
||||
logger.warning("🦜 Invalid sample range for segment: start=\(startSample), end=\(endSample). Skipping.")
|
||||
}
|
||||
}
|
||||
logger.notice("🦜 Extracted \(speechAudio.count) samples from VAD segments.")
|
||||
} else {
|
||||
logger.warning("🦜 VAD could not be initialized. Transcribing original audio.")
|
||||
speechAudio = audioSamples
|
||||
|
||||
@ -86,8 +86,8 @@ class VoiceActivityDetector {
|
||||
logger.notice("VAD context freed.")
|
||||
}
|
||||
|
||||
/// Processes audio samples to detect speech segments and returns the stitched audio containing only speech.
|
||||
func process(audioSamples: [Float]) -> [Float] {
|
||||
/// Processes audio samples to detect speech segments and returns an array of (start: TimeInterval, end: TimeInterval) tuples.
|
||||
func process(audioSamples: [Float]) -> [(start: TimeInterval, end: TimeInterval)] {
|
||||
// 1. Detect speech and get probabilities internally in the context
|
||||
let success = audioSamples.withUnsafeBufferPointer { buffer in
|
||||
whisper_vad_detect_speech(vadContext, buffer.baseAddress!, Int32(audioSamples.count))
|
||||
@ -100,10 +100,12 @@ class VoiceActivityDetector {
|
||||
|
||||
// 2. Get segments from probabilities
|
||||
var vadParams = whisper_vad_default_params()
|
||||
vadParams.threshold = 0.5
|
||||
vadParams.min_speech_duration_ms = 250
|
||||
vadParams.min_silence_duration_ms = 100
|
||||
vadParams.speech_pad_ms = 30
|
||||
vadParams.threshold = 0.45
|
||||
vadParams.min_speech_duration_ms = 150
|
||||
vadParams.min_silence_duration_ms = 750
|
||||
vadParams.max_speech_duration_s = Float.greatestFiniteMagnitude // Use the largest representable Float value for no max duration
|
||||
vadParams.speech_pad_ms = 100
|
||||
vadParams.samples_overlap = 0.1 // Add samples_overlap parameter
|
||||
|
||||
guard let segments = whisper_vad_segments_from_probs(vadContext, vadParams) else {
|
||||
logger.error("Failed to get VAD segments from probabilities.")
|
||||
@ -117,36 +119,15 @@ class VoiceActivityDetector {
|
||||
let nSegments = whisper_vad_segments_n_segments(segments)
|
||||
logger.notice("Detected \(nSegments) speech segments.")
|
||||
|
||||
// 3. Stitch audio segments together
|
||||
var stitchedAudio = [Float]()
|
||||
let sampleRate = 16000 // Assuming 16kHz sample rate
|
||||
|
||||
var speechSegments: [(start: TimeInterval, end: TimeInterval)] = []
|
||||
for i in 0..<nSegments {
|
||||
// Timestamps from C are mysteriously multiplied by 100, so we correct them here.
|
||||
let startTimeSec = whisper_vad_segments_get_segment_t0(segments, i) / 100.0
|
||||
let endTimeSec = whisper_vad_segments_get_segment_t1(segments, i) / 100.0
|
||||
|
||||
logger.debug("Segment \(i): start=\(startTimeSec, privacy: .public)s, end=\(endTimeSec, privacy: .public)s")
|
||||
|
||||
let startSample = Int(startTimeSec * Float(sampleRate))
|
||||
var endSample = Int(endTimeSec * Float(sampleRate))
|
||||
|
||||
logger.debug("Segment \(i): startSample=\(startSample, privacy: .public), endSample=\(endSample, privacy: .public)")
|
||||
|
||||
// Cap endSample to the audio buffer size
|
||||
if endSample > audioSamples.count {
|
||||
logger.debug("Capping endSample from \(endSample, privacy: .public) to \(audioSamples.count, privacy: .public)")
|
||||
endSample = audioSamples.count
|
||||
}
|
||||
|
||||
if startSample < endSample {
|
||||
stitchedAudio.append(contentsOf: audioSamples[startSample..<endSample])
|
||||
} else {
|
||||
logger.warning("Segment \(i): Invalid sample range, skipping.")
|
||||
}
|
||||
speechSegments.append((start: TimeInterval(startTimeSec), end: TimeInterval(endTimeSec)))
|
||||
}
|
||||
|
||||
logger.notice("Stitched audio contains \(stitchedAudio.count) samples.")
|
||||
return stitchedAudio
|
||||
logger.notice("Returning \(speechSegments.count) speech segments.")
|
||||
return speechSegments
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user