Added support for CoreML and redesigned the AI Models view

This commit is contained in:
Beingpax 2025-04-08 18:08:19 +05:45
parent 7ae4f62ab6
commit 0923093bcd
7 changed files with 532 additions and 276 deletions

View File

@ -82,7 +82,7 @@ import Foundation
size: "2.9 GiB", size: "2.9 GiB",
supportedLanguages: getLanguageDictionary(isMultilingual: true), supportedLanguages: getLanguageDictionary(isMultilingual: true),
description: "Large model v2, slower than Medium but more accurate, supports multiple languages", description: "Large model v2, slower than Medium but more accurate, supports multiple languages",
speed: 0.25, speed: 0.5,
accuracy: 0.96, accuracy: 0.96,
ramUsage: 3.8, ramUsage: 3.8,
hash: "0f4c8e30f21cf1769f637135f521436792c48186" hash: "0f4c8e30f21cf1769f637135f521436792c48186"
@ -93,7 +93,7 @@ import Foundation
size: "2.9 GiB", size: "2.9 GiB",
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
description: "Large model v3, very slow but most accurate, supports multiple languages", description: "Large model v3, very slow but most accurate, supports multiple languages",
speed: 0.2, speed: 0.5,
accuracy: 0.98, accuracy: 0.98,
ramUsage: 3.9, ramUsage: 3.9,
hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062" hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062"
@ -105,7 +105,7 @@ import Foundation
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
description: description:
"Large model v3 Turbo, faster than v3 with similar accuracy, supports multiple languages", "Large model v3 Turbo, faster than v3 with similar accuracy, supports multiple languages",
speed: 0.5, speed: 0.7,
accuracy: 0.97, accuracy: 0.97,
ramUsage: 1.8, ramUsage: 1.8,
hash: "4af2b29d7ec73d781377bfd1758ca957a807e941" hash: "4af2b29d7ec73d781377bfd1758ca957a807e941"
@ -116,7 +116,7 @@ import Foundation
size: "547 MiB", size: "547 MiB",
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true), supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
description: "Quantized version of Large v3 Turbo, faster with slightly lower accuracy", description: "Quantized version of Large v3 Turbo, faster with slightly lower accuracy",
speed: 0.6, speed: 0.7,
accuracy: 0.96, accuracy: 0.96,
ramUsage: 1.0, ramUsage: 1.0,
hash: "e050f7970618a659205450ad97eb95a18d69c9ee" hash: "e050f7970618a659205450ad97eb95a18d69c9ee"

View File

