376 lines
14 KiB
Swift
376 lines
14 KiB
Swift
import Foundation
|
|
import os
|
|
import Zip
|
|
import SwiftUI
|
|
|
|
// MARK: - Model Management Extension
|
|
extension WhisperState {
|
|
|
|
// MARK: - Model Directory Management
|
|
|
|
func createModelsDirectoryIfNeeded() {
|
|
do {
|
|
try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil)
|
|
} catch {
|
|
logError("Error creating models directory", error)
|
|
}
|
|
}
|
|
|
|
func loadAvailableModels() {
|
|
do {
|
|
let fileURLs = try FileManager.default.contentsOfDirectory(at: modelsDirectory, includingPropertiesForKeys: nil)
|
|
availableModels = fileURLs.compactMap { url in
|
|
guard url.pathExtension == "bin" else { return nil }
|
|
return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url)
|
|
}
|
|
} catch {
|
|
logError("Error loading available models", error)
|
|
}
|
|
}
|
|
|
|
// MARK: - Model Loading
|
|
|
|
func loadModel(_ model: WhisperModel) async throws {
|
|
guard whisperContext == nil else { return }
|
|
|
|
isModelLoading = true
|
|
defer { isModelLoading = false }
|
|
|
|
do {
|
|
whisperContext = try await WhisperContext.createContext(path: model.url.path)
|
|
isModelLoaded = true
|
|
currentModel = model
|
|
} catch {
|
|
throw WhisperStateError.modelLoadFailed
|
|
}
|
|
}
|
|
|
|
func setDefaultModel(_ model: WhisperModel) async {
|
|
currentModel = model
|
|
UserDefaults.standard.set(model.name, forKey: "CurrentModel")
|
|
canTranscribe = true
|
|
}
|
|
|
|
// MARK: - Model Download & Management
|
|
|
|
/// Helper function to download a file from a URL with progress tracking
|
|
private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data {
|
|
let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString)
|
|
|
|
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Data, Error>) in
|
|
let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in
|
|
if let error = error {
|
|
continuation.resume(throwing: error)
|
|
return
|
|
}
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse,
|
|
(200...299).contains(httpResponse.statusCode),
|
|
let tempURL = tempURL else {
|
|
continuation.resume(throwing: URLError(.badServerResponse))
|
|
return
|
|
}
|
|
|
|
do {
|
|
// Move the downloaded file to the final destination
|
|
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
|
|
|
|
// Read the file in chunks to avoid memory pressure
|
|
let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe)
|
|
continuation.resume(returning: data)
|
|
|
|
// Clean up the temporary file
|
|
try? FileManager.default.removeItem(at: destinationURL)
|
|
} catch {
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
|
|
task.resume()
|
|
|
|
var lastUpdateTime = Date()
|
|
var lastProgressValue: Double = 0
|
|
|
|
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in
|
|
let currentTime = Date()
|
|
let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime)
|
|
let currentProgress = round(progress.fractionCompleted * 100) / 100
|
|
|
|
if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 {
|
|
lastUpdateTime = currentTime
|
|
lastProgressValue = currentProgress
|
|
|
|
DispatchQueue.main.async {
|
|
self.downloadProgress[progressKey] = currentProgress
|
|
}
|
|
}
|
|
}
|
|
|
|
Task {
|
|
await withTaskCancellationHandler {
|
|
observation.invalidate()
|
|
} operation: {
|
|
await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shows an alert about Core ML support and first-run optimization
|
|
private func showCoreMLAlert(for model: PredefinedModel, completion: @escaping () -> Void) {
|
|
Task { @MainActor in
|
|
let alert = NSAlert()
|
|
alert.messageText = "Core ML Support for \(model.displayName) Model"
|
|
alert.informativeText = "This Whisper model supports Core ML, which can improve performance by 2-4x on Apple Silicon devices.\n\nDuring the first run, it can take several minutes to optimize the model for your system. Subsequent runs will be much faster."
|
|
alert.alertStyle = .informational
|
|
alert.addButton(withTitle: "Download")
|
|
alert.addButton(withTitle: "Cancel")
|
|
|
|
let response = alert.runModal()
|
|
if response == .alertFirstButtonReturn {
|
|
completion()
|
|
}
|
|
}
|
|
}
|
|
|
|
func downloadModel(_ model: PredefinedModel) async {
|
|
guard let url = URL(string: model.downloadURL) else { return }
|
|
|
|
// Check if model supports Core ML (non-quantized models)
|
|
let supportsCoreML = !model.name.contains("q5") && !model.name.contains("q8")
|
|
|
|
if supportsCoreML {
|
|
// Show the CoreML alert for models that support it
|
|
await MainActor.run {
|
|
showCoreMLAlert(for: model) {
|
|
// This completion handler is called when user clicks "Download"
|
|
Task {
|
|
await self.performModelDownload(model, url)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// Directly download the model if it doesn't support Core ML
|
|
await performModelDownload(model, url)
|
|
}
|
|
}
|
|
|
|
private func performModelDownload(_ model: PredefinedModel, _ url: URL) async {
|
|
do {
|
|
let whisperModel = try await downloadMainModel(model, from: url)
|
|
|
|
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
|
|
let coreMLURL = URL(string: coreMLZipURL) {
|
|
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
|
|
}
|
|
|
|
availableModels.append(whisperModel)
|
|
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
|
} catch {
|
|
handleModelDownloadError(model, error)
|
|
}
|
|
}
|
|
|
|
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
|
|
let progressKeyMain = model.name + "_main"
|
|
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
|
|
|
|
let destinationURL = modelsDirectory.appendingPathComponent(model.filename)
|
|
try data.write(to: destinationURL)
|
|
|
|
return WhisperModel(name: model.name, url: destinationURL)
|
|
}
|
|
|
|
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
|
|
let progressKeyCoreML = model.name + "_coreml"
|
|
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
|
|
|
|
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
|
|
try coreMLData.write(to: coreMLZipPath)
|
|
|
|
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
|
|
}
|
|
|
|
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
|
|
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
|
|
|
|
try? FileManager.default.removeItem(at: coreMLDestination)
|
|
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
|
|
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
|
|
}
|
|
|
|
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
|
|
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
|
|
do {
|
|
try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true)
|
|
try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil)
|
|
continuation.resume()
|
|
} catch {
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func verifyAndCleanupCoreMLFiles(_ model: WhisperModel, _ destination: URL, _ zipPath: URL, _ progressKey: String) throws -> WhisperModel {
|
|
var model = model
|
|
|
|
var isDirectory: ObjCBool = false
|
|
guard FileManager.default.fileExists(atPath: destination.path, isDirectory: &isDirectory), isDirectory.boolValue else {
|
|
try? FileManager.default.removeItem(at: zipPath)
|
|
throw WhisperStateError.unzipFailed
|
|
}
|
|
|
|
try? FileManager.default.removeItem(at: zipPath)
|
|
model.coreMLEncoderURL = destination
|
|
self.downloadProgress.removeValue(forKey: progressKey)
|
|
|
|
return model
|
|
}
|
|
|
|
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
|
|
currentError = .modelDownloadFailed
|
|
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
|
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
|
|
}
|
|
|
|
func deleteModel(_ model: WhisperModel) async {
|
|
do {
|
|
// Delete main model file
|
|
try FileManager.default.removeItem(at: model.url)
|
|
|
|
// Delete CoreML model if it exists
|
|
if let coreMLURL = model.coreMLEncoderURL {
|
|
try? FileManager.default.removeItem(at: coreMLURL)
|
|
} else {
|
|
// Check if there's a CoreML directory matching the model name
|
|
let coreMLDir = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
|
|
if FileManager.default.fileExists(atPath: coreMLDir.path) {
|
|
try? FileManager.default.removeItem(at: coreMLDir)
|
|
}
|
|
}
|
|
|
|
// Update model state
|
|
availableModels.removeAll { $0.id == model.id }
|
|
if currentModel?.id == model.id {
|
|
currentModel = nil
|
|
canTranscribe = false
|
|
}
|
|
} catch {
|
|
logError("Error deleting model: \(model.name)", error)
|
|
currentError = .modelDeletionFailed
|
|
}
|
|
}
|
|
|
|
func unloadModel() {
|
|
Task {
|
|
await whisperContext?.releaseResources()
|
|
whisperContext = nil
|
|
isModelLoaded = false
|
|
|
|
if let recordedFile = recordedFile {
|
|
try? FileManager.default.removeItem(at: recordedFile)
|
|
self.recordedFile = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func clearDownloadedModels() async {
|
|
for model in availableModels {
|
|
do {
|
|
try FileManager.default.removeItem(at: model.url)
|
|
} catch {
|
|
logError("Error deleting model during cleanup", error)
|
|
}
|
|
}
|
|
availableModels.removeAll()
|
|
}
|
|
|
|
// MARK: - Resource Management
|
|
|
|
func cleanupModelResources() async {
|
|
recorder.stopRecording()
|
|
try? await Task.sleep(nanoseconds: 500_000_000)
|
|
|
|
let canCleanup = !isRecording && !isProcessing
|
|
|
|
if canCleanup {
|
|
await whisperContext?.releaseResources()
|
|
whisperContext = nil
|
|
isModelLoaded = false
|
|
}
|
|
}
|
|
|
|
// MARK: - Helper Methods
|
|
|
|
private func logError(_ message: String, _ error: Error) {
|
|
self.logger.error("\(message): \(error.localizedDescription)")
|
|
}
|
|
}
|
|
|
|
// MARK: - Download Progress View
|
|
struct DownloadProgressView: View {
|
|
let modelName: String
|
|
let downloadProgress: [String: Double]
|
|
|
|
@Environment(\.colorScheme) private var colorScheme
|
|
|
|
private var mainProgress: Double {
|
|
downloadProgress[modelName + "_main"] ?? 0
|
|
}
|
|
|
|
private var coreMLProgress: Double {
|
|
supportsCoreML ? (downloadProgress[modelName + "_coreml"] ?? 0) : 0
|
|
}
|
|
|
|
private var supportsCoreML: Bool {
|
|
!modelName.contains("q5") && !modelName.contains("q8")
|
|
}
|
|
|
|
private var totalProgress: Double {
|
|
supportsCoreML ? (mainProgress * 0.5) + (coreMLProgress * 0.5) : mainProgress
|
|
}
|
|
|
|
private var downloadPhase: String {
|
|
// Check if we're currently downloading the CoreML model
|
|
if supportsCoreML && downloadProgress[modelName + "_coreml"] != nil {
|
|
return "Downloading Core ML Model for \(modelName)"
|
|
}
|
|
// Otherwise, we're downloading the main model
|
|
return "Downloading \(modelName) Model"
|
|
}
|
|
|
|
var body: some View {
|
|
VStack(alignment: .leading, spacing: 8) {
|
|
// Status text with clean typography
|
|
Text(downloadPhase)
|
|
.font(.system(size: 12, weight: .medium))
|
|
.foregroundColor(Color(.secondaryLabelColor))
|
|
|
|
// Clean progress bar
|
|
GeometryReader { geometry in
|
|
ZStack(alignment: .leading) {
|
|
// Background track
|
|
RoundedRectangle(cornerRadius: 4)
|
|
.fill(Color(.separatorColor).opacity(0.3))
|
|
.frame(height: 6)
|
|
|
|
// Progress indicator
|
|
RoundedRectangle(cornerRadius: 4)
|
|
.fill(Color(.controlAccentColor))
|
|
.frame(width: max(0, min(geometry.size.width * totalProgress, geometry.size.width)), height: 6)
|
|
}
|
|
}
|
|
.frame(height: 6)
|
|
|
|
// Percentage indicator in Apple style
|
|
HStack {
|
|
Spacer()
|
|
Text("\(Int(totalProgress * 100))%")
|
|
.font(.system(size: 11, weight: .medium, design: .monospaced))
|
|
.foregroundColor(Color(.secondaryLabelColor))
|
|
}
|
|
}
|
|
.padding(.vertical, 4)
|
|
.animation(.smooth, value: totalProgress)
|
|
}
|
|
} |