Auto warm-up downloaded Whisper models after downloading
This commit is contained in:
parent
3f01f49f56
commit
b568fb9fa0
@ -7,12 +7,12 @@ struct LocalModelCardView: View {
|
||||
let isCurrent: Bool
|
||||
let downloadProgress: [String: Double]
|
||||
let modelURL: URL?
|
||||
let isWarming: Bool
|
||||
|
||||
// Actions
|
||||
var deleteAction: () -> Void
|
||||
var setDefaultAction: () -> Void
|
||||
var downloadAction: () -> Void
|
||||
|
||||
private var isDownloading: Bool {
|
||||
downloadProgress.keys.contains(model.name + "_main") ||
|
||||
downloadProgress.keys.contains(model.name + "_coreml")
|
||||
@ -134,12 +134,22 @@ struct LocalModelCardView: View {
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
} else if isDownloaded {
|
||||
Button(action: setDefaultAction) {
|
||||
Text("Set as Default")
|
||||
.font(.system(size: 12))
|
||||
if isWarming {
|
||||
HStack(spacing: 6) {
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
Text("Optimizing model for your device...")
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
}
|
||||
} else {
|
||||
Button(action: setDefaultAction) {
|
||||
Text("Set as Default")
|
||||
.font(.system(size: 12))
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
} else {
|
||||
Button(action: downloadAction) {
|
||||
HStack(spacing: 4) {
|
||||
|
||||
@ -8,13 +8,13 @@ struct ModelCardRowView: View {
|
||||
let isCurrent: Bool
|
||||
let downloadProgress: [String: Double]
|
||||
let modelURL: URL?
|
||||
let isWarming: Bool
|
||||
|
||||
// Actions
|
||||
var deleteAction: () -> Void
|
||||
var setDefaultAction: () -> Void
|
||||
var downloadAction: () -> Void
|
||||
var editAction: ((CustomCloudModel) -> Void)?
|
||||
|
||||
var body: some View {
|
||||
Group {
|
||||
switch model.provider {
|
||||
@ -26,6 +26,7 @@ struct ModelCardRowView: View {
|
||||
isCurrent: isCurrent,
|
||||
downloadProgress: downloadProgress,
|
||||
modelURL: modelURL,
|
||||
isWarming: isWarming,
|
||||
deleteAction: deleteAction,
|
||||
setDefaultAction: setDefaultAction,
|
||||
downloadAction: downloadAction
|
||||
@ -76,4 +77,4 @@ struct ModelCardRowView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ struct ModelManagementView: View {
|
||||
@EnvironmentObject private var enhancementService: AIEnhancementService
|
||||
@Environment(\.modelContext) private var modelContext
|
||||
@StateObject private var whisperPrompt = WhisperPrompt()
|
||||
@ObservedObject private var warmupCoordinator = WhisperModelWarmupCoordinator.shared
|
||||
|
||||
@State private var selectedFilter: ModelFilter = .recommended
|
||||
@State private var isShowingSettings = false
|
||||
@ -118,6 +119,10 @@ struct ModelManagementView: View {
|
||||
} else {
|
||||
VStack(spacing: 12) {
|
||||
ForEach(filteredModels, id: \.id) { model in
|
||||
let isWarming = (model as? LocalModel).map { localModel in
|
||||
warmupCoordinator.isWarming(modelNamed: localModel.name)
|
||||
} ?? false
|
||||
|
||||
ModelCardRowView(
|
||||
model: model,
|
||||
whisperState: whisperState,
|
||||
@ -125,6 +130,7 @@ struct ModelManagementView: View {
|
||||
isCurrent: whisperState.currentTranscriptionModel?.name == model.name,
|
||||
downloadProgress: whisperState.downloadProgress,
|
||||
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
|
||||
isWarming: isWarming,
|
||||
deleteAction: {
|
||||
if let customModel = model as? CustomCloudModel {
|
||||
alertTitle = "Delete Custom Model"
|
||||
|
||||
68
VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift
Normal file
68
VoiceInk/Whisper/WhisperModelWarmupCoordinator.swift
Normal file
@ -0,0 +1,68 @@
|
||||
import Foundation
|
||||
import Combine
|
||||
|
||||
@MainActor
|
||||
final class WhisperModelWarmupCoordinator: ObservableObject {
|
||||
static let shared = WhisperModelWarmupCoordinator()
|
||||
|
||||
@Published private(set) var warmingModels: Set<String> = []
|
||||
|
||||
private init() {}
|
||||
|
||||
func isWarming(modelNamed name: String) -> Bool {
|
||||
warmingModels.contains(name)
|
||||
}
|
||||
|
||||
func scheduleWarmup(for model: LocalModel, whisperState: WhisperState) {
|
||||
guard shouldWarmup(modelName: model.name),
|
||||
!warmingModels.contains(model.name) else {
|
||||
return
|
||||
}
|
||||
|
||||
warmingModels.insert(model.name)
|
||||
|
||||
Task {
|
||||
do {
|
||||
try await runWarmup(for: model, whisperState: whisperState)
|
||||
} catch {
|
||||
await MainActor.run {
|
||||
whisperState.logger.error("Warmup failed for \(model.name): \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.warmingModels.remove(model.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func runWarmup(for model: LocalModel, whisperState: WhisperState) async throws {
|
||||
guard let sampleURL = warmupSampleURL() else { return }
|
||||
let service = LocalTranscriptionService(
|
||||
modelsDirectory: whisperState.modelsDirectory,
|
||||
whisperState: whisperState
|
||||
)
|
||||
_ = try await service.transcribe(audioURL: sampleURL, model: model)
|
||||
}
|
||||
|
||||
private func warmupSampleURL() -> URL? {
|
||||
let bundle = Bundle.main
|
||||
let candidates: [URL?] = [
|
||||
bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Resources/Sounds"),
|
||||
bundle.url(forResource: "esc", withExtension: "wav", subdirectory: "Sounds"),
|
||||
bundle.url(forResource: "esc", withExtension: "wav")
|
||||
]
|
||||
|
||||
for candidate in candidates {
|
||||
if let url = candidate {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
private func shouldWarmup(modelName: String) -> Bool {
|
||||
!modelName.contains("q5") && !modelName.contains("q8")
|
||||
}
|
||||
}
|
||||
@ -157,57 +157,26 @@ extension WhisperState {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shows an alert about Core ML support and first-run optimization
|
||||
private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) {
|
||||
Task { @MainActor in
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Core ML Support for \(model.displayName) Model"
|
||||
alert.informativeText = "This Whisper model supports Core ML, which can improve performance by 2-4x on Apple Silicon devices.\n\nDuring the first run, it can take several minutes to optimize the model for your system. Subsequent runs will be much faster."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Download")
|
||||
alert.addButton(withTitle: "Cancel")
|
||||
|
||||
let response = alert.runModal()
|
||||
if response == .alertFirstButtonReturn {
|
||||
completion()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func downloadModel(_ model: LocalModel) async {
|
||||
guard let url = URL(string: model.downloadURL) else { return }
|
||||
|
||||
// Check if model supports Core ML (non-quantized models)
|
||||
let supportsCoreML = !model.name.contains("q5") && !model.name.contains("q8")
|
||||
|
||||
if supportsCoreML {
|
||||
// Show the CoreML alert for models that support it
|
||||
await MainActor.run {
|
||||
showCoreMLAlert(for: model) {
|
||||
// This completion handler is called when user clicks "Download"
|
||||
Task {
|
||||
await self.performModelDownload(model, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Directly download the model if it doesn't support Core ML
|
||||
await performModelDownload(model, url)
|
||||
}
|
||||
await performModelDownload(model, url)
|
||||
}
|
||||
|
||||
private func performModelDownload(_ model: LocalModel, _ url: URL) async {
|
||||
do {
|
||||
let whisperModel = try await downloadMainModel(model, from: url)
|
||||
var whisperModel = try await downloadMainModel(model, from: url)
|
||||
|
||||
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
|
||||
let coreMLURL = URL(string: coreMLZipURL) {
|
||||
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
|
||||
whisperModel = try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
|
||||
}
|
||||
|
||||
availableModels.append(whisperModel)
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
||||
|
||||
if shouldWarmup(model) {
|
||||
WhisperModelWarmupCoordinator.shared.scheduleWarmup(for: model, whisperState: self)
|
||||
}
|
||||
} catch {
|
||||
handleModelDownloadError(model, error)
|
||||
}
|
||||
@ -223,22 +192,22 @@ extension WhisperState {
|
||||
return WhisperModel(name: model.name, url: destinationURL)
|
||||
}
|
||||
|
||||
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
|
||||
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws -> WhisperModel {
|
||||
let progressKeyCoreML = model.name + "_coreml"
|
||||
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
|
||||
|
||||
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
|
||||
try coreMLData.write(to: coreMLZipPath)
|
||||
|
||||
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
|
||||
return try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
|
||||
}
|
||||
|
||||
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
|
||||
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws -> WhisperModel {
|
||||
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
|
||||
|
||||
try? FileManager.default.removeItem(at: coreMLDestination)
|
||||
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
|
||||
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
|
||||
return try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
|
||||
}
|
||||
|
||||
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
|
||||
@ -268,6 +237,10 @@ extension WhisperState {
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
private func shouldWarmup(_ model: LocalModel) -> Bool {
|
||||
!model.name.contains("q5") && !model.name.contains("q8")
|
||||
}
|
||||
|
||||
private func handleModelDownloadError(_ model: LocalModel, _ error: Error) {
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user