From b568fb9fa0932728f2284f07b16299667b204077 Mon Sep 17 00:00:00 2001 From: Beingpax Date: Sun, 19 Oct 2025 19:56:07 +0545 Subject: [PATCH] Auto warm-up downloaded Whisper models after downloading --- .../AI Models/LocalModelCardRowView.swift | 22 ++++-- .../Views/AI Models/ModelCardRowView.swift | 5 +- .../Views/AI Models/ModelManagementView.swift | 6 ++ .../WhisperModelWarmupCoordinator.swift | 68 +++++++++++++++++++ .../WhisperState+LocalModelManager.swift | 57 ++++------------ 5 files changed, 108 insertions(+), 50 deletions(-) create mode 100644 VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift diff --git a/VoiceInk/Views/AI Models/LocalModelCardRowView.swift b/VoiceInk/Views/AI Models/LocalModelCardRowView.swift index b836320..ba5b874 100644 --- a/VoiceInk/Views/AI Models/LocalModelCardRowView.swift +++ b/VoiceInk/Views/AI Models/LocalModelCardRowView.swift @@ -7,12 +7,12 @@ struct LocalModelCardView: View { let isCurrent: Bool let downloadProgress: [String: Double] let modelURL: URL? + let isWarming: Bool // Actions var deleteAction: () -> Void var setDefaultAction: () -> Void var downloadAction: () -> Void - private var isDownloading: Bool { downloadProgress.keys.contains(model.name + "_main") || downloadProgress.keys.contains(model.name + "_coreml") @@ -134,12 +134,22 @@ struct LocalModelCardView: View { .font(.system(size: 12)) .foregroundColor(Color(.secondaryLabelColor)) } else if isDownloaded { - Button(action: setDefaultAction) { - Text("Set as Default") - .font(.system(size: 12)) + if isWarming { + HStack(spacing: 6) { + ProgressView() + .controlSize(.small) + Text("Optimizing model for your device...") + .font(.system(size: 12)) + .foregroundColor(Color(.secondaryLabelColor)) + } + } else { + Button(action: setDefaultAction) { + Text("Set as Default") + .font(.system(size: 12)) + } + .buttonStyle(.bordered) + .controlSize(.small) } - .buttonStyle(.bordered) - .controlSize(.small) } else { Button(action: downloadAction) { HStack(spacing: 4) { diff --git a/VoiceInk/Views/AI Models/ModelCardRowView.swift b/VoiceInk/Views/AI Models/ModelCardRowView.swift index e33fc71..012a750 100644 --- a/VoiceInk/Views/AI Models/ModelCardRowView.swift +++ b/VoiceInk/Views/AI Models/ModelCardRowView.swift @@ -8,13 +8,13 @@ struct ModelCardRowView: View { let isCurrent: Bool let downloadProgress: [String: Double] let modelURL: URL? + let isWarming: Bool // Actions var deleteAction: () -> Void var setDefaultAction: () -> Void var downloadAction: () -> Void var editAction: ((CustomCloudModel) -> Void)? - var body: some View { Group { switch model.provider { @@ -26,6 +26,7 @@ struct ModelCardRowView: View { isCurrent: isCurrent, downloadProgress: downloadProgress, modelURL: modelURL, + isWarming: isWarming, deleteAction: deleteAction, setDefaultAction: setDefaultAction, downloadAction: downloadAction @@ -76,4 +77,4 @@ struct ModelCardRowView: View { } } } -} \ No newline at end of file +} diff --git a/VoiceInk/Views/AI Models/ModelManagementView.swift b/VoiceInk/Views/AI Models/ModelManagementView.swift index 65b2964..3e1cbe4 100644 --- a/VoiceInk/Views/AI Models/ModelManagementView.swift +++ b/VoiceInk/Views/AI Models/ModelManagementView.swift @@ -19,6 +19,7 @@ struct ModelManagementView: View { @EnvironmentObject private var enhancementService: AIEnhancementService @Environment(\.modelContext) private var modelContext @StateObject private var whisperPrompt = WhisperPrompt() + @ObservedObject private var warmupCoordinator = WhisperModelWarmupCoordinator.shared @State private var selectedFilter: ModelFilter = .recommended @State private var isShowingSettings = false @@ -118,6 +119,10 @@ struct ModelManagementView: View { } else { VStack(spacing: 12) { ForEach(filteredModels, id: \.id) { model in + let isWarming = (model as? LocalModel).map { localModel in + warmupCoordinator.isWarming(modelNamed: localModel.name) + } ?? false + ModelCardRowView( model: model, whisperState: whisperState, @@ -125,6 +130,7 @@ struct ModelManagementView: View { isCurrent: whisperState.currentTranscriptionModel?.name == model.name, downloadProgress: whisperState.downloadProgress, modelURL: whisperState.availableModels.first { $0.name == model.name }?.url, + isWarming: isWarming, deleteAction: { if let customModel = model as? CustomCloudModel { alertTitle = "Delete Custom Model" diff --git a/VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift b/VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift new file mode 100644 index 0000000..e150f0e --- /dev/null +++ b/VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift @@ -0,0 +1,68 @@ +import Foundation +import Combine + +@MainActor +final class WhisperModelWarmupCoordinator: ObservableObject { + static let shared = WhisperModelWarmupCoordinator() + + @Published private(set) var warmingModels: Set = [] + + private init() {} + + func isWarming(modelNamed name: String) -> Bool { + warmingModels.contains(name) + } + + func scheduleWarmup(for model: LocalModel, whisperState: WhisperState) { + guard shouldWarmup(modelName: model.name), + !warmingModels.contains(model.name) else { + return + } + + warmingModels.insert(model.name) + + Task { + do { + try await runWarmup(for: model, whisperState: whisperState) + } catch { + await MainActor.run { + whisperState.logger.error("Warmup failed for \(model.name): \(error.localizedDescription)") + } + } + + await MainActor.run { + self.warmingModels.remove(model.name) + } + } + } + + private func runWarmup(for model: LocalModel, whisperState: WhisperState) async throws { + guard let sampleURL = warmupSampleURL() else { return } + let service = LocalTranscriptionService( + modelsDirectory: whisperState.modelsDirectory, + whisperState: whisperState + ) + _ = try await service.transcribe(audioURL: sampleURL, model: model) + } + + private func warmupSampleURL() -> URL? { + let bundle = Bundle.main + let candidates: [URL?] = [ + bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Resources/Sounds"), + bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Sounds"), + bundle.url(forResource: "esc", withExtension: "wav") + ] + + for candidate in candidates { + if let url = candidate { + return url + } + } + + return nil + } + + private func shouldWarmup(modelName: String) -> Bool { + !modelName.contains("q5") && !modelName.contains("q8") + } +} diff --git a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift index 6ba1c1b..d3b65a9 100644 --- a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift +++ b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift @@ -157,57 +157,26 @@ extension WhisperState { } } } - - // Shows an alert about Core ML support and first-run optimization - private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) { - Task { @MainActor in - let alert = NSAlert() - alert.messageText = "Core ML Support for \(model.displayName) Model" - alert.informativeText = "This Whisper model supports Core ML, which can improve performance by 2-4x on Apple Silicon devices.\n\nDuring the first run, it can take several minutes to optimize the model for your system. Subsequent runs will be much faster." - alert.alertStyle = .informational - alert.addButton(withTitle: "Download") - alert.addButton(withTitle: "Cancel") - - let response = alert.runModal() - if response == .alertFirstButtonReturn { - completion() - } - } - } - func downloadModel(_ model: LocalModel) async { guard let url = URL(string: model.downloadURL) else { return } - - // Check if model supports Core ML (non-quantized models) - let supportsCoreML = !model.name.contains("q5") && !model.name.contains("q8") - - if supportsCoreML { - // Show the CoreML alert for models that support it - await MainActor.run { - showCoreMLAlert(for: model) { - // This completion handler is called when user clicks "Download" - Task { - await self.performModelDownload(model, url) - } - } - } - } else { - // Directly download the model if it doesn't support Core ML - await performModelDownload(model, url) - } + await performModelDownload(model, url) } private func performModelDownload(_ model: LocalModel, _ url: URL) async { do { - let whisperModel = try await downloadMainModel(model, from: url) + var whisperModel = try await downloadMainModel(model, from: url) if let coreMLZipURL = whisperModel.coreMLZipDownloadURL, let coreMLURL = URL(string: coreMLZipURL) { - try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL) + whisperModel = try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL) } availableModels.append(whisperModel) self.downloadProgress.removeValue(forKey: model.name + "_main") + + if shouldWarmup(model) { + WhisperModelWarmupCoordinator.shared.scheduleWarmup(for: model, whisperState: self) + } } catch { handleModelDownloadError(model, error) } @@ -223,22 +192,22 @@ extension WhisperState { return WhisperModel(name: model.name, url: destinationURL) } - private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws { + private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws -> WhisperModel { let progressKeyCoreML = model.name + "_coreml" let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML) let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip") try coreMLData.write(to: coreMLZipPath) - try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML) + return try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML) } - private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws { + private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws -> WhisperModel { let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc") try? FileManager.default.removeItem(at: coreMLDestination) try await unzipCoreMLFile(zipPath, to: modelsDirectory) - try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey) + return try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey) } private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws { @@ -268,6 +237,10 @@ extension WhisperState { return model } + + private func shouldWarmup(_ model: LocalModel) -> Bool { + !model.name.contains("q5") && !model.name.contains("q8") + } private func handleModelDownloadError(_ model: LocalModel, _ error: Error) { self.downloadProgress.removeValue(forKey: model.name + "_main")