@ -0,0 +1,253 @@
import SwiftUI
import AppKit
struct ModelCardRowView: View {
let model: PredefinedModel
let isDownloaded: Bool
let isCurrent: Bool
let downloadProgress: [String: Double]
let modelURL: URL?
// Actions
var deleteAction: () -> Void
var setDefaultAction: () -> Void
var downloadAction: () -> Void
private var isDownloading: Bool {
downloadProgress.keys.contains(model.name + "_main") ||
downloadProgress.keys.contains(model.name + "_coreml")
}
var body: some View {
HStack(alignment: .top, spacing: 16) {
// Main Content
VStack(alignment: .leading, spacing: 6) {
headerSection
metadataSection
descriptionSection
progressSection
}
.frame(maxWidth: .infinity, alignment: .leading)
// Action Controls
actionSection
}
.padding(16)
.background(
RoundedRectangle(cornerRadius: 14)
.fill(
LinearGradient(
gradient: Gradient(colors: [
Color(NSColor.controlBackgroundColor).opacity(1.0),
Color(NSColor.controlBackgroundColor).opacity(0.6)
]),
startPoint: .topLeading,
endPoint: .bottomTrailing
)
)
.overlay(
RoundedRectangle(cornerRadius: 14)
.stroke(
LinearGradient(
gradient: Gradient(colors: [
Color.white.opacity(0.3),
Color.white.opacity(0.05)
]),
startPoint: .topLeading,
endPoint: .bottomTrailing
),
lineWidth: 1
)
)
.shadow(
color: Color.primary.opacity(0.08),
radius: 8,
x: 0,
y: 4
)
)
}
// MARK: - Components
private var headerSection: some View {
HStack(alignment: .firstTextBaseline) {
Text(model.displayName)
.font(.system(size: 13, weight: .semibold))
.foregroundColor(Color(.labelColor))
statusBadge
Spacer()
}
}
private var statusBadge: some View {
Group {
if isCurrent {
Text("Default")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color.accentColor))
.foregroundColor(.white)
} else if isDownloaded {
Text("Downloaded")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color(.quaternaryLabelColor)))
.foregroundColor(Color(.labelColor))
}
}
}
private var metadataSection: some View {
HStack(spacing: 16) {
// Language
Label(model.language, systemImage: "globe")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
// Size
Label(model.size, systemImage: "internaldrive")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
// Speed
HStack(spacing: 4) {
Text("Speed")
.font(.system(size: 11, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
progressDotsWithNumber(value: model.speed * 10)
}
// Accuracy
HStack(spacing: 4) {
Text("Accuracy")
.font(.system(size: 11, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
progressDotsWithNumber(value: model.accuracy * 10)
}
}
}
private var descriptionSection: some View {
Text(model.description)
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(2)
.fixedSize(horizontal: false, vertical: true)
.padding(.top, 4)
}
private var progressSection: some View {
Group {
if isDownloading {
DownloadProgressView(
modelName: model.name,
downloadProgress: downloadProgress
)
.padding(.top, 8)
.frame(maxWidth: .infinity, alignment: .leading)
}
}
}
private var actionSection: some View {
HStack(spacing: 8) {
if isCurrent {
Text("Default Model")
.font(.system(size: 12))
.foregroundColor(Color(.secondaryLabelColor))
} else if isDownloaded {
Button(action: setDefaultAction) {
Text("Set as Default")
.font(.system(size: 12))
}
.buttonStyle(.bordered)
.controlSize(.small)
} else {
Button(action: downloadAction) {
HStack(spacing: 4) {
Text(isDownloading ? "Downloading..." : "Download")
.font(.system(size: 12, weight: .medium))
Image(systemName: "arrow.down.circle")
.font(.system(size: 12, weight: .medium))
}
.foregroundColor(.white)
.padding(.horizontal, 12)
.padding(.vertical, 6)
.background(
Capsule()
.fill(Color(.controlAccentColor))
.shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1)
)
}
.buttonStyle(.plain)
.disabled(isDownloading)
}
if isDownloaded {
Menu {
Button(action: deleteAction) {
Label("Delete Model", systemImage: "trash")
}
if isDownloaded {
Button {
if let modelURL = modelURL {
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
}
} label: {
Label("Show in Finder", systemImage: "folder")
}
}
} label: {
Image(systemName: "ellipsis.circle")
.font(.system(size: 14))
}
.menuStyle(.borderlessButton)
.menuIndicator(.hidden)
.frame(width: 20, height: 20)
}
}
}
// MARK: - Helpers
private var downloadComponents: [(String, Double)] {
[
("Model", downloadProgress[model.name + "_main"] ?? 0),
("CoreML", downloadProgress[model.name + "_coreml"] ?? 0)
].filter { $0.1 > 0 }
}
private func progressDotsWithNumber(value: Double) -> some View {
HStack(spacing: 4) {
progressDots(value: value)
Text(String(format: "%.1f", value))
.font(.system(size: 10, weight: .medium, design: .monospaced))
.foregroundColor(Color(.secondaryLabelColor))
}
}
private func progressDots(value: Double) -> some View {
HStack(spacing: 2) {
ForEach(0..<5) { index in
Circle()
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
.frame(width: 6, height: 6)
}
}
}
private func performanceColor(value: Double) -> Color {
switch value {
case 0.8...1.0: return Color(.systemGreen)
case 0.6..<0.8: return Color(.systemYellow)
case 0.4..<0.6: return Color(.systemOrange)
default: return Color(.systemRed)
}
}
}

View File

@ -68,217 +68,35 @@ struct ModelManagementView: View {
Spacer() Spacer()
} }
LazyVGrid(columns: [GridItem(.adaptive(minimum: 300, maximum: 400), spacing: 16)], spacing: 16) { VStack(spacing: 12) {
ForEach(whisperState.predefinedModels) { model in ForEach(whisperState.predefinedModels) { model in
modelCard(for: model) ModelCardRowView(
} model: model,
} isDownloaded: whisperState.availableModels.contains { $0.name == model.name },
} isCurrent: whisperState.currentModel?.name == model.name,
.padding() downloadProgress: whisperState.downloadProgress,
.background(Color(.windowBackgroundColor).opacity(0.4)) modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
.cornerRadius(10) deleteAction: {
}
private func modelCard(for model: PredefinedModel) -> some View {
let isDownloaded = whisperState.availableModels.contains { $0.name == model.name }
let isCurrent = whisperState.currentModel?.name == model.name
return VStack(alignment: .leading, spacing: 12) {
// Model name and details
HStack {
VStack(alignment: .leading, spacing: 4) {
Text(model.displayName)
.font(.headline)
Text("\(model.size)\(model.language)")
.font(.caption)
.foregroundColor(.secondary)
}
Spacer()
modelStatusBadge(isDownloaded: isDownloaded, isCurrent: isCurrent)
}
// Description
Text(model.description)
.font(.subheadline)
.foregroundColor(.secondary)
.lineLimit(2)
// Performance indicators
HStack(spacing: 16) {
performanceIndicator(label: "Speed", value: model.speed)
performanceIndicator(label: "Accuracy", value: model.accuracy)
ramUsageLabel(gb: model.ramUsage)
}
// Action buttons
HStack {
modelActionButton(isDownloaded: isDownloaded, isCurrent: isCurrent, model: model)
if isDownloaded {
Menu {
Button(action: {
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
modelToDelete = downloadedModel modelToDelete = downloadedModel
} }
}) { },
Label("Delete", systemImage: "trash") setDefaultAction: {
}
Button(action: {
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
NSWorkspace.shared.selectFile(downloadedModel.url.path, inFileViewerRootedAtPath: "") Task {
await whisperState.setDefaultModel(downloadedModel)
}
}
},
downloadAction: {
Task {
await whisperState.downloadModel(model)
} }
}) {
Label("Show in Finder", systemImage: "folder")
} }
} label: { )
Image(systemName: "ellipsis.circle")
.foregroundColor(.secondary)
}
.menuStyle(BorderlessButtonMenuStyle())
.frame(width: 30, height: 30)
} }
} }
} }
.padding() .padding()
.background(Color(.windowBackgroundColor).opacity(0.9))
.cornerRadius(10)
.overlay(
RoundedRectangle(cornerRadius: 10)
.stroke(isCurrent ? Color.accentColor : Color.gray.opacity(0.2), lineWidth: isCurrent ? 2 : 1)
)
}
private func modelStatusBadge(isDownloaded: Bool, isCurrent: Bool) -> some View {
Group {
if isCurrent {
Text("Default")
.font(.caption)
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(Color.green)
.foregroundColor(.white)
.cornerRadius(8)
} else if isDownloaded {
Text("Downloaded")
.font(.caption)
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(Color.indigo)
.foregroundColor(.white)
.cornerRadius(8)
}
}
}
private func performanceIndicator(label: String, value: Double) -> some View {
VStack(alignment: .leading, spacing: 4) {
Text(label)
.font(.caption)
.foregroundColor(.secondary)
HStack(spacing: 2) {
ForEach(0..<5) { index in
RoundedRectangle(cornerRadius: 2)
.fill(index < Int(value * 5) ? performanceColor(value: value) : Color.secondary.opacity(0.2))
.frame(width: 16, height: 8)
}
}
Text(String(format: "%.1f", value * 10))
.font(.caption)
.foregroundColor(.secondary)
}
}
private func performanceColor(value: Double) -> Color {
switch value {
case 0.8...: return .green
case 0.6..<0.8: return .yellow
case 0.4..<0.6: return .orange
default: return .red
}
}
private func modelActionButton(isDownloaded: Bool, isCurrent: Bool, model: PredefinedModel) -> some View {
Group {
if isCurrent {
Text("Default Model")
.foregroundColor(.white)
} else if isDownloaded {
Button("Set as Default") {
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
Task {
await whisperState.setDefaultModel(downloadedModel)
}
}
}
.foregroundColor(.white)
} else if whisperState.downloadProgress[model.name] != nil {
VStack {
ProgressView(value: whisperState.downloadProgress[model.name] ?? 0)
.progressViewStyle(LinearProgressViewStyle())
.animation(.linear, value: whisperState.downloadProgress[model.name])
Text("\(Int((whisperState.downloadProgress[model.name] ?? 0) * 100))%")
.font(.caption)
.animation(.none)
}
} else {
Button("Download Model") {
Task {
await whisperState.downloadModel(model)
}
}
.foregroundColor(.white)
}
}
.buttonStyle(GradientButtonStyle(isDownloaded: isDownloaded, isCurrent: isCurrent))
.frame(maxWidth: .infinity)
}
private func ramUsageLabel(gb: Double) -> some View {
VStack(alignment: .leading, spacing: 4) {
Text("RAM")
.font(.caption)
.foregroundColor(.secondary)
Text(formatRAMSize(gb))
.font(.system(size: 12, weight: .bold))
.foregroundColor(.primary)
}
}
private func formatRAMSize(_ gb: Double) -> String {
if gb >= 1.0 {
return String(format: "%.1f GB", gb)
} else {
return String(format: "%d MB", Int(gb * 1024))
}
}
}
struct GradientButtonStyle: ButtonStyle {
let isDownloaded: Bool
let isCurrent: Bool
func makeBody(configuration: Configuration) -> some View {
configuration.label
.padding(.vertical, 5)
.padding(.horizontal, 10)
.background(
Group {
if isCurrent {
LinearGradient(gradient: Gradient(colors: [Color.green, Color.green.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
} else if isDownloaded {
LinearGradient(gradient: Gradient(colors: [Color.purple, Color.purple.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
} else {
LinearGradient(gradient: Gradient(colors: [Color.blue, Color.blue.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
}
}
)
.cornerRadius(10)
.shadow(color: Color.black.opacity(0.2), radius: 5, x: 0, y: 2)
.scaleEffect(configuration.isPressed ? 0.95 : 1)
.animation(.easeInOut(duration: 0.2), value: configuration.isPressed)
} }
} }

View File

@ -83,14 +83,10 @@ struct OnboardingModelDownloadView: View {
// Download progress // Download progress
if isDownloading { if isDownloading {
VStack(spacing: 8) { DownloadProgressView(
ProgressView(value: whisperState.downloadProgress[turboModel.name] ?? 0) modelName: turboModel.name,
.progressViewStyle(.linear) downloadProgress: whisperState.downloadProgress
.tint(.white) )
Text("\(Int((whisperState.downloadProgress[turboModel.name] ?? 0) * 100))%")
.font(.caption)
.foregroundColor(.white.opacity(0.7))
}
.transition(.opacity) .transition(.opacity)
} }
} }

View File

@ -7,6 +7,7 @@ enum WhisperStateError: Error, Identifiable {
case accessibilityPermissionDenied case accessibilityPermissionDenied
case modelDownloadFailed case modelDownloadFailed
case modelDeletionFailed case modelDeletionFailed
case unzipFailed
case unknownError case unknownError
var id: String { UUID().uuidString } var id: String { UUID().uuidString }
@ -27,6 +28,8 @@ extension WhisperStateError: LocalizedError {
return "Failed to download the model." return "Failed to download the model."
case .modelDeletionFailed: case .modelDeletionFailed:
return "Failed to delete the model." return "Failed to delete the model."
case .unzipFailed:
return "Failed to unzip the downloaded Core ML model."
case .unknownError: case .unknownError:
return "An unknown error occurred." return "An unknown error occurred."
} }
@ -46,6 +49,8 @@ extension WhisperStateError: LocalizedError {
return "Check your internet connection and try again. If the problem persists, try a different model." return "Check your internet connection and try again. If the problem persists, try a different model."
case .modelDeletionFailed: case .modelDeletionFailed:
return "Restart the application and try again. If the problem persists, you may need to manually delete the model file." return "Restart the application and try again. If the problem persists, you may need to manually delete the model file."
case .unzipFailed:
return "The downloaded Core ML model archive might be corrupted. Try deleting the model and downloading it again. Check available disk space."
case .unknownError: case .unknownError:
return "Please restart the application. If the problem persists, contact support." return "Please restart the application. If the problem persists, contact support."
} }

View File

@ -1,5 +1,7 @@
import Foundation import Foundation
import os import os
import Zip
import SwiftUI
// MARK: - Model Management Extension // MARK: - Model Management Extension
extension WhisperState { extension WhisperState {
@ -10,7 +12,7 @@ extension WhisperState {
do { do {
try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil) try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil)
} catch { } catch {
messageLog += "Error creating models directory: \(error.localizedDescription)\n" logError("Error creating models directory", error)
} }
} }
@ -22,7 +24,7 @@ extension WhisperState {
return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url) return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url)
} }
} catch { } catch {
messageLog += "Error loading available models: \(error.localizedDescription)\n" logError("Error loading available models", error)
} }
} }
@ -31,7 +33,6 @@ extension WhisperState {
func loadModel(_ model: WhisperModel) async throws { func loadModel(_ model: WhisperModel) async throws {
guard whisperContext == nil else { return } guard whisperContext == nil else { return }
logger.notice("🔄 Loading Whisper model: \(model.name)")
isModelLoading = true isModelLoading = true
defer { isModelLoading = false } defer { isModelLoading = false }
@ -39,86 +40,186 @@ extension WhisperState {
whisperContext = try await WhisperContext.createContext(path: model.url.path) whisperContext = try await WhisperContext.createContext(path: model.url.path)
isModelLoaded = true isModelLoaded = true
currentModel = model currentModel = model
logger.notice("✅ Successfully loaded model: \(model.name)")
} catch { } catch {
logger.error("❌ Failed to load model: \(model.name) - \(error.localizedDescription)")
throw WhisperStateError.modelLoadFailed throw WhisperStateError.modelLoadFailed
} }
} }
func setDefaultModel(_ model: WhisperModel) async { func setDefaultModel(_ model: WhisperModel) async {
do { currentModel = model
currentModel = model UserDefaults.standard.set(model.name, forKey: "CurrentModel")
UserDefaults.standard.set(model.name, forKey: "CurrentModel") canTranscribe = true
canTranscribe = true
} catch {
currentError = error as? WhisperStateError ?? .unknownError
canTranscribe = false
}
} }
// MARK: - Model Download & Management // MARK: - Model Download & Management
func downloadModel(_ model: PredefinedModel) async { /// Helper function to download a file from a URL with progress tracking
guard let url = URL(string: model.downloadURL) else { return } private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data {
let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString)
logger.notice("🔽 Downloading model: \(model.name)")
do { return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Data, Error>) in
let (data, response) = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in
let task = URLSession.shared.dataTask(with: url) { data, response, error in if let error = error {
if let error = error { continuation.resume(throwing: error)
continuation.resume(throwing: error) return
return
}
guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode),
let data = data else {
continuation.resume(throwing: URLError(.badServerResponse))
return
}
continuation.resume(returning: (data, httpResponse))
} }
task.resume() guard let httpResponse = response as? HTTPURLResponse,
(200...299).contains(httpResponse.statusCode),
let tempURL = tempURL else {
continuation.resume(throwing: URLError(.badServerResponse))
return
}
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in do {
// Move the downloaded file to the final destination
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
// Read the file in chunks to avoid memory pressure
let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe)
continuation.resume(returning: data)
// Clean up the temporary file
try? FileManager.default.removeItem(at: destinationURL)
} catch {
continuation.resume(throwing: error)
}
}
task.resume()
var lastUpdateTime = Date()
var lastProgressValue: Double = 0
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in
let currentTime = Date()
let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime)
let currentProgress = round(progress.fractionCompleted * 100) / 100
if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 {
lastUpdateTime = currentTime
lastProgressValue = currentProgress
DispatchQueue.main.async { DispatchQueue.main.async {
self.downloadProgress[model.name] = progress.fractionCompleted self.downloadProgress[progressKey] = currentProgress
}
}
Task {
await withTaskCancellationHandler {
observation.invalidate()
} operation: {
await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
} }
} }
} }
let destinationURL = modelsDirectory.appendingPathComponent(model.filename) Task {
try data.write(to: destinationURL) await withTaskCancellationHandler {
observation.invalidate()
availableModels.append(WhisperModel(name: model.name, url: destinationURL)) } operation: {
self.downloadProgress.removeValue(forKey: model.name) await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
logger.notice("✅ Successfully downloaded model: \(model.name)") }
} catch { }
logger.error("❌ Failed to download model: \(model.name) - \(error.localizedDescription)")
currentError = .modelDownloadFailed
self.downloadProgress.removeValue(forKey: model.name)
} }
} }
func downloadModel(_ model: PredefinedModel) async {
guard let url = URL(string: model.downloadURL) else { return }
do {
let whisperModel = try await downloadMainModel(model, from: url)
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
let coreMLURL = URL(string: coreMLZipURL) {
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
}
availableModels.append(whisperModel)
self.downloadProgress.removeValue(forKey: model.name + "_main")
} catch {
handleModelDownloadError(model, error)
}
}
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
let progressKeyMain = model.name + "_main"
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
let destinationURL = modelsDirectory.appendingPathComponent(model.filename)
try data.write(to: destinationURL)
return WhisperModel(name: model.name, url: destinationURL)
}
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
let progressKeyCoreML = model.name + "_coreml"
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
try coreMLData.write(to: coreMLZipPath)
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
}
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
try? FileManager.default.removeItem(at: coreMLDestination)
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
}
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
do {
try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true)
try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil)
continuation.resume()
} catch {
continuation.resume(throwing: error)
}
}
}
private func verifyAndCleanupCoreMLFiles(_ model: WhisperModel, _ destination: URL, _ zipPath: URL, _ progressKey: String) throws -> WhisperModel {
var model = model
var isDirectory: ObjCBool = false
guard FileManager.default.fileExists(atPath: destination.path, isDirectory: &isDirectory), isDirectory.boolValue else {
try? FileManager.default.removeItem(at: zipPath)
throw WhisperStateError.unzipFailed
}
try? FileManager.default.removeItem(at: zipPath)
model.coreMLEncoderURL = destination
self.downloadProgress.removeValue(forKey: progressKey)
return model
}
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
currentError = .modelDownloadFailed
self.downloadProgress.removeValue(forKey: model.name + "_main")
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
}
func deleteModel(_ model: WhisperModel) async { func deleteModel(_ model: WhisperModel) async {
do { do {
// Delete main model file
try FileManager.default.removeItem(at: model.url) try FileManager.default.removeItem(at: model.url)
// Delete CoreML model if it exists
if let coreMLURL = model.coreMLEncoderURL {
try? FileManager.default.removeItem(at: coreMLURL)
} else {
// Check if there's a CoreML directory matching the model name
let coreMLDir = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
if FileManager.default.fileExists(atPath: coreMLDir.path) {
try? FileManager.default.removeItem(at: coreMLDir)
}
}
// Update model state
availableModels.removeAll { $0.id == model.id } availableModels.removeAll { $0.id == model.id }
if currentModel?.id == model.id { if currentModel?.id == model.id {
currentModel = nil currentModel = nil
canTranscribe = false canTranscribe = false
} }
} catch { } catch {
messageLog += "Error deleting model: \(error.localizedDescription)\n" logError("Error deleting model: \(model.name)", error)
currentError = .modelDeletionFailed currentError = .modelDeletionFailed
} }
} }
@ -141,7 +242,7 @@ extension WhisperState {
do { do {
try FileManager.default.removeItem(at: model.url) try FileManager.default.removeItem(at: model.url)
} catch { } catch {
messageLog += "Error deleting model: \(error.localizedDescription)\n" logError("Error deleting model during cleanup", error)
} }
} }
availableModels.removeAll() availableModels.removeAll()
@ -151,21 +252,88 @@ extension WhisperState {
func cleanupModelResources() async { func cleanupModelResources() async {
recorder.stopRecording() recorder.stopRecording()
try? await Task.sleep(nanoseconds: 500_000_000)
// Add a small delay to ensure recording has fully stopped
try? await Task.sleep(nanoseconds: 500_000_000) // 0.5 seconds
// Only cleanup model resources if we're not actively using them
let canCleanup = !isRecording && !isProcessing let canCleanup = !isRecording && !isProcessing
if canCleanup { if canCleanup {
logger.notice("🧹 Cleaning up Whisper resources")
// Release any resources held by the model
await whisperContext?.releaseResources() await whisperContext?.releaseResources()
whisperContext = nil whisperContext = nil
isModelLoaded = false isModelLoaded = false
} else {
logger.info("Skipping cleanup while recording or processing is active")
} }
} }
// MARK: - Helper Methods
private func logError(_ message: String, _ error: Error) {
messageLog += "\(message): \(error.localizedDescription)\n"
}
}
// MARK: - Download Progress View
struct DownloadProgressView: View {
let modelName: String
let downloadProgress: [String: Double]
@Environment(\.colorScheme) private var colorScheme
private var mainProgress: Double {
downloadProgress[modelName + "_main"] ?? 0
}
private var coreMLProgress: Double {
supportsCoreML ? (downloadProgress[modelName + "_coreml"] ?? 0) : 0
}
private var supportsCoreML: Bool {
!modelName.contains("q5") && !modelName.contains("q8")
}
private var totalProgress: Double {
supportsCoreML ? (mainProgress * 0.5) + (coreMLProgress * 0.5) : mainProgress
}
private var downloadPhase: String {
// Check if we're currently downloading the CoreML model
if supportsCoreML && downloadProgress[modelName + "_coreml"] != nil {
return "Downloading Core ML Model for \(modelName)"
}
// Otherwise, we're downloading the main model
return "Downloading \(modelName) Model"
}
var body: some View {
VStack(alignment: .leading, spacing: 8) {
// Status text with clean typography
Text(downloadPhase)
.font(.system(size: 12, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
// Clean progress bar
GeometryReader { geometry in
ZStack(alignment: .leading) {
// Background track
RoundedRectangle(cornerRadius: 4)
.fill(Color(.separatorColor).opacity(0.3))
.frame(height: 6)
// Progress indicator
RoundedRectangle(cornerRadius: 4)
.fill(Color(.controlAccentColor))
.frame(width: max(0, min(geometry.size.width * totalProgress, geometry.size.width)), height: 6)
}
}
.frame(height: 6)
// Percentage indicator in Apple style
HStack {
Spacer()
Text("\(Int(totalProgress * 100))%")
.font(.system(size: 11, weight: .medium, design: .monospaced))
.foregroundColor(Color(.secondaryLabelColor))
}
}
.padding(.vertical, 4)
.animation(.smooth, value: totalProgress)
}
} }

View File

@ -438,12 +438,28 @@ struct WhisperModel: Identifiable {
let id = UUID() let id = UUID()
let name: String let name: String
let url: URL let url: URL
var coreMLEncoderURL: URL? // Path to the unzipped .mlmodelc directory
var isCoreMLDownloaded: Bool { coreMLEncoderURL != nil }
var downloadURL: String { var downloadURL: String {
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)" "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)"
} }
var filename: String { var filename: String {
"\(name).bin" "\(name).bin"
} }
// Core ML related properties
var coreMLZipDownloadURL: String? {
// Only non-quantized models have Core ML versions
guard !name.contains("q5") && !name.contains("q8") else { return nil }
return "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(name)-encoder.mlmodelc.zip"
}
var coreMLEncoderDirectoryName: String? {
guard coreMLZipDownloadURL != nil else { return nil }
return "\(name)-encoder.mlmodelc"
}
} }
private class TaskDelegate: NSObject, URLSessionTaskDelegate { private class TaskDelegate: NSObject, URLSessionTaskDelegate {