vOOice/VoiceInk/Whisper/LibWhisper.swift

146 lines
4.6 KiB
Swift

import Foundation
#if canImport(whisper)
import whisper
#else
#error("Unable to import whisper module. Please check your project configuration.")
#endif
import os
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
actor WhisperContext {
private var context: OpaquePointer?
private var languageCString: [CChar]?
private var prompt: String?
private var promptCString: [CChar]?
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "WhisperContext")
private init() {}
init(context: OpaquePointer) {
self.context = context
}
deinit {
if let context = context {
whisper_free(context)
}
}
func fullTranscribe(samples: [Float]) async {
guard let context = context else { return }
let maxThreads = max(1, min(8, cpuCount() - 2))
var params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY)
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
if selectedLanguage != "auto" {
languageCString = Array(selectedLanguage.utf8CString)
params.language = languageCString?.withUnsafeBufferPointer { ptr in
ptr.baseAddress
}
} else {
languageCString = nil
params.language = nil
}
if prompt != nil {
promptCString = Array(prompt!.utf8CString)
params.initial_prompt = promptCString?.withUnsafeBufferPointer { ptr in
ptr.baseAddress
}
} else {
promptCString = nil
params.initial_prompt = nil
}
params.print_realtime = true
params.print_progress = false
params.print_timestamps = true
params.print_special = false
params.translate = false
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
params.suppress_nst = true
params.entropy_thold = 2.0
params.logprob_thold = -0.8
params.no_speech_thold = 0.6
whisper_reset_timings(context)
if let vadModelPath = await VADModelManager.shared.getModelPath() {
params.vad = true
params.vad_model_path = (vadModelPath as NSString).utf8String
var vadParams = whisper_vad_default_params()
vadParams.threshold = 0.50
vadParams.min_speech_duration_ms = 250
vadParams.min_silence_duration_ms = 100
vadParams.max_speech_duration_s = Float.greatestFiniteMagnitude
vadParams.speech_pad_ms = 30
vadParams.samples_overlap = 0.1
params.vad_params = vadParams
} else {
params.vad = false
}
samples.withUnsafeBufferPointer { samplesBuffer in
if whisper_full(context, params, samplesBuffer.baseAddress, Int32(samplesBuffer.count)) != 0 {
self.logger.error("Failed to run whisper_full")
}
}
languageCString = nil
promptCString = nil
}
func getTranscription() -> String {
guard let context = context else { return "" }
var transcription = ""
for i in 0..<whisper_full_n_segments(context) {
transcription += String(cString: whisper_full_get_segment_text(context, i))
}
let filteredTranscription = WhisperHallucinationFilter.filter(transcription)
return filteredTranscription
}
static func createContext(path: String) async throws -> WhisperContext {
let whisperContext = WhisperContext()
try await whisperContext.initializeModel(path: path)
return whisperContext
}
private func initializeModel(path: String) throws {
var params = whisper_context_default_params()
#if targetEnvironment(simulator)
params.use_gpu = false
#endif
let context = whisper_init_from_file_with_params(path, params)
if let context {
self.context = context
} else {
logger.error("❌ Couldn't load model at \(path)")
throw WhisperStateError.modelLoadFailed
}
}
func releaseResources() {
if let context = context {
whisper_free(context)
self.context = nil
}
languageCString = nil
}
func setPrompt(_ prompt: String?) {
self.prompt = prompt
}
}
fileprivate func cpuCount() -> Int {
ProcessInfo.processInfo.processorCount
}