From 0923093bcd66b05fc3840e0b281ce4de5e3ba8df Mon Sep 17 00:00:00 2001 From: Beingpax Date: Tue, 8 Apr 2025 18:08:19 +0545 Subject: [PATCH] Added support for CoreML and redesigned the AI Models view --- VoiceInk/Models/PredefinedModels.swift | 8 +- VoiceInk/Views/ModelCardRowView.swift | 253 +++++++++++++++ VoiceInk/Views/ModelManagementView.swift | 220 ++----------- .../OnboardingModelDownloadView.swift | 12 +- VoiceInk/Whisper/WhisperError.swift | 5 + .../Whisper/WhisperState+ModelManager.swift | 294 ++++++++++++++---- VoiceInk/Whisper/WhisperState.swift | 16 + 7 files changed, 532 insertions(+), 276 deletions(-) create mode 100644 VoiceInk/Views/ModelCardRowView.swift diff --git a/VoiceInk/Models/PredefinedModels.swift b/VoiceInk/Models/PredefinedModels.swift index b8c30b3..603406e 100644 --- a/VoiceInk/Models/PredefinedModels.swift +++ b/VoiceInk/Models/PredefinedModels.swift @@ -82,7 +82,7 @@ import Foundation size: "2.9 GiB", supportedLanguages: getLanguageDictionary(isMultilingual: true), description: "Large model v2, slower than Medium but more accurate, supports multiple languages", - speed: 0.25, + speed: 0.5, accuracy: 0.96, ramUsage: 3.8, hash: "0f4c8e30f21cf1769f637135f521436792c48186" @@ -93,7 +93,7 @@ import Foundation size: "2.9 GiB", supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), description: "Large model v3, very slow but most accurate, supports multiple languages", - speed: 0.2, + speed: 0.5, accuracy: 0.98, ramUsage: 3.9, hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062" @@ -105,7 +105,7 @@ import Foundation supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), description: "Large model v3 Turbo, faster than v3 with similar accuracy, supports multiple languages", - speed: 0.5, + speed: 0.7, accuracy: 0.97, ramUsage: 1.8, hash: "4af2b29d7ec73d781377bfd1758ca957a807e941" @@ -116,7 +116,7 @@ import Foundation size: "547 MiB", supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), description: "Quantized version of Large v3 Turbo, faster with slightly lower accuracy", - speed: 0.6, + speed: 0.7, accuracy: 0.96, ramUsage: 1.0, hash: "e050f7970618a659205450ad97eb95a18d69c9ee" diff --git a/VoiceInk/Views/ModelCardRowView.swift b/VoiceInk/Views/ModelCardRowView.swift new file mode 100644 index 0000000..b341745 --- /dev/null +++ b/VoiceInk/Views/ModelCardRowView.swift @@ -0,0 +1,253 @@ +import SwiftUI +import AppKit + +struct ModelCardRowView: View { + let model: PredefinedModel + let isDownloaded: Bool + let isCurrent: Bool + let downloadProgress: [String: Double] + let modelURL: URL? + + // Actions + var deleteAction: () -> Void + var setDefaultAction: () -> Void + var downloadAction: () -> Void + + private var isDownloading: Bool { + downloadProgress.keys.contains(model.name + "_main") || + downloadProgress.keys.contains(model.name + "_coreml") + } + + var body: some View { + HStack(alignment: .top, spacing: 16) { + // Main Content + VStack(alignment: .leading, spacing: 6) { + headerSection + metadataSection + descriptionSection + progressSection + } + .frame(maxWidth: .infinity, alignment: .leading) + + // Action Controls + actionSection + } + .padding(16) + .background( + RoundedRectangle(cornerRadius: 14) + .fill( + LinearGradient( + gradient: Gradient(colors: [ + Color(NSColor.controlBackgroundColor).opacity(1.0), + Color(NSColor.controlBackgroundColor).opacity(0.6) + ]), + startPoint: .topLeading, + endPoint: .bottomTrailing + ) + ) + .overlay( + RoundedRectangle(cornerRadius: 14) + .stroke( + LinearGradient( + gradient: Gradient(colors: [ + Color.white.opacity(0.3), + Color.white.opacity(0.05) + ]), + startPoint: .topLeading, + endPoint: .bottomTrailing + ), + lineWidth: 1 + ) + ) + .shadow( + color: Color.primary.opacity(0.08), + radius: 8, + x: 0, + y: 4 + ) + ) + } + + // MARK: - Components + + private var headerSection: some View { + HStack(alignment: .firstTextBaseline) { + Text(model.displayName) + .font(.system(size: 13, weight: .semibold)) + .foregroundColor(Color(.labelColor)) + + statusBadge + + Spacer() + } + } + + private var statusBadge: some View { + Group { + if isCurrent { + Text("Default") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color.accentColor)) + .foregroundColor(.white) + } else if isDownloaded { + Text("Downloaded") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color(.quaternaryLabelColor))) + .foregroundColor(Color(.labelColor)) + } + } + } + + private var metadataSection: some View { + HStack(spacing: 16) { + // Language + Label(model.language, systemImage: "globe") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + + // Size + Label(model.size, systemImage: "internaldrive") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + + // Speed + HStack(spacing: 4) { + Text("Speed") + .font(.system(size: 11, weight: .medium)) + .foregroundColor(Color(.secondaryLabelColor)) + progressDotsWithNumber(value: model.speed * 10) + } + + // Accuracy + HStack(spacing: 4) { + Text("Accuracy") + .font(.system(size: 11, weight: .medium)) + .foregroundColor(Color(.secondaryLabelColor)) + progressDotsWithNumber(value: model.accuracy * 10) + } + } + } + + private var descriptionSection: some View { + Text(model.description) + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(2) + .fixedSize(horizontal: false, vertical: true) + .padding(.top, 4) + } + + private var progressSection: some View { + Group { + if isDownloading { + DownloadProgressView( + modelName: model.name, + downloadProgress: downloadProgress + ) + .padding(.top, 8) + .frame(maxWidth: .infinity, alignment: .leading) + } + } + } + + private var actionSection: some View { + HStack(spacing: 8) { + if isCurrent { + Text("Default Model") + .font(.system(size: 12)) + .foregroundColor(Color(.secondaryLabelColor)) + } else if isDownloaded { + Button(action: setDefaultAction) { + Text("Set as Default") + .font(.system(size: 12)) + } + .buttonStyle(.bordered) + .controlSize(.small) + } else { + Button(action: downloadAction) { + HStack(spacing: 4) { + Text(isDownloading ? "Downloading..." : "Download") + .font(.system(size: 12, weight: .medium)) + Image(systemName: "arrow.down.circle") + .font(.system(size: 12, weight: .medium)) + } + .foregroundColor(.white) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background( + Capsule() + .fill(Color(.controlAccentColor)) + .shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1) + ) + } + .buttonStyle(.plain) + .disabled(isDownloading) + } + + if isDownloaded { + Menu { + Button(action: deleteAction) { + Label("Delete Model", systemImage: "trash") + } + + if isDownloaded { + Button { + if let modelURL = modelURL { + NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "") + } + } label: { + Label("Show in Finder", systemImage: "folder") + } + } + } label: { + Image(systemName: "ellipsis.circle") + .font(.system(size: 14)) + } + .menuStyle(.borderlessButton) + .menuIndicator(.hidden) + .frame(width: 20, height: 20) + } + } + } + + // MARK: - Helpers + + private var downloadComponents: [(String, Double)] { + [ + ("Model", downloadProgress[model.name + "_main"] ?? 0), + ("CoreML", downloadProgress[model.name + "_coreml"] ?? 0) + ].filter { $0.1 > 0 } + } + + private func progressDotsWithNumber(value: Double) -> some View { + HStack(spacing: 4) { + progressDots(value: value) + Text(String(format: "%.1f", value)) + .font(.system(size: 10, weight: .medium, design: .monospaced)) + .foregroundColor(Color(.secondaryLabelColor)) + } + } + + private func progressDots(value: Double) -> some View { + HStack(spacing: 2) { + ForEach(0..<5) { index in + Circle() + .fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor)) + .frame(width: 6, height: 6) + } + } + } + + private func performanceColor(value: Double) -> Color { + switch value { + case 0.8...1.0: return Color(.systemGreen) + case 0.6..<0.8: return Color(.systemYellow) + case 0.4..<0.6: return Color(.systemOrange) + default: return Color(.systemRed) + } + } +} diff --git a/VoiceInk/Views/ModelManagementView.swift b/VoiceInk/Views/ModelManagementView.swift index c319282..a8fed7b 100644 --- a/VoiceInk/Views/ModelManagementView.swift +++ b/VoiceInk/Views/ModelManagementView.swift @@ -68,217 +68,35 @@ struct ModelManagementView: View { Spacer() } - LazyVGrid(columns: [GridItem(.adaptive(minimum: 300, maximum: 400), spacing: 16)], spacing: 16) { + VStack(spacing: 12) { ForEach(whisperState.predefinedModels) { model in - modelCard(for: model) - } - } - } - .padding() - .background(Color(.windowBackgroundColor).opacity(0.4)) - .cornerRadius(10) - } - - private func modelCard(for model: PredefinedModel) -> some View { - let isDownloaded = whisperState.availableModels.contains { $0.name == model.name } - let isCurrent = whisperState.currentModel?.name == model.name - - return VStack(alignment: .leading, spacing: 12) { - // Model name and details - HStack { - VStack(alignment: .leading, spacing: 4) { - Text(model.displayName) - .font(.headline) - Text("\(model.size) โ€ข \(model.language)") - .font(.caption) - .foregroundColor(.secondary) - } - Spacer() - modelStatusBadge(isDownloaded: isDownloaded, isCurrent: isCurrent) - } - - // Description - Text(model.description) - .font(.subheadline) - .foregroundColor(.secondary) - .lineLimit(2) - - // Performance indicators - HStack(spacing: 16) { - performanceIndicator(label: "Speed", value: model.speed) - performanceIndicator(label: "Accuracy", value: model.accuracy) - ramUsageLabel(gb: model.ramUsage) - } - - // Action buttons - HStack { - modelActionButton(isDownloaded: isDownloaded, isCurrent: isCurrent, model: model) - - if isDownloaded { - Menu { - Button(action: { + ModelCardRowView( + model: model, + isDownloaded: whisperState.availableModels.contains { $0.name == model.name }, + isCurrent: whisperState.currentModel?.name == model.name, + downloadProgress: whisperState.downloadProgress, + modelURL: whisperState.availableModels.first { $0.name == model.name }?.url, + deleteAction: { if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { modelToDelete = downloadedModel } - }) { - Label("Delete", systemImage: "trash") - } - - Button(action: { + }, + setDefaultAction: { if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { - NSWorkspace.shared.selectFile(downloadedModel.url.path, inFileViewerRootedAtPath: "") + Task { + await whisperState.setDefaultModel(downloadedModel) + } + } + }, + downloadAction: { + Task { + await whisperState.downloadModel(model) } - }) { - Label("Show in Finder", systemImage: "folder") } - } label: { - Image(systemName: "ellipsis.circle") - .foregroundColor(.secondary) - } - .menuStyle(BorderlessButtonMenuStyle()) - .frame(width: 30, height: 30) + ) } } } .padding() - .background(Color(.windowBackgroundColor).opacity(0.9)) - .cornerRadius(10) - .overlay( - RoundedRectangle(cornerRadius: 10) - .stroke(isCurrent ? Color.accentColor : Color.gray.opacity(0.2), lineWidth: isCurrent ? 2 : 1) - ) - } - - private func modelStatusBadge(isDownloaded: Bool, isCurrent: Bool) -> some View { - Group { - if isCurrent { - Text("Default") - .font(.caption) - .padding(.horizontal, 8) - .padding(.vertical, 4) - .background(Color.green) - .foregroundColor(.white) - .cornerRadius(8) - } else if isDownloaded { - Text("Downloaded") - .font(.caption) - .padding(.horizontal, 8) - .padding(.vertical, 4) - .background(Color.indigo) - .foregroundColor(.white) - .cornerRadius(8) - } - } - } - - private func performanceIndicator(label: String, value: Double) -> some View { - VStack(alignment: .leading, spacing: 4) { - Text(label) - .font(.caption) - .foregroundColor(.secondary) - - HStack(spacing: 2) { - ForEach(0..<5) { index in - RoundedRectangle(cornerRadius: 2) - .fill(index < Int(value * 5) ? performanceColor(value: value) : Color.secondary.opacity(0.2)) - .frame(width: 16, height: 8) - } - } - - Text(String(format: "%.1f", value * 10)) - .font(.caption) - .foregroundColor(.secondary) - } - } - - private func performanceColor(value: Double) -> Color { - switch value { - case 0.8...: return .green - case 0.6..<0.8: return .yellow - case 0.4..<0.6: return .orange - default: return .red - } - } - - private func modelActionButton(isDownloaded: Bool, isCurrent: Bool, model: PredefinedModel) -> some View { - Group { - if isCurrent { - Text("Default Model") - .foregroundColor(.white) - } else if isDownloaded { - Button("Set as Default") { - if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { - Task { - await whisperState.setDefaultModel(downloadedModel) - } - } - } - .foregroundColor(.white) - } else if whisperState.downloadProgress[model.name] != nil { - VStack { - ProgressView(value: whisperState.downloadProgress[model.name] ?? 0) - .progressViewStyle(LinearProgressViewStyle()) - .animation(.linear, value: whisperState.downloadProgress[model.name]) - Text("\(Int((whisperState.downloadProgress[model.name] ?? 0) * 100))%") - .font(.caption) - .animation(.none) - } - } else { - Button("Download Model") { - Task { - await whisperState.downloadModel(model) - } - } - .foregroundColor(.white) - } - } - .buttonStyle(GradientButtonStyle(isDownloaded: isDownloaded, isCurrent: isCurrent)) - .frame(maxWidth: .infinity) - } - - private func ramUsageLabel(gb: Double) -> some View { - VStack(alignment: .leading, spacing: 4) { - Text("RAM") - .font(.caption) - .foregroundColor(.secondary) - - Text(formatRAMSize(gb)) - .font(.system(size: 12, weight: .bold)) - .foregroundColor(.primary) - } - } - - private func formatRAMSize(_ gb: Double) -> String { - if gb >= 1.0 { - return String(format: "%.1f GB", gb) - } else { - return String(format: "%d MB", Int(gb * 1024)) - } - } -} - -struct GradientButtonStyle: ButtonStyle { - let isDownloaded: Bool - let isCurrent: Bool - - func makeBody(configuration: Configuration) -> some View { - configuration.label - .padding(.vertical, 5) - .padding(.horizontal, 10) - .background( - Group { - if isCurrent { - LinearGradient(gradient: Gradient(colors: [Color.green, Color.green.opacity(0.7)]), startPoint: .top, endPoint: .bottom) - } else if isDownloaded { - LinearGradient(gradient: Gradient(colors: [Color.purple, Color.purple.opacity(0.7)]), startPoint: .top, endPoint: .bottom) - } else { - LinearGradient(gradient: Gradient(colors: [Color.blue, Color.blue.opacity(0.7)]), startPoint: .top, endPoint: .bottom) - } - } - ) - .cornerRadius(10) - .shadow(color: Color.black.opacity(0.2), radius: 5, x: 0, y: 2) - .scaleEffect(configuration.isPressed ? 0.95 : 1) - .animation(.easeInOut(duration: 0.2), value: configuration.isPressed) } } diff --git a/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift b/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift index 0f08832..5a37f32 100644 --- a/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift +++ b/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift @@ -83,14 +83,10 @@ struct OnboardingModelDownloadView: View { // Download progress if isDownloading { - VStack(spacing: 8) { - ProgressView(value: whisperState.downloadProgress[turboModel.name] ?? 0) - .progressViewStyle(.linear) - .tint(.white) - Text("\(Int((whisperState.downloadProgress[turboModel.name] ?? 0) * 100))%") - .font(.caption) - .foregroundColor(.white.opacity(0.7)) - } + DownloadProgressView( + modelName: turboModel.name, + downloadProgress: whisperState.downloadProgress + ) .transition(.opacity) } } diff --git a/VoiceInk/Whisper/WhisperError.swift b/VoiceInk/Whisper/WhisperError.swift index fc92c4f..9fdfdf6 100644 --- a/VoiceInk/Whisper/WhisperError.swift +++ b/VoiceInk/Whisper/WhisperError.swift @@ -7,6 +7,7 @@ enum WhisperStateError: Error, Identifiable { case accessibilityPermissionDenied case modelDownloadFailed case modelDeletionFailed + case unzipFailed case unknownError var id: String { UUID().uuidString } @@ -27,6 +28,8 @@ extension WhisperStateError: LocalizedError { return "Failed to download the model." case .modelDeletionFailed: return "Failed to delete the model." + case .unzipFailed: + return "Failed to unzip the downloaded Core ML model." case .unknownError: return "An unknown error occurred." } @@ -46,6 +49,8 @@ extension WhisperStateError: LocalizedError { return "Check your internet connection and try again. If the problem persists, try a different model." case .modelDeletionFailed: return "Restart the application and try again. If the problem persists, you may need to manually delete the model file." + case .unzipFailed: + return "The downloaded Core ML model archive might be corrupted. Try deleting the model and downloading it again. Check available disk space." case .unknownError: return "Please restart the application. If the problem persists, contact support." } diff --git a/VoiceInk/Whisper/WhisperState+ModelManager.swift b/VoiceInk/Whisper/WhisperState+ModelManager.swift index d6fd310..13dd92c 100644 --- a/VoiceInk/Whisper/WhisperState+ModelManager.swift +++ b/VoiceInk/Whisper/WhisperState+ModelManager.swift @@ -1,5 +1,7 @@ import Foundation import os +import Zip +import SwiftUI // MARK: - Model Management Extension extension WhisperState { @@ -10,7 +12,7 @@ extension WhisperState { do { try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil) } catch { - messageLog += "Error creating models directory: \(error.localizedDescription)\n" + logError("Error creating models directory", error) } } @@ -22,7 +24,7 @@ extension WhisperState { return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url) } } catch { - messageLog += "Error loading available models: \(error.localizedDescription)\n" + logError("Error loading available models", error) } } @@ -31,7 +33,6 @@ extension WhisperState { func loadModel(_ model: WhisperModel) async throws { guard whisperContext == nil else { return } - logger.notice("๐Ÿ”„ Loading Whisper model: \(model.name)") isModelLoading = true defer { isModelLoading = false } @@ -39,86 +40,186 @@ extension WhisperState { whisperContext = try await WhisperContext.createContext(path: model.url.path) isModelLoaded = true currentModel = model - logger.notice("โœ… Successfully loaded model: \(model.name)") } catch { - logger.error("โŒ Failed to load model: \(model.name) - \(error.localizedDescription)") throw WhisperStateError.modelLoadFailed } } func setDefaultModel(_ model: WhisperModel) async { - do { - currentModel = model - UserDefaults.standard.set(model.name, forKey: "CurrentModel") - canTranscribe = true - } catch { - currentError = error as? WhisperStateError ?? .unknownError - canTranscribe = false - } + currentModel = model + UserDefaults.standard.set(model.name, forKey: "CurrentModel") + canTranscribe = true } // MARK: - Model Download & Management - func downloadModel(_ model: PredefinedModel) async { - guard let url = URL(string: model.downloadURL) else { return } - - logger.notice("๐Ÿ”ฝ Downloading model: \(model.name)") - do { - let (data, response) = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in - let task = URLSession.shared.dataTask(with: url) { data, response, error in - if let error = error { - continuation.resume(throwing: error) - return - } - guard let httpResponse = response as? HTTPURLResponse, - (200...299).contains(httpResponse.statusCode), - let data = data else { - continuation.resume(throwing: URLError(.badServerResponse)) - return - } - continuation.resume(returning: (data, httpResponse)) + /// Helper function to download a file from a URL with progress tracking + private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data { + let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString) + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in + if let error = error { + continuation.resume(throwing: error) + return } - task.resume() + guard let httpResponse = response as? HTTPURLResponse, + (200...299).contains(httpResponse.statusCode), + let tempURL = tempURL else { + continuation.resume(throwing: URLError(.badServerResponse)) + return + } - let observation = task.progress.observe(\.fractionCompleted) { progress, _ in + do { + // Move the downloaded file to the final destination + try FileManager.default.moveItem(at: tempURL, to: destinationURL) + + // Read the file in chunks to avoid memory pressure + let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe) + continuation.resume(returning: data) + + // Clean up the temporary file + try? FileManager.default.removeItem(at: destinationURL) + } catch { + continuation.resume(throwing: error) + } + } + + task.resume() + + var lastUpdateTime = Date() + var lastProgressValue: Double = 0 + + let observation = task.progress.observe(\.fractionCompleted) { progress, _ in + let currentTime = Date() + let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime) + let currentProgress = round(progress.fractionCompleted * 100) / 100 + + if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 { + lastUpdateTime = currentTime + lastProgressValue = currentProgress + DispatchQueue.main.async { - self.downloadProgress[model.name] = progress.fractionCompleted - } - } - - Task { - await withTaskCancellationHandler { - observation.invalidate() - } operation: { - await withCheckedContinuation { (_: CheckedContinuation) in } + self.downloadProgress[progressKey] = currentProgress } } } - - let destinationURL = modelsDirectory.appendingPathComponent(model.filename) - try data.write(to: destinationURL) - - availableModels.append(WhisperModel(name: model.name, url: destinationURL)) - self.downloadProgress.removeValue(forKey: model.name) - logger.notice("โœ… Successfully downloaded model: \(model.name)") - } catch { - logger.error("โŒ Failed to download model: \(model.name) - \(error.localizedDescription)") - currentError = .modelDownloadFailed - self.downloadProgress.removeValue(forKey: model.name) + + Task { + await withTaskCancellationHandler { + observation.invalidate() + } operation: { + await withCheckedContinuation { (_: CheckedContinuation) in } + } + } } } + func downloadModel(_ model: PredefinedModel) async { + guard let url = URL(string: model.downloadURL) else { return } + + do { + let whisperModel = try await downloadMainModel(model, from: url) + + if let coreMLZipURL = whisperModel.coreMLZipDownloadURL, + let coreMLURL = URL(string: coreMLZipURL) { + try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL) + } + + availableModels.append(whisperModel) + self.downloadProgress.removeValue(forKey: model.name + "_main") + } catch { + handleModelDownloadError(model, error) + } + } + + private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel { + let progressKeyMain = model.name + "_main" + let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain) + + let destinationURL = modelsDirectory.appendingPathComponent(model.filename) + try data.write(to: destinationURL) + + return WhisperModel(name: model.name, url: destinationURL) + } + + private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws { + let progressKeyCoreML = model.name + "_coreml" + let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML) + + let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip") + try coreMLData.write(to: coreMLZipPath) + + try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML) + } + + private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws { + let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc") + + try? FileManager.default.removeItem(at: coreMLDestination) + try await unzipCoreMLFile(zipPath, to: modelsDirectory) + try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey) + } + + private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + do { + try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true) + try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil) + continuation.resume() + } catch { + continuation.resume(throwing: error) + } + } + } + + private func verifyAndCleanupCoreMLFiles(_ model: WhisperModel, _ destination: URL, _ zipPath: URL, _ progressKey: String) throws -> WhisperModel { + var model = model + + var isDirectory: ObjCBool = false + guard FileManager.default.fileExists(atPath: destination.path, isDirectory: &isDirectory), isDirectory.boolValue else { + try? FileManager.default.removeItem(at: zipPath) + throw WhisperStateError.unzipFailed + } + + try? FileManager.default.removeItem(at: zipPath) + model.coreMLEncoderURL = destination + self.downloadProgress.removeValue(forKey: progressKey) + + return model + } + + private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) { + currentError = .modelDownloadFailed + self.downloadProgress.removeValue(forKey: model.name + "_main") + self.downloadProgress.removeValue(forKey: model.name + "_coreml") + } + func deleteModel(_ model: WhisperModel) async { do { + // Delete main model file try FileManager.default.removeItem(at: model.url) + + // Delete CoreML model if it exists + if let coreMLURL = model.coreMLEncoderURL { + try? FileManager.default.removeItem(at: coreMLURL) + } else { + // Check if there's a CoreML directory matching the model name + let coreMLDir = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc") + if FileManager.default.fileExists(atPath: coreMLDir.path) { + try? FileManager.default.removeItem(at: coreMLDir) + } + } + + // Update model state availableModels.removeAll { $0.id == model.id } if currentModel?.id == model.id { currentModel = nil canTranscribe = false } } catch { - messageLog += "Error deleting model: \(error.localizedDescription)\n" + logError("Error deleting model: \(model.name)", error) currentError = .modelDeletionFailed } } @@ -141,7 +242,7 @@ extension WhisperState { do { try FileManager.default.removeItem(at: model.url) } catch { - messageLog += "Error deleting model: \(error.localizedDescription)\n" + logError("Error deleting model during cleanup", error) } } availableModels.removeAll() @@ -151,21 +252,88 @@ extension WhisperState { func cleanupModelResources() async { recorder.stopRecording() + try? await Task.sleep(nanoseconds: 500_000_000) - // Add a small delay to ensure recording has fully stopped - try? await Task.sleep(nanoseconds: 500_000_000) // 0.5 seconds - - // Only cleanup model resources if we're not actively using them let canCleanup = !isRecording && !isProcessing if canCleanup { - logger.notice("๐Ÿงน Cleaning up Whisper resources") - // Release any resources held by the model await whisperContext?.releaseResources() whisperContext = nil isModelLoaded = false - } else { - logger.info("Skipping cleanup while recording or processing is active") } } + + // MARK: - Helper Methods + + private func logError(_ message: String, _ error: Error) { + messageLog += "\(message): \(error.localizedDescription)\n" + } +} + +// MARK: - Download Progress View +struct DownloadProgressView: View { + let modelName: String + let downloadProgress: [String: Double] + + @Environment(\.colorScheme) private var colorScheme + + private var mainProgress: Double { + downloadProgress[modelName + "_main"] ?? 0 + } + + private var coreMLProgress: Double { + supportsCoreML ? (downloadProgress[modelName + "_coreml"] ?? 0) : 0 + } + + private var supportsCoreML: Bool { + !modelName.contains("q5") && !modelName.contains("q8") + } + + private var totalProgress: Double { + supportsCoreML ? (mainProgress * 0.5) + (coreMLProgress * 0.5) : mainProgress + } + + private var downloadPhase: String { + // Check if we're currently downloading the CoreML model + if supportsCoreML && downloadProgress[modelName + "_coreml"] != nil { + return "Downloading Core ML Model for \(modelName)" + } + // Otherwise, we're downloading the main model + return "Downloading \(modelName) Model" + } + + var body: some View { + VStack(alignment: .leading, spacing: 8) { + // Status text with clean typography + Text(downloadPhase) + .font(.system(size: 12, weight: .medium)) + .foregroundColor(Color(.secondaryLabelColor)) + + // Clean progress bar + GeometryReader { geometry in + ZStack(alignment: .leading) { + // Background track + RoundedRectangle(cornerRadius: 4) + .fill(Color(.separatorColor).opacity(0.3)) + .frame(height: 6) + + // Progress indicator + RoundedRectangle(cornerRadius: 4) + .fill(Color(.controlAccentColor)) + .frame(width: max(0, min(geometry.size.width * totalProgress, geometry.size.width)), height: 6) + } + } + .frame(height: 6) + + // Percentage indicator in Apple style + HStack { + Spacer() + Text("\(Int(totalProgress * 100))%") + .font(.system(size: 11, weight: .medium, design: .monospaced)) + .foregroundColor(Color(.secondaryLabelColor)) + } + } + .padding(.vertical, 4) + .animation(.smooth, value: totalProgress) + } } \ No newline at end of file diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index 553ea35..d66d28f 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -438,12 +438,28 @@ struct WhisperModel: Identifiable { let id = UUID() let name: String let url: URL + var coreMLEncoderURL: URL? // Path to the unzipped .mlmodelc directory + var isCoreMLDownloaded: Bool { coreMLEncoderURL != nil } + var downloadURL: String { "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)" } + var filename: String { "\(name).bin" } + + // Core ML related properties + var coreMLZipDownloadURL: String? { + // Only non-quantized models have Core ML versions + guard !name.contains("q5") && !name.contains("q8") else { return nil } + return "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(name)-encoder.mlmodelc.zip" + } + + var coreMLEncoderDirectoryName: String? { + guard coreMLZipDownloadURL != nil else { return nil } + return "\(name)-encoder.mlmodelc" + } } private class TaskDelegate: NSObject, URLSessionTaskDelegate {