Added support for CoreML and redesigned the AI Models view
This commit is contained in:
parent
7ae4f62ab6
commit
0923093bcd
@ -82,7 +82,7 @@ import Foundation
|
||||
size: "2.9 GiB",
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true),
|
||||
description: "Large model v2, slower than Medium but more accurate, supports multiple languages",
|
||||
speed: 0.25,
|
||||
speed: 0.5,
|
||||
accuracy: 0.96,
|
||||
ramUsage: 3.8,
|
||||
hash: "0f4c8e30f21cf1769f637135f521436792c48186"
|
||||
@ -93,7 +93,7 @@ import Foundation
|
||||
size: "2.9 GiB",
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
|
||||
description: "Large model v3, very slow but most accurate, supports multiple languages",
|
||||
speed: 0.2,
|
||||
speed: 0.5,
|
||||
accuracy: 0.98,
|
||||
ramUsage: 3.9,
|
||||
hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062"
|
||||
@ -105,7 +105,7 @@ import Foundation
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
|
||||
description:
|
||||
"Large model v3 Turbo, faster than v3 with similar accuracy, supports multiple languages",
|
||||
speed: 0.5,
|
||||
speed: 0.7,
|
||||
accuracy: 0.97,
|
||||
ramUsage: 1.8,
|
||||
hash: "4af2b29d7ec73d781377bfd1758ca957a807e941"
|
||||
@ -116,7 +116,7 @@ import Foundation
|
||||
size: "547 MiB",
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true),
|
||||
description: "Quantized version of Large v3 Turbo, faster with slightly lower accuracy",
|
||||
speed: 0.6,
|
||||
speed: 0.7,
|
||||
accuracy: 0.96,
|
||||
ramUsage: 1.0,
|
||||
hash: "e050f7970618a659205450ad97eb95a18d69c9ee"
|
||||
|
||||
253
VoiceInk/Views/ModelCardRowView.swift
Normal file
253
VoiceInk/Views/ModelCardRowView.swift
Normal file
@ -0,0 +1,253 @@
|
||||
import SwiftUI
|
||||
import AppKit
|
||||
|
||||
struct ModelCardRowView: View {
|
||||
let model: PredefinedModel
|
||||
let isDownloaded: Bool
|
||||
let isCurrent: Bool
|
||||
let downloadProgress: [String: Double]
|
||||
let modelURL: URL?
|
||||
|
||||
// Actions
|
||||
var deleteAction: () -> Void
|
||||
var setDefaultAction: () -> Void
|
||||
var downloadAction: () -> Void
|
||||
|
||||
private var isDownloading: Bool {
|
||||
downloadProgress.keys.contains(model.name + "_main") ||
|
||||
downloadProgress.keys.contains(model.name + "_coreml")
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .top, spacing: 16) {
|
||||
// Main Content
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
headerSection
|
||||
metadataSection
|
||||
descriptionSection
|
||||
progressSection
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
|
||||
// Action Controls
|
||||
actionSection
|
||||
}
|
||||
.padding(16)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 14)
|
||||
.fill(
|
||||
LinearGradient(
|
||||
gradient: Gradient(colors: [
|
||||
Color(NSColor.controlBackgroundColor).opacity(1.0),
|
||||
Color(NSColor.controlBackgroundColor).opacity(0.6)
|
||||
]),
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
)
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 14)
|
||||
.stroke(
|
||||
LinearGradient(
|
||||
gradient: Gradient(colors: [
|
||||
Color.white.opacity(0.3),
|
||||
Color.white.opacity(0.05)
|
||||
]),
|
||||
startPoint: .topLeading,
|
||||
endPoint: .bottomTrailing
|
||||
),
|
||||
lineWidth: 1
|
||||
)
|
||||
)
|
||||
.shadow(
|
||||
color: Color.primary.opacity(0.08),
|
||||
radius: 8,
|
||||
x: 0,
|
||||
y: 4
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Components
|
||||
|
||||
private var headerSection: some View {
|
||||
HStack(alignment: .firstTextBaseline) {
|
||||
Text(model.displayName)
|
||||
.font(.system(size: 13, weight: .semibold))
|
||||
.foregroundColor(Color(.labelColor))
|
||||
|
||||
statusBadge
|
||||
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
|
||||
private var statusBadge: some View {
|
||||
Group {
|
||||
if isCurrent {
|
||||
Text("Default")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color.accentColor))
|
||||
.foregroundColor(.white)
|
||||
} else if isDownloaded {
|
||||
Text("Downloaded")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color(.quaternaryLabelColor)))
|
||||
.foregroundColor(Color(.labelColor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var metadataSection: some View {
|
||||
HStack(spacing: 16) {
|
||||
// Language
|
||||
Label(model.language, systemImage: "globe")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
|
||||
// Size
|
||||
Label(model.size, systemImage: "internaldrive")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
|
||||
// Speed
|
||||
HStack(spacing: 4) {
|
||||
Text("Speed")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
progressDotsWithNumber(value: model.speed * 10)
|
||||
}
|
||||
|
||||
// Accuracy
|
||||
HStack(spacing: 4) {
|
||||
Text("Accuracy")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
progressDotsWithNumber(value: model.accuracy * 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var descriptionSection: some View {
|
||||
Text(model.description)
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(2)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
.padding(.top, 4)
|
||||
}
|
||||
|
||||
private var progressSection: some View {
|
||||
Group {
|
||||
if isDownloading {
|
||||
DownloadProgressView(
|
||||
modelName: model.name,
|
||||
downloadProgress: downloadProgress
|
||||
)
|
||||
.padding(.top, 8)
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var actionSection: some View {
|
||||
HStack(spacing: 8) {
|
||||
if isCurrent {
|
||||
Text("Default Model")
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
} else if isDownloaded {
|
||||
Button(action: setDefaultAction) {
|
||||
Text("Set as Default")
|
||||
.font(.system(size: 12))
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
} else {
|
||||
Button(action: downloadAction) {
|
||||
HStack(spacing: 4) {
|
||||
Text(isDownloading ? "Downloading..." : "Download")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
Image(systemName: "arrow.down.circle")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 6)
|
||||
.background(
|
||||
Capsule()
|
||||
.fill(Color(.controlAccentColor))
|
||||
.shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1)
|
||||
)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(isDownloading)
|
||||
}
|
||||
|
||||
if isDownloaded {
|
||||
Menu {
|
||||
Button(action: deleteAction) {
|
||||
Label("Delete Model", systemImage: "trash")
|
||||
}
|
||||
|
||||
if isDownloaded {
|
||||
Button {
|
||||
if let modelURL = modelURL {
|
||||
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
|
||||
}
|
||||
} label: {
|
||||
Label("Show in Finder", systemImage: "folder")
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
Image(systemName: "ellipsis.circle")
|
||||
.font(.system(size: 14))
|
||||
}
|
||||
.menuStyle(.borderlessButton)
|
||||
.menuIndicator(.hidden)
|
||||
.frame(width: 20, height: 20)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private var downloadComponents: [(String, Double)] {
|
||||
[
|
||||
("Model", downloadProgress[model.name + "_main"] ?? 0),
|
||||
("CoreML", downloadProgress[model.name + "_coreml"] ?? 0)
|
||||
].filter { $0.1 > 0 }
|
||||
}
|
||||
|
||||
private func progressDotsWithNumber(value: Double) -> some View {
|
||||
HStack(spacing: 4) {
|
||||
progressDots(value: value)
|
||||
Text(String(format: "%.1f", value))
|
||||
.font(.system(size: 10, weight: .medium, design: .monospaced))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
}
|
||||
}
|
||||
|
||||
private func progressDots(value: Double) -> some View {
|
||||
HStack(spacing: 2) {
|
||||
ForEach(0..<5) { index in
|
||||
Circle()
|
||||
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
|
||||
.frame(width: 6, height: 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func performanceColor(value: Double) -> Color {
|
||||
switch value {
|
||||
case 0.8...1.0: return Color(.systemGreen)
|
||||
case 0.6..<0.8: return Color(.systemYellow)
|
||||
case 0.4..<0.6: return Color(.systemOrange)
|
||||
default: return Color(.systemRed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -68,217 +68,35 @@ struct ModelManagementView: View {
|
||||
Spacer()
|
||||
}
|
||||
|
||||
LazyVGrid(columns: [GridItem(.adaptive(minimum: 300, maximum: 400), spacing: 16)], spacing: 16) {
|
||||
VStack(spacing: 12) {
|
||||
ForEach(whisperState.predefinedModels) { model in
|
||||
modelCard(for: model)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
.background(Color(.windowBackgroundColor).opacity(0.4))
|
||||
.cornerRadius(10)
|
||||
}
|
||||
|
||||
private func modelCard(for model: PredefinedModel) -> some View {
|
||||
let isDownloaded = whisperState.availableModels.contains { $0.name == model.name }
|
||||
let isCurrent = whisperState.currentModel?.name == model.name
|
||||
|
||||
return VStack(alignment: .leading, spacing: 12) {
|
||||
// Model name and details
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(model.displayName)
|
||||
.font(.headline)
|
||||
Text("\(model.size) • \(model.language)")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
Spacer()
|
||||
modelStatusBadge(isDownloaded: isDownloaded, isCurrent: isCurrent)
|
||||
}
|
||||
|
||||
// Description
|
||||
Text(model.description)
|
||||
.font(.subheadline)
|
||||
.foregroundColor(.secondary)
|
||||
.lineLimit(2)
|
||||
|
||||
// Performance indicators
|
||||
HStack(spacing: 16) {
|
||||
performanceIndicator(label: "Speed", value: model.speed)
|
||||
performanceIndicator(label: "Accuracy", value: model.accuracy)
|
||||
ramUsageLabel(gb: model.ramUsage)
|
||||
}
|
||||
|
||||
// Action buttons
|
||||
HStack {
|
||||
modelActionButton(isDownloaded: isDownloaded, isCurrent: isCurrent, model: model)
|
||||
|
||||
if isDownloaded {
|
||||
Menu {
|
||||
Button(action: {
|
||||
ModelCardRowView(
|
||||
model: model,
|
||||
isDownloaded: whisperState.availableModels.contains { $0.name == model.name },
|
||||
isCurrent: whisperState.currentModel?.name == model.name,
|
||||
downloadProgress: whisperState.downloadProgress,
|
||||
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
|
||||
deleteAction: {
|
||||
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
modelToDelete = downloadedModel
|
||||
}
|
||||
}) {
|
||||
Label("Delete", systemImage: "trash")
|
||||
}
|
||||
|
||||
Button(action: {
|
||||
},
|
||||
setDefaultAction: {
|
||||
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
NSWorkspace.shared.selectFile(downloadedModel.url.path, inFileViewerRootedAtPath: "")
|
||||
Task {
|
||||
await whisperState.setDefaultModel(downloadedModel)
|
||||
}
|
||||
}
|
||||
},
|
||||
downloadAction: {
|
||||
Task {
|
||||
await whisperState.downloadModel(model)
|
||||
}
|
||||
}) {
|
||||
Label("Show in Finder", systemImage: "folder")
|
||||
}
|
||||
} label: {
|
||||
Image(systemName: "ellipsis.circle")
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
.menuStyle(BorderlessButtonMenuStyle())
|
||||
.frame(width: 30, height: 30)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
.background(Color(.windowBackgroundColor).opacity(0.9))
|
||||
.cornerRadius(10)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 10)
|
||||
.stroke(isCurrent ? Color.accentColor : Color.gray.opacity(0.2), lineWidth: isCurrent ? 2 : 1)
|
||||
)
|
||||
}
|
||||
|
||||
private func modelStatusBadge(isDownloaded: Bool, isCurrent: Bool) -> some View {
|
||||
Group {
|
||||
if isCurrent {
|
||||
Text("Default")
|
||||
.font(.caption)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 4)
|
||||
.background(Color.green)
|
||||
.foregroundColor(.white)
|
||||
.cornerRadius(8)
|
||||
} else if isDownloaded {
|
||||
Text("Downloaded")
|
||||
.font(.caption)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 4)
|
||||
.background(Color.indigo)
|
||||
.foregroundColor(.white)
|
||||
.cornerRadius(8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func performanceIndicator(label: String, value: Double) -> some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(label)
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
HStack(spacing: 2) {
|
||||
ForEach(0..<5) { index in
|
||||
RoundedRectangle(cornerRadius: 2)
|
||||
.fill(index < Int(value * 5) ? performanceColor(value: value) : Color.secondary.opacity(0.2))
|
||||
.frame(width: 16, height: 8)
|
||||
}
|
||||
}
|
||||
|
||||
Text(String(format: "%.1f", value * 10))
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
|
||||
private func performanceColor(value: Double) -> Color {
|
||||
switch value {
|
||||
case 0.8...: return .green
|
||||
case 0.6..<0.8: return .yellow
|
||||
case 0.4..<0.6: return .orange
|
||||
default: return .red
|
||||
}
|
||||
}
|
||||
|
||||
private func modelActionButton(isDownloaded: Bool, isCurrent: Bool, model: PredefinedModel) -> some View {
|
||||
Group {
|
||||
if isCurrent {
|
||||
Text("Default Model")
|
||||
.foregroundColor(.white)
|
||||
} else if isDownloaded {
|
||||
Button("Set as Default") {
|
||||
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
Task {
|
||||
await whisperState.setDefaultModel(downloadedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
} else if whisperState.downloadProgress[model.name] != nil {
|
||||
VStack {
|
||||
ProgressView(value: whisperState.downloadProgress[model.name] ?? 0)
|
||||
.progressViewStyle(LinearProgressViewStyle())
|
||||
.animation(.linear, value: whisperState.downloadProgress[model.name])
|
||||
Text("\(Int((whisperState.downloadProgress[model.name] ?? 0) * 100))%")
|
||||
.font(.caption)
|
||||
.animation(.none)
|
||||
}
|
||||
} else {
|
||||
Button("Download Model") {
|
||||
Task {
|
||||
await whisperState.downloadModel(model)
|
||||
}
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
}
|
||||
}
|
||||
.buttonStyle(GradientButtonStyle(isDownloaded: isDownloaded, isCurrent: isCurrent))
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
|
||||
private func ramUsageLabel(gb: Double) -> some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("RAM")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
Text(formatRAMSize(gb))
|
||||
.font(.system(size: 12, weight: .bold))
|
||||
.foregroundColor(.primary)
|
||||
}
|
||||
}
|
||||
|
||||
private func formatRAMSize(_ gb: Double) -> String {
|
||||
if gb >= 1.0 {
|
||||
return String(format: "%.1f GB", gb)
|
||||
} else {
|
||||
return String(format: "%d MB", Int(gb * 1024))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct GradientButtonStyle: ButtonStyle {
|
||||
let isDownloaded: Bool
|
||||
let isCurrent: Bool
|
||||
|
||||
func makeBody(configuration: Configuration) -> some View {
|
||||
configuration.label
|
||||
.padding(.vertical, 5)
|
||||
.padding(.horizontal, 10)
|
||||
.background(
|
||||
Group {
|
||||
if isCurrent {
|
||||
LinearGradient(gradient: Gradient(colors: [Color.green, Color.green.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
|
||||
} else if isDownloaded {
|
||||
LinearGradient(gradient: Gradient(colors: [Color.purple, Color.purple.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
|
||||
} else {
|
||||
LinearGradient(gradient: Gradient(colors: [Color.blue, Color.blue.opacity(0.7)]), startPoint: .top, endPoint: .bottom)
|
||||
}
|
||||
}
|
||||
)
|
||||
.cornerRadius(10)
|
||||
.shadow(color: Color.black.opacity(0.2), radius: 5, x: 0, y: 2)
|
||||
.scaleEffect(configuration.isPressed ? 0.95 : 1)
|
||||
.animation(.easeInOut(duration: 0.2), value: configuration.isPressed)
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,14 +83,10 @@ struct OnboardingModelDownloadView: View {
|
||||
|
||||
// Download progress
|
||||
if isDownloading {
|
||||
VStack(spacing: 8) {
|
||||
ProgressView(value: whisperState.downloadProgress[turboModel.name] ?? 0)
|
||||
.progressViewStyle(.linear)
|
||||
.tint(.white)
|
||||
Text("\(Int((whisperState.downloadProgress[turboModel.name] ?? 0) * 100))%")
|
||||
.font(.caption)
|
||||
.foregroundColor(.white.opacity(0.7))
|
||||
}
|
||||
DownloadProgressView(
|
||||
modelName: turboModel.name,
|
||||
downloadProgress: whisperState.downloadProgress
|
||||
)
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ enum WhisperStateError: Error, Identifiable {
|
||||
case accessibilityPermissionDenied
|
||||
case modelDownloadFailed
|
||||
case modelDeletionFailed
|
||||
case unzipFailed
|
||||
case unknownError
|
||||
|
||||
var id: String { UUID().uuidString }
|
||||
@ -27,6 +28,8 @@ extension WhisperStateError: LocalizedError {
|
||||
return "Failed to download the model."
|
||||
case .modelDeletionFailed:
|
||||
return "Failed to delete the model."
|
||||
case .unzipFailed:
|
||||
return "Failed to unzip the downloaded Core ML model."
|
||||
case .unknownError:
|
||||
return "An unknown error occurred."
|
||||
}
|
||||
@ -46,6 +49,8 @@ extension WhisperStateError: LocalizedError {
|
||||
return "Check your internet connection and try again. If the problem persists, try a different model."
|
||||
case .modelDeletionFailed:
|
||||
return "Restart the application and try again. If the problem persists, you may need to manually delete the model file."
|
||||
case .unzipFailed:
|
||||
return "The downloaded Core ML model archive might be corrupted. Try deleting the model and downloading it again. Check available disk space."
|
||||
case .unknownError:
|
||||
return "Please restart the application. If the problem persists, contact support."
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import Foundation
|
||||
import os
|
||||
import Zip
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - Model Management Extension
|
||||
extension WhisperState {
|
||||
@ -10,7 +12,7 @@ extension WhisperState {
|
||||
do {
|
||||
try FileManager.default.createDirectory(at: modelsDirectory, withIntermediateDirectories: true, attributes: nil)
|
||||
} catch {
|
||||
messageLog += "Error creating models directory: \(error.localizedDescription)\n"
|
||||
logError("Error creating models directory", error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,7 +24,7 @@ extension WhisperState {
|
||||
return WhisperModel(name: url.deletingPathExtension().lastPathComponent, url: url)
|
||||
}
|
||||
} catch {
|
||||
messageLog += "Error loading available models: \(error.localizedDescription)\n"
|
||||
logError("Error loading available models", error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,7 +33,6 @@ extension WhisperState {
|
||||
func loadModel(_ model: WhisperModel) async throws {
|
||||
guard whisperContext == nil else { return }
|
||||
|
||||
logger.notice("🔄 Loading Whisper model: \(model.name)")
|
||||
isModelLoading = true
|
||||
defer { isModelLoading = false }
|
||||
|
||||
@ -39,86 +40,186 @@ extension WhisperState {
|
||||
whisperContext = try await WhisperContext.createContext(path: model.url.path)
|
||||
isModelLoaded = true
|
||||
currentModel = model
|
||||
logger.notice("✅ Successfully loaded model: \(model.name)")
|
||||
} catch {
|
||||
logger.error("❌ Failed to load model: \(model.name) - \(error.localizedDescription)")
|
||||
throw WhisperStateError.modelLoadFailed
|
||||
}
|
||||
}
|
||||
|
||||
func setDefaultModel(_ model: WhisperModel) async {
|
||||
do {
|
||||
currentModel = model
|
||||
UserDefaults.standard.set(model.name, forKey: "CurrentModel")
|
||||
canTranscribe = true
|
||||
} catch {
|
||||
currentError = error as? WhisperStateError ?? .unknownError
|
||||
canTranscribe = false
|
||||
}
|
||||
currentModel = model
|
||||
UserDefaults.standard.set(model.name, forKey: "CurrentModel")
|
||||
canTranscribe = true
|
||||
}
|
||||
|
||||
// MARK: - Model Download & Management
|
||||
|
||||
func downloadModel(_ model: PredefinedModel) async {
|
||||
guard let url = URL(string: model.downloadURL) else { return }
|
||||
|
||||
logger.notice("🔽 Downloading model: \(model.name)")
|
||||
do {
|
||||
let (data, response) = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in
|
||||
let task = URLSession.shared.dataTask(with: url) { data, response, error in
|
||||
if let error = error {
|
||||
continuation.resume(throwing: error)
|
||||
return
|
||||
}
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200...299).contains(httpResponse.statusCode),
|
||||
let data = data else {
|
||||
continuation.resume(throwing: URLError(.badServerResponse))
|
||||
return
|
||||
}
|
||||
continuation.resume(returning: (data, httpResponse))
|
||||
/// Helper function to download a file from a URL with progress tracking
|
||||
private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data {
|
||||
let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString)
|
||||
|
||||
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Data, Error>) in
|
||||
let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in
|
||||
if let error = error {
|
||||
continuation.resume(throwing: error)
|
||||
return
|
||||
}
|
||||
|
||||
task.resume()
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200...299).contains(httpResponse.statusCode),
|
||||
let tempURL = tempURL else {
|
||||
continuation.resume(throwing: URLError(.badServerResponse))
|
||||
return
|
||||
}
|
||||
|
||||
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in
|
||||
do {
|
||||
// Move the downloaded file to the final destination
|
||||
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
|
||||
|
||||
// Read the file in chunks to avoid memory pressure
|
||||
let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe)
|
||||
continuation.resume(returning: data)
|
||||
|
||||
// Clean up the temporary file
|
||||
try? FileManager.default.removeItem(at: destinationURL)
|
||||
} catch {
|
||||
continuation.resume(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
task.resume()
|
||||
|
||||
var lastUpdateTime = Date()
|
||||
var lastProgressValue: Double = 0
|
||||
|
||||
let observation = task.progress.observe(\.fractionCompleted) { progress, _ in
|
||||
let currentTime = Date()
|
||||
let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime)
|
||||
let currentProgress = round(progress.fractionCompleted * 100) / 100
|
||||
|
||||
if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 {
|
||||
lastUpdateTime = currentTime
|
||||
lastProgressValue = currentProgress
|
||||
|
||||
DispatchQueue.main.async {
|
||||
self.downloadProgress[model.name] = progress.fractionCompleted
|
||||
}
|
||||
}
|
||||
|
||||
Task {
|
||||
await withTaskCancellationHandler {
|
||||
observation.invalidate()
|
||||
} operation: {
|
||||
await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
|
||||
self.downloadProgress[progressKey] = currentProgress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let destinationURL = modelsDirectory.appendingPathComponent(model.filename)
|
||||
try data.write(to: destinationURL)
|
||||
|
||||
availableModels.append(WhisperModel(name: model.name, url: destinationURL))
|
||||
self.downloadProgress.removeValue(forKey: model.name)
|
||||
logger.notice("✅ Successfully downloaded model: \(model.name)")
|
||||
} catch {
|
||||
logger.error("❌ Failed to download model: \(model.name) - \(error.localizedDescription)")
|
||||
currentError = .modelDownloadFailed
|
||||
self.downloadProgress.removeValue(forKey: model.name)
|
||||
|
||||
Task {
|
||||
await withTaskCancellationHandler {
|
||||
observation.invalidate()
|
||||
} operation: {
|
||||
await withCheckedContinuation { (_: CheckedContinuation<Void, Never>) in }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func downloadModel(_ model: PredefinedModel) async {
|
||||
guard let url = URL(string: model.downloadURL) else { return }
|
||||
|
||||
do {
|
||||
let whisperModel = try await downloadMainModel(model, from: url)
|
||||
|
||||
if let coreMLZipURL = whisperModel.coreMLZipDownloadURL,
|
||||
let coreMLURL = URL(string: coreMLZipURL) {
|
||||
try await downloadAndSetupCoreMLModel(for: whisperModel, from: coreMLURL)
|
||||
}
|
||||
|
||||
availableModels.append(whisperModel)
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
||||
} catch {
|
||||
handleModelDownloadError(model, error)
|
||||
}
|
||||
}
|
||||
|
||||
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
|
||||
let progressKeyMain = model.name + "_main"
|
||||
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
|
||||
|
||||
let destinationURL = modelsDirectory.appendingPathComponent(model.filename)
|
||||
try data.write(to: destinationURL)
|
||||
|
||||
return WhisperModel(name: model.name, url: destinationURL)
|
||||
}
|
||||
|
||||
private func downloadAndSetupCoreMLModel(for model: WhisperModel, from url: URL) async throws {
|
||||
let progressKeyCoreML = model.name + "_coreml"
|
||||
let coreMLData = try await downloadFileWithProgress(from: url, progressKey: progressKeyCoreML)
|
||||
|
||||
let coreMLZipPath = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc.zip")
|
||||
try coreMLData.write(to: coreMLZipPath)
|
||||
|
||||
try await unzipAndSetupCoreMLModel(for: model, zipPath: coreMLZipPath, progressKey: progressKeyCoreML)
|
||||
}
|
||||
|
||||
private func unzipAndSetupCoreMLModel(for model: WhisperModel, zipPath: URL, progressKey: String) async throws {
|
||||
let coreMLDestination = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
|
||||
|
||||
try? FileManager.default.removeItem(at: coreMLDestination)
|
||||
try await unzipCoreMLFile(zipPath, to: modelsDirectory)
|
||||
try verifyAndCleanupCoreMLFiles(model, coreMLDestination, zipPath, progressKey)
|
||||
}
|
||||
|
||||
private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws {
|
||||
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
|
||||
do {
|
||||
try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true)
|
||||
try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil)
|
||||
continuation.resume()
|
||||
} catch {
|
||||
continuation.resume(throwing: error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func verifyAndCleanupCoreMLFiles(_ model: WhisperModel, _ destination: URL, _ zipPath: URL, _ progressKey: String) throws -> WhisperModel {
|
||||
var model = model
|
||||
|
||||
var isDirectory: ObjCBool = false
|
||||
guard FileManager.default.fileExists(atPath: destination.path, isDirectory: &isDirectory), isDirectory.boolValue else {
|
||||
try? FileManager.default.removeItem(at: zipPath)
|
||||
throw WhisperStateError.unzipFailed
|
||||
}
|
||||
|
||||
try? FileManager.default.removeItem(at: zipPath)
|
||||
model.coreMLEncoderURL = destination
|
||||
self.downloadProgress.removeValue(forKey: progressKey)
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
|
||||
currentError = .modelDownloadFailed
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
|
||||
}
|
||||
|
||||
func deleteModel(_ model: WhisperModel) async {
|
||||
do {
|
||||
// Delete main model file
|
||||
try FileManager.default.removeItem(at: model.url)
|
||||
|
||||
// Delete CoreML model if it exists
|
||||
if let coreMLURL = model.coreMLEncoderURL {
|
||||
try? FileManager.default.removeItem(at: coreMLURL)
|
||||
} else {
|
||||
// Check if there's a CoreML directory matching the model name
|
||||
let coreMLDir = modelsDirectory.appendingPathComponent("\(model.name)-encoder.mlmodelc")
|
||||
if FileManager.default.fileExists(atPath: coreMLDir.path) {
|
||||
try? FileManager.default.removeItem(at: coreMLDir)
|
||||
}
|
||||
}
|
||||
|
||||
// Update model state
|
||||
availableModels.removeAll { $0.id == model.id }
|
||||
if currentModel?.id == model.id {
|
||||
currentModel = nil
|
||||
canTranscribe = false
|
||||
}
|
||||
} catch {
|
||||
messageLog += "Error deleting model: \(error.localizedDescription)\n"
|
||||
logError("Error deleting model: \(model.name)", error)
|
||||
currentError = .modelDeletionFailed
|
||||
}
|
||||
}
|
||||
@ -141,7 +242,7 @@ extension WhisperState {
|
||||
do {
|
||||
try FileManager.default.removeItem(at: model.url)
|
||||
} catch {
|
||||
messageLog += "Error deleting model: \(error.localizedDescription)\n"
|
||||
logError("Error deleting model during cleanup", error)
|
||||
}
|
||||
}
|
||||
availableModels.removeAll()
|
||||
@ -151,21 +252,88 @@ extension WhisperState {
|
||||
|
||||
func cleanupModelResources() async {
|
||||
recorder.stopRecording()
|
||||
try? await Task.sleep(nanoseconds: 500_000_000)
|
||||
|
||||
// Add a small delay to ensure recording has fully stopped
|
||||
try? await Task.sleep(nanoseconds: 500_000_000) // 0.5 seconds
|
||||
|
||||
// Only cleanup model resources if we're not actively using them
|
||||
let canCleanup = !isRecording && !isProcessing
|
||||
|
||||
if canCleanup {
|
||||
logger.notice("🧹 Cleaning up Whisper resources")
|
||||
// Release any resources held by the model
|
||||
await whisperContext?.releaseResources()
|
||||
whisperContext = nil
|
||||
isModelLoaded = false
|
||||
} else {
|
||||
logger.info("Skipping cleanup while recording or processing is active")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helper Methods
|
||||
|
||||
private func logError(_ message: String, _ error: Error) {
|
||||
messageLog += "\(message): \(error.localizedDescription)\n"
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Download Progress View
|
||||
struct DownloadProgressView: View {
|
||||
let modelName: String
|
||||
let downloadProgress: [String: Double]
|
||||
|
||||
@Environment(\.colorScheme) private var colorScheme
|
||||
|
||||
private var mainProgress: Double {
|
||||
downloadProgress[modelName + "_main"] ?? 0
|
||||
}
|
||||
|
||||
private var coreMLProgress: Double {
|
||||
supportsCoreML ? (downloadProgress[modelName + "_coreml"] ?? 0) : 0
|
||||
}
|
||||
|
||||
private var supportsCoreML: Bool {
|
||||
!modelName.contains("q5") && !modelName.contains("q8")
|
||||
}
|
||||
|
||||
private var totalProgress: Double {
|
||||
supportsCoreML ? (mainProgress * 0.5) + (coreMLProgress * 0.5) : mainProgress
|
||||
}
|
||||
|
||||
private var downloadPhase: String {
|
||||
// Check if we're currently downloading the CoreML model
|
||||
if supportsCoreML && downloadProgress[modelName + "_coreml"] != nil {
|
||||
return "Downloading Core ML Model for \(modelName)"
|
||||
}
|
||||
// Otherwise, we're downloading the main model
|
||||
return "Downloading \(modelName) Model"
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
// Status text with clean typography
|
||||
Text(downloadPhase)
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
|
||||
// Clean progress bar
|
||||
GeometryReader { geometry in
|
||||
ZStack(alignment: .leading) {
|
||||
// Background track
|
||||
RoundedRectangle(cornerRadius: 4)
|
||||
.fill(Color(.separatorColor).opacity(0.3))
|
||||
.frame(height: 6)
|
||||
|
||||
// Progress indicator
|
||||
RoundedRectangle(cornerRadius: 4)
|
||||
.fill(Color(.controlAccentColor))
|
||||
.frame(width: max(0, min(geometry.size.width * totalProgress, geometry.size.width)), height: 6)
|
||||
}
|
||||
}
|
||||
.frame(height: 6)
|
||||
|
||||
// Percentage indicator in Apple style
|
||||
HStack {
|
||||
Spacer()
|
||||
Text("\(Int(totalProgress * 100))%")
|
||||
.font(.system(size: 11, weight: .medium, design: .monospaced))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
}
|
||||
}
|
||||
.padding(.vertical, 4)
|
||||
.animation(.smooth, value: totalProgress)
|
||||
}
|
||||
}
|
||||
@ -438,12 +438,28 @@ struct WhisperModel: Identifiable {
|
||||
let id = UUID()
|
||||
let name: String
|
||||
let url: URL
|
||||
var coreMLEncoderURL: URL? // Path to the unzipped .mlmodelc directory
|
||||
var isCoreMLDownloaded: Bool { coreMLEncoderURL != nil }
|
||||
|
||||
var downloadURL: String {
|
||||
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)"
|
||||
}
|
||||
|
||||
var filename: String {
|
||||
"\(name).bin"
|
||||
}
|
||||
|
||||
// Core ML related properties
|
||||
var coreMLZipDownloadURL: String? {
|
||||
// Only non-quantized models have Core ML versions
|
||||
guard !name.contains("q5") && !name.contains("q8") else { return nil }
|
||||
return "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(name)-encoder.mlmodelc.zip"
|
||||
}
|
||||
|
||||
var coreMLEncoderDirectoryName: String? {
|
||||
guard coreMLZipDownloadURL != nil else { return nil }
|
||||
return "\(name)-encoder.mlmodelc"
|
||||
}
|
||||
}
|
||||
|
||||
private class TaskDelegate: NSObject, URLSessionTaskDelegate {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user