vOOice/VoiceInk/Whisper/WhisperState+ModelManager.swift
2025-05-24 23:19:36 +05:45

377 lines
15 KiB
Swift

import Foundation
import os
import Zip
import SwiftUI
// MARK: - Model Management Extension
extension WhisperState {
// MARK: - Model Directory Management
func createModelsDirectoryIfNeeded() {
do {
try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil)
} catch {
logError("Error creating models directory", error)
}
}
func loadAvailableModels() {
do {
let fileURLs = try FileManager.default.contentsOfDirectory(at: modelsDirectory, includingPropertiesForKeys: nil)
availableModels = fileURLs.compactMap { url in
guard url.pathExtension == "bin" else { return nil }
return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url)
}
} catch {
logError("Error loading available models", error)
}
}
// MARK: - Model Loading
func loadModel(_ model: WhisperModel) async throws {
guard whisperContext == nil else { return }
isModelLoading = true
defer { isModelLoading = false }
do {
whisperContext = try await WhisperContext.createContext(path: model.url.path)
// Set the prompt from UserDefaults to ensure we have the latest
let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? whisperPrompt.transcriptionPrompt
await whisperContext?.setPrompt(currentPrompt)
isModelLoaded = true
currentModel = model
} catch {
throw WhisperStateError.modelLoadFailed
}
}
func setDefaultModel(_ model: WhisperModel) async {
currentModel = model
UserDefaults.standard.set(model.name, forKey: "CurrentModel")
canTranscribe = true
}
// MARK: - Model Download & Management
/// Helper function to download a file from a URL with progress tracking
private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data {
let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString)
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Data, Error>) in
let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in
if let error = error {
continuation.resume(throwing: error)
return
}
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode),
let tempURL = tempURL else {
continuation.resume(throwing: URLError(.badServerResponse))
return
}
do {
// Move the downloaded file to the final destination
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
// Read the file in chunks to avoid memory pressure
let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe)
continuation.resume(returning: data)
// Clean up the temporary file
try? FileManager.default.removeItem(at: destinationURL)
} catch {
continuation.resume(throwing: error)
}
}
task.resume()
var lastUpdateTime = Date()
var lastProgressValue: Double = 0
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in
let currentTime = Date()
let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime)
let currentProgress = round(progress.fractionCompleted * 100) / 100
if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 {
lastUpdateTime = currentTime
lastProgressValue = currentProgress
DispatchQueue.main.async {
self.downloadProgress[progressKey] = currentProgress
}
}
}
Task {
await withTaskCancellationHandler {
observation.invalidate()
} operation: {
await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
}
}
}
}
// Shows an alert about Core ML support and first-run optimization
private func showCoreMLAlert(for model: PredefinedModel, completion: @escaping () -> Void) {
Task { @MainActor in
let alert = NSAlert()
alert.messageText = "Core ML Support for \(model.displayName) Model"
alert.informativeText = "This Whisper model supports Core ML, which can improve performance by 2-4x on Apple Silicon devices.\n\nDuring the first run, it can take several minutes to optimize the model for your system. Subsequent runs will be much faster."
alert.alertStyle = .informational
alert.addButton(withTitle: "Download")
alert.addButton(withTitle: "Cancel")
let response = alert.runModal()
if response == .alertFirstButtonReturn {
completion()
}
}
}
func downloadModel(_ model: PredefinedModel) async {
guard let url = URL(string: model.downloadURL) else { return }
// Check if model supports Core ML (non-quantized models)
let supportsCoreML = !model.name.contains("q5") && !model.name.contains("q8")
if supportsCoreML {
// Show the CoreML alert for models that support it
await MainActor.run {
showCoreMLAlert(for: model) {
// This completion handler is called when user clicks "Download"
Task {
await self.performModelDownload(model, url)
}
}
}
} else {
// Directly download the model if it doesn't support Core ML
await performModelDownload(model, url)
}
}
private func performModelDownload(_ model: PredefinedModel, _ url: URL) async {
do {
let whisperModel = try await downloadMainModel(model, from: url)
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
let coreMLURL = URL(string: coreMLZipURL) {
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
}
availableModels.append(whisperModel)
self.downloadProgress.removeValue(forKey: model.name + "_main")
} catch {
handleModelDownloadError(model, error)
}
}
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
let progressKeyMain = model.name + "_main"
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
let destinationURL = modelsDirectory.appendingPathComponent(model.filename)
try data.write(to: destinationURL)
return WhisperModel(name: model.name, url: destinationURL)
}
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
let progressKeyCoreML = model.name + "_coreml"
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
try coreMLData.write(to: coreMLZipPath)
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
}
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
try? FileManager.default.removeItem(at: coreMLDestination)
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
}
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
do {
try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true)
try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil)
continuation.resume()
} catch {
continuation.resume(throwing: error)
}
}
}
private func verifyAndCleanupCoreMLFiles(_ model: WhisperModel, _ destination: URL, _ zipPath: URL, _ progressKey: String) throws -> WhisperModel {
var model = model
var isDirectory: ObjCBool = false
guard FileManager.default.fileExists(atPath: destination.path, isDirectory: &isDirectory), isDirectory.boolValue else {
try? FileManager.default.removeItem(at: zipPath)
throw WhisperStateError.unzipFailed
}
try? FileManager.default.removeItem(at: zipPath)
model.coreMLEncoderURL = destination
self.downloadProgress.removeValue(forKey: progressKey)
return model
}
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
currentError = .modelDownloadFailed
self.downloadProgress.removeValue(forKey: model.name + "_main")
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
}
func deleteModel(_ model: WhisperModel) async {
do {
// Delete main model file
try FileManager.default.removeItem(at: model.url)
// Delete CoreML model if it exists
if let coreMLURL = model.coreMLEncoderURL {
try? FileManager.default.removeItem(at: coreMLURL)
} else {
// Check if there's a CoreML directory matching the model name
let coreMLDir = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
if FileManager.default.fileExists(atPath: coreMLDir.path) {
try? FileManager.default.removeItem(at: coreMLDir)
}
}
// Update model state
availableModels.removeAll { $0.id == model.id }
if currentModel?.id == model.id {
currentModel = nil
canTranscribe = false
}
} catch {
logError("Error deleting model: \(model.name)", error)
currentError = .modelDeletionFailed
}
}
func unloadModel() {
Task {
await whisperContext?.releaseResources()
whisperContext = nil
isModelLoaded = false
if let recordedFile = recordedFile {
try? FileManager.default.removeItem(at: recordedFile)
self.recordedFile = nil
}
}
}
func clearDownloadedModels() async {
for model in availableModels {
do {
try FileManager.default.removeItem(at: model.url)
} catch {
logError("Error deleting model during cleanup", error)
}
}
availableModels.removeAll()
}
// MARK: - Resource Management
func cleanupModelResources() async {
try? await Task.sleep(nanoseconds: 300_000_000)
await whisperContext?.releaseResources()
whisperContext = nil
isModelLoaded = false
}
// MARK: - Helper Methods
private func logError(_ message: String, _ error: Error) {
self.logger.error("\(message): \(error.localizedDescription)")
}
}
// MARK: - Download Progress View
struct DownloadProgressView: View {
let modelName: String
let downloadProgress: [String: Double]
@Environment(\.colorScheme) private var colorScheme
private var mainProgress: Double {
downloadProgress[modelName + "_main"] ?? 0
}
private var coreMLProgress: Double {
supportsCoreML ? (downloadProgress[modelName + "_coreml"] ?? 0) : 0
}
private var supportsCoreML: Bool {
!modelName.contains("q5") && !modelName.contains("q8")
}
private var totalProgress: Double {
supportsCoreML ? (mainProgress * 0.5) + (coreMLProgress * 0.5) : mainProgress
}
private var downloadPhase: String {
// Check if we're currently downloading the CoreML model
if supportsCoreML && downloadProgress[modelName + "_coreml"] != nil {
return "Downloading Core ML Model for \(modelName)"
}
// Otherwise, we're downloading the main model
return "Downloading \(modelName) Model"
}
var body: some View {
VStack(alignment: .leading, spacing: 8) {
// Status text with clean typography
Text(downloadPhase)
.font(.system(size: 12, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
// Clean progress bar
GeometryReader { geometry in
ZStack(alignment: .leading) {
// Background track
RoundedRectangle(cornerRadius: 4)
.fill(Color(.separatorColor).opacity(0.3))
.frame(height: 6)
// Progress indicator
RoundedRectangle(cornerRadius: 4)
.fill(Color(.controlAccentColor))
.frame(width: max(0, min(geometry.size.width * totalProgress, geometry.size.width)), height: 6)
}
}
.frame(height: 6)
// Percentage indicator in Apple style
HStack {
Spacer()
Text("\(Int(totalProgress * 100))%")
.font(.system(size: 11, weight: .medium, design: .monospaced))
.foregroundColor(Color(.secondaryLabelColor))
}
}
.padding(.vertical, 4)
.animation(.smooth, value: totalProgress)
}
}