Auto warm-up downloaded Whisper models after downloading

This commit is contained in:
Beingpax 2025-10-19 19:56:07 +05:45
parent 3f01f49f56
commit b568fb9fa0
5 changed files with 108 additions and 50 deletions

View File

@ -7,12 +7,12 @@ struct LocalModelCardView: View {
let isCurrent: Bool
let downloadProgress: [String: Double]
let modelURL: URL?
let isWarming: Bool
// Actions
var deleteAction: () -> Void
var setDefaultAction: () -> Void
var downloadAction: () -> Void
private var isDownloading: Bool {
downloadProgress.keys.contains(model.name + "_main") ||
downloadProgress.keys.contains(model.name + "_coreml")
@ -134,12 +134,22 @@ struct LocalModelCardView: View {
.font(.system(size: 12))
.foregroundColor(Color(.secondaryLabelColor))
} else if isDownloaded {
Button(action: setDefaultAction) {
Text("Set as Default")
.font(.system(size: 12))
if isWarming {
HStack(spacing: 6) {
ProgressView()
.controlSize(.small)
Text("Optimizing model for your device...")
.font(.system(size: 12))
.foregroundColor(Color(.secondaryLabelColor))
}
} else {
Button(action: setDefaultAction) {
Text("Set as Default")
.font(.system(size: 12))
}
.buttonStyle(.bordered)
.controlSize(.small)
}
.buttonStyle(.bordered)
.controlSize(.small)
} else {
Button(action: downloadAction) {
HStack(spacing: 4) {

View File

@ -8,13 +8,13 @@ struct ModelCardRowView: View {
let isCurrent: Bool
let downloadProgress: [String: Double]
let modelURL: URL?
let isWarming: Bool
// Actions
var deleteAction: () -> Void
var setDefaultAction: () -> Void
var downloadAction: () -> Void
var editAction: ((CustomCloudModel) -> Void)?
var body: some View {
Group {
switch model.provider {
@ -26,6 +26,7 @@ struct ModelCardRowView: View {
isCurrent: isCurrent,
downloadProgress: downloadProgress,
modelURL: modelURL,
isWarming: isWarming,
deleteAction: deleteAction,
setDefaultAction: setDefaultAction,
downloadAction: downloadAction
@ -76,4 +77,4 @@ struct ModelCardRowView: View {
}
}
}
}
}

View File

@ -19,6 +19,7 @@ struct ModelManagementView: View {
@EnvironmentObject private var enhancementService: AIEnhancementService
@Environment(\.modelContext) private var modelContext
@StateObject private var whisperPrompt = WhisperPrompt()
@ObservedObject private var warmupCoordinator = WhisperModelWarmupCoordinator.shared
@State private var selectedFilter: ModelFilter = .recommended
@State private var isShowingSettings = false
@ -118,6 +119,10 @@ struct ModelManagementView: View {
} else {
VStack(spacing: 12) {
ForEach(filteredModels, id: \.id) { model in
let isWarming = (model as? LocalModel).map { localModel in
warmupCoordinator.isWarming(modelNamed: localModel.name)
} ?? false
ModelCardRowView(
model: model,
whisperState: whisperState,
@ -125,6 +130,7 @@ struct ModelManagementView: View {
isCurrent: whisperState.currentTranscriptionModel?.name == model.name,
downloadProgress: whisperState.downloadProgress,
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
isWarming: isWarming,
deleteAction: {
if let customModel = model as? CustomCloudModel {
alertTitle = "Delete Custom Model"

View File

@ -0,0 +1,68 @@
import Foundation
import Combine
@MainActor
final class WhisperModelWarmupCoordinator: ObservableObject {
static let shared = WhisperModelWarmupCoordinator()
@Published private(set) var warmingModels: Set<String> = []
private init() {}
func isWarming(modelNamed name: String) -> Bool {
warmingModels.contains(name)
}
func scheduleWarmup(for model: LocalModel, whisperState: WhisperState) {
guard shouldWarmup(modelName: model.name),
!warmingModels.contains(model.name) else {
return
}
warmingModels.insert(model.name)
Task {
do {
try await runWarmup(for: model, whisperState: whisperState)
} catch {
await MainActor.run {
whisperState.logger.error("Warmup failed for \(model.name): \(error.localizedDescription)")
}
}
await MainActor.run {
self.warmingModels.remove(model.name)
}
}
}
private func runWarmup(for model: LocalModel, whisperState: WhisperState) async throws {
guard let sampleURL = warmupSampleURL() else { return }
let service = LocalTranscriptionService(
modelsDirectory: whisperState.modelsDirectory,
whisperState: whisperState
)
_ = try await service.transcribe(audioURL: sampleURL, model: model)
}
private func warmupSampleURL() -> URL? {
let bundle = Bundle.main
let candidates: [URL?] = [
bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Resources/Sounds"),
bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Sounds"),
bundle.url(forResource: "esc", withExtension: "wav")
]
for candidate in candidates {
if let url = candidate {
return url
}
}
return nil
}
private func shouldWarmup(modelName: String) -> Bool {
!modelName.contains("q5") && !modelName.contains("q8")
}
}

View File

@ -157,57 +157,26 @@ extension WhisperState {
}
}
}
// Shows an alert about Core ML support and first-run optimization
private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) {
Task { @MainActor in
let alert = NSAlert()
alert.messageText = "Core ML Support for \(model.displayName) Model"
alert.informativeText = "This Whisper model supports Core ML, which can improve performance by 2-4x on Apple Silicon devices.\n\nDuring the first run, it can take several minutes to optimize the model for your system. Subsequent runs will be much faster."
alert.alertStyle = .informational
alert.addButton(withTitle: "Download")
alert.addButton(withTitle: "Cancel")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
completion()
}
}
}
func downloadModel(_ model: LocalModel) async {
guard let url = URL(string: model.downloadURL) else { return }
// Check if model supports Core ML (non-quantized models)
let supportsCoreML = !model.name.contains("q5") && !model.name.contains("q8")
if supportsCoreML {
// Show the CoreML alert for models that support it
await MainActor.run {
showCoreMLAlert(for: model) {
// This completion handler is called when user clicks "Download"
Task {
await self.performModelDownload(model, url)
}
}
}
} else {
// Directly download the model if it doesn't support Core ML
await performModelDownload(model, url)
}
await performModelDownload(model, url)
}
private func performModelDownload(_ model: LocalModel, _ url: URL) async {
do {
let whisperModel = try await downloadMainModel(model, from: url)
var whisperModel = try await downloadMainModel(model, from: url)
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
let coreMLURL = URL(string: coreMLZipURL) {
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
whisperModel = try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
}
availableModels.append(whisperModel)
self.downloadProgress.removeValue(forKey: model.name + "_main")
if shouldWarmup(model) {
WhisperModelWarmupCoordinator.shared.scheduleWarmup(for: model, whisperState: self)
}
} catch {
handleModelDownloadError(model, error)
}
@ -223,22 +192,22 @@ extension WhisperState {
return WhisperModel(name: model.name, url: destinationURL)
}
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws -> WhisperModel {
let progressKeyCoreML = model.name + "_coreml"
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
try coreMLData.write(to: coreMLZipPath)
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
return try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
}
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws -> WhisperModel {
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
try? FileManager.default.removeItem(at: coreMLDestination)
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
return try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
}
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
@ -268,6 +237,10 @@ extension WhisperState {
return model
}
private func shouldWarmup(_ model: LocalModel) -> Bool {
!model.name.contains("q5") && !model.name.contains("q8")
}
private func handleModelDownloadError(_ model: LocalModel, _ error: Error) {
self.downloadProgress.removeValue(forKey: model.name + "_main")