From 05cfbd1a0bf314677ce7c20b3074c610ad2c64bf Mon Sep 17 00:00:00 2001 From: Beingpax Date: Tue, 17 Jun 2025 13:50:22 +0545 Subject: [PATCH] Custom Trancription model with openAI compatibility --- VoiceInk/Models/PredefinedModels.swift | 22 +- VoiceInk/Models/TranscriptionModel.swift | 27 ++ VoiceInk/PowerMode/PowerModeConfigView.swift | 2 +- .../CloudTranscriptionService.swift | 6 + ...OpenAICompatibleTranscriptionService.swift | 104 +++++++ VoiceInk/Services/CustomModelManager.swift | 138 +++++++++ VoiceInk/Views/APIKeyManagementView.swift | 2 +- VoiceInk/Views/AddCustomModelView.swift | 290 ++++++++++++++++++ VoiceInk/Views/ModelCardRowView.swift | 139 +++++++++ VoiceInk/Views/ModelManagementView.swift | 36 ++- .../Whisper/WhisperState+ModelQueries.swift | 3 + VoiceInk/Whisper/WhisperState.swift | 5 + 12 files changed, 761 insertions(+), 13 deletions(-) create mode 100644 VoiceInk/Services/CloudTranscription/OpenAICompatibleTranscriptionService.swift create mode 100644 VoiceInk/Services/CustomModelManager.swift create mode 100644 VoiceInk/Views/AddCustomModelView.swift diff --git a/VoiceInk/Models/PredefinedModels.swift b/VoiceInk/Models/PredefinedModels.swift index 480614c..2b49a32 100644 --- a/VoiceInk/Models/PredefinedModels.swift +++ b/VoiceInk/Models/PredefinedModels.swift @@ -27,15 +27,19 @@ import Foundation } enum PredefinedModels { - static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] { - if !isMultilingual { - return ["en": "English"] - } else { - return allLanguages - } - } - - static let models: [any TranscriptionModel] = [ + static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] { + if !isMultilingual { + return ["en": "English"] + } else { + return allLanguages + } + } + + static var models: [any TranscriptionModel] { + return predefinedModels + CustomModelManager.shared.customModels + } + + private static let predefinedModels: [any TranscriptionModel] = [ // Local Models LocalModel( name: "ggml-tiny", diff --git a/VoiceInk/Models/TranscriptionModel.swift b/VoiceInk/Models/TranscriptionModel.swift index 39941f1..3a36c92 100644 --- a/VoiceInk/Models/TranscriptionModel.swift +++ b/VoiceInk/Models/TranscriptionModel.swift @@ -6,6 +6,7 @@ enum ModelProvider: String, Codable, Hashable, CaseIterable { case groq = "Groq" case elevenLabs = "ElevenLabs" case deepgram = "Deepgram" + case custom = "Custom" // Future providers can be added here } @@ -55,4 +56,30 @@ struct CloudModel: TranscriptionModel { self.isMultilingualModel = isMultilingual self.supportedLanguages = supportedLanguages } +} + +// A new struct for custom cloud models +struct CustomCloudModel: TranscriptionModel, Codable { + let id: UUID + let name: String + let displayName: String + let description: String + let provider: ModelProvider = .custom + let apiEndpoint: String + let apiKey: String + let modelName: String + let isMultilingualModel: Bool + let supportedLanguages: [String: String] + + init(id: UUID = UUID(), name: String, displayName: String, description: String, apiEndpoint: String, apiKey: String, modelName: String, isMultilingual: Bool = true, supportedLanguages: [String: String]? = nil) { + self.id = id + self.name = name + self.displayName = displayName + self.description = description + self.apiEndpoint = apiEndpoint + self.apiKey = apiKey + self.modelName = modelName + self.isMultilingualModel = isMultilingual + self.supportedLanguages = supportedLanguages ?? PredefinedModels.getLanguageDictionary(isMultilingual: isMultilingual) + } } \ No newline at end of file diff --git a/VoiceInk/PowerMode/PowerModeConfigView.swift b/VoiceInk/PowerMode/PowerModeConfigView.swift index eee0cc8..8dfdb6a 100644 --- a/VoiceInk/PowerMode/PowerModeConfigView.swift +++ b/VoiceInk/PowerMode/PowerModeConfigView.swift @@ -487,7 +487,7 @@ struct ConfigurationView: View { .frame(maxWidth: .infinity, alignment: .leading) } else { Picker("", selection: providerBinding) { - ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in + ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in Text(provider.rawValue).tag(provider) } } diff --git a/VoiceInk/Services/CloudTranscription/CloudTranscriptionService.swift b/VoiceInk/Services/CloudTranscription/CloudTranscriptionService.swift index 28299c8..49785dd 100644 --- a/VoiceInk/Services/CloudTranscription/CloudTranscriptionService.swift +++ b/VoiceInk/Services/CloudTranscription/CloudTranscriptionService.swift @@ -38,6 +38,7 @@ class CloudTranscriptionService: TranscriptionService { private let groqService = GroqTranscriptionService() private let elevenLabsService = ElevenLabsTranscriptionService() private let deepgramService = DeepgramTranscriptionService() + private let openAICompatibleService = OpenAICompatibleTranscriptionService() func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { switch model.provider { @@ -47,6 +48,11 @@ class CloudTranscriptionService: TranscriptionService { return try await elevenLabsService.transcribe(audioURL: audioURL, model: model) case .deepgram: return try await deepgramService.transcribe(audioURL: audioURL, model: model) + case .custom: + guard let customModel = model as? CustomCloudModel else { + throw CloudTranscriptionError.unsupportedProvider + } + return try await openAICompatibleService.transcribe(audioURL: audioURL, model: customModel) default: throw CloudTranscriptionError.unsupportedProvider } diff --git a/VoiceInk/Services/CloudTranscription/OpenAICompatibleTranscriptionService.swift b/VoiceInk/Services/CloudTranscription/OpenAICompatibleTranscriptionService.swift new file mode 100644 index 0000000..57d0325 --- /dev/null +++ b/VoiceInk/Services/CloudTranscription/OpenAICompatibleTranscriptionService.swift @@ -0,0 +1,104 @@ +import Foundation +import os + +class OpenAICompatibleTranscriptionService { + private let logger = Logger(subsystem: "com.voiceink.transcription", category: "OpenAICompatibleService") + + func transcribe(audioURL: URL, model: CustomCloudModel) async throws -> String { + let config = APIConfig( + url: URL(string: model.apiEndpoint)!, + apiKey: model.apiKey, + modelName: model.modelName + ) + + let boundary = "Boundary-\(UUID().uuidString)" + var request = URLRequest(url: config.url) + request.httpMethod = "POST" + request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") + request.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization") + + let body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary) + + let (data, response) = try await URLSession.shared.upload(for: request, from: body) + guard let httpResponse = response as? HTTPURLResponse else { + throw CloudTranscriptionError.networkError(URLError(.badServerResponse)) + } + + if !(200...299).contains(httpResponse.statusCode) { + let errorMessage = String(data: data, encoding: .utf8) ?? "No error message" + logger.error("OpenAI-compatible API request failed with status \(httpResponse.statusCode): \(errorMessage, privacy: .public)") + throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage) + } + + do { + let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data) + return transcriptionResponse.text + } catch { + logger.error("Failed to decode OpenAI-compatible API response: \(error.localizedDescription)") + throw CloudTranscriptionError.noTranscriptionReturned + } + } + + private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data { + var body = Data() + let crlf = "\r\n" + + guard let audioData = try? Data(contentsOf: audioURL) else { + throw CloudTranscriptionError.audioFileNotFound + } + + let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto" + let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? "" + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!) + body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!) + body.append(audioData) + body.append(crlf.data(using: .utf8)!) + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(modelName.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + + if selectedLanguage != "auto", !selectedLanguage.isEmpty { + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(selectedLanguage.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + } + + // Include prompt for OpenAI-compatible APIs + if !prompt.isEmpty { + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(prompt.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + } + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append("json".data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append("0".data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + body.append("--\(boundary)--\(crlf)".data(using: .utf8)!) + + return body + } + + private struct APIConfig { + let url: URL + let apiKey: String + let modelName: String + } + + private struct TranscriptionResponse: Decodable { + let text: String + let language: String? + let duration: Double? + } +} \ No newline at end of file diff --git a/VoiceInk/Services/CustomModelManager.swift b/VoiceInk/Services/CustomModelManager.swift new file mode 100644 index 0000000..1a9b910 --- /dev/null +++ b/VoiceInk/Services/CustomModelManager.swift @@ -0,0 +1,138 @@ +import Foundation +import os + +class CustomModelManager: ObservableObject { + static let shared = CustomModelManager() + + private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "CustomModelManager") + private let userDefaults = UserDefaults.standard + private let customModelsKey = "customCloudModels" + + @Published var customModels: [CustomCloudModel] = [] + + private init() { + loadCustomModels() + } + + // MARK: - CRUD Operations + + func addCustomModel(_ model: CustomCloudModel) { + customModels.append(model) + saveCustomModels() + logger.info("Added custom model: \(model.displayName)") + } + + func removeCustomModel(withId id: UUID) { + customModels.removeAll { $0.id == id } + saveCustomModels() + logger.info("Removed custom model with ID: \(id)") + } + + func updateCustomModel(_ updatedModel: CustomCloudModel) { + if let index = customModels.firstIndex(where: { $0.id == updatedModel.id }) { + customModels[index] = updatedModel + saveCustomModels() + logger.info("Updated custom model: \(updatedModel.displayName)") + } + } + + // MARK: - Persistence + + private func loadCustomModels() { + guard let data = userDefaults.data(forKey: customModelsKey) else { + logger.info("No custom models found in UserDefaults") + return + } + + do { + customModels = try JSONDecoder().decode([CustomCloudModel].self, from: data) + } catch { + logger.error("Failed to decode custom models: \(error.localizedDescription)") + customModels = [] + } + } + + private func saveCustomModels() { + do { + let data = try JSONEncoder().encode(customModels) + userDefaults.set(data, forKey: customModelsKey) + } catch { + logger.error("Failed to encode custom models: \(error.localizedDescription)") + } + } + + // MARK: - Validation + + func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String) -> [String] { + var errors: [String] = [] + + if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Name cannot be empty") + } + + if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Display name cannot be empty") + } + + if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("API endpoint cannot be empty") + } else if !isValidURL(apiEndpoint) { + errors.append("API endpoint must be a valid URL") + } + + if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("API key cannot be empty") + } + + if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Model name cannot be empty") + } + + // Check for duplicate names + if customModels.contains(where: { $0.name == name }) { + errors.append("A model with this name already exists") + } + + return errors + } + + func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String, excludingId: UUID? = nil) -> [String] { + var errors: [String] = [] + + if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Name cannot be empty") + } + + if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Display name cannot be empty") + } + + if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("API endpoint cannot be empty") + } else if !isValidURL(apiEndpoint) { + errors.append("API endpoint must be a valid URL") + } + + if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("API key cannot be empty") + } + + if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + errors.append("Model name cannot be empty") + } + + // Check for duplicate names, excluding the specified ID + if customModels.contains(where: { $0.name == name && $0.id != excludingId }) { + errors.append("A model with this name already exists") + } + + return errors + } + + private func isValidURL(_ string: String) -> Bool { + if let url = URL(string: string) { + return url.scheme != nil && url.host != nil + } + return false + } +} diff --git a/VoiceInk/Views/APIKeyManagementView.swift b/VoiceInk/Views/APIKeyManagementView.swift index 9ebf516..5488155 100644 --- a/VoiceInk/Views/APIKeyManagementView.swift +++ b/VoiceInk/Views/APIKeyManagementView.swift @@ -44,7 +44,7 @@ struct APIKeyManagementView: View { // Provider Selection Picker("AI Provider", selection: $aiService.selectedProvider) { - ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in + ForEach(AIProvider.allCases.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in Text(provider.rawValue).tag(provider) } } diff --git a/VoiceInk/Views/AddCustomModelView.swift b/VoiceInk/Views/AddCustomModelView.swift new file mode 100644 index 0000000..d2b87ee --- /dev/null +++ b/VoiceInk/Views/AddCustomModelView.swift @@ -0,0 +1,290 @@ +import SwiftUI + +struct AddCustomModelCardView: View { + @ObservedObject var customModelManager: CustomModelManager + var onModelAdded: () -> Void + var editingModel: CustomCloudModel? = nil + + @State private var isExpanded = false + @State private var displayName = "" + @State private var apiEndpoint = "" + @State private var apiKey = "" + @State private var modelName = "" + @State private var isMultilingual = true + + @State private var validationErrors: [String] = [] + @State private var showingAlert = false + @State private var isSaving = false + + var body: some View { + VStack(spacing: 0) { + // Simple Add Model Button + if !isExpanded { + Button(action: { + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded = true + // Pre-fill values - either from editing model or defaults + if let editing = editingModel { + displayName = editing.displayName + apiEndpoint = editing.apiEndpoint + apiKey = editing.apiKey + modelName = editing.modelName + isMultilingual = editing.isMultilingualModel + } else { + // Pre-fill some default values when adding new + if apiEndpoint.isEmpty { + apiEndpoint = "https://api.example.com/v1/audio/transcriptions" + } + if modelName.isEmpty { + modelName = "large-v3-turbo" + } + } + } + }) { + HStack(spacing: 8) { + Image(systemName: "plus") + .font(.system(size: 14, weight: .medium)) + Text(editingModel != nil ? "Edit Model" : "Add Model") + .font(.system(size: 14, weight: .semibold)) + } + .foregroundColor(.white) + .frame(maxWidth: .infinity) + .padding(.vertical, 12) + .background(Color.accentColor) + .cornerRadius(12) + } + .buttonStyle(.plain) + .shadow(color: Color.accentColor.opacity(0.3), radius: 8, y: 4) + } + + // Expandable Form Section + if isExpanded { + VStack(alignment: .leading, spacing: 20) { + // Header + HStack { + Text(editingModel != nil ? "Edit Custom Model" : "Add Custom Model") + .font(.headline) + .foregroundColor(.primary) + + Spacer() + + Button(action: { + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded = false + clearForm() + } + }) { + Image(systemName: "xmark") + .font(.system(size: 12, weight: .medium)) + .foregroundColor(.secondary) + } + .buttonStyle(.plain) + } + + // Disclaimer + HStack(spacing: 8) { + Image(systemName: "exclamationmark.triangle.fill") + .foregroundColor(.orange) + .font(.caption) + Text("Only OpenAI-compatible transcription APIs are supported") + .font(.caption) + .foregroundColor(.secondary) + } + .padding(.horizontal, 12) + .padding(.vertical, 8) + .background(Color.orange.opacity(0.1)) + .cornerRadius(8) + + // Form fields + VStack(alignment: .leading, spacing: 16) { + FormField(title: "Display Name", text: $displayName, placeholder: "My Custom Model") + FormField(title: "API Endpoint", text: $apiEndpoint, placeholder: "https://api.example.com/v1/audio/transcriptions") + FormField(title: "API Key", text: $apiKey, placeholder: "your-api-key", isSecure: true) + FormField(title: "Model Name", text: $modelName, placeholder: "whisper-1") + + Toggle("Multilingual Model", isOn: $isMultilingual) + } + + // Action buttons + HStack(spacing: 12) { + Button(action: { + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded = false + clearForm() + } + }) { + Text("Cancel") + .font(.system(size: 13, weight: .medium)) + .foregroundColor(.secondary) + .frame(maxWidth: .infinity) + .padding(.vertical, 10) + .background(Color.secondary.opacity(0.1)) + .cornerRadius(8) + } + .buttonStyle(.plain) + + Button(action: { + addModel() + }) { + HStack(spacing: 6) { + if isSaving { + ProgressView() + .scaleEffect(0.8) + .frame(width: 14, height: 14) + } else { + Image(systemName: editingModel != nil ? "checkmark.circle.fill" : "plus.circle.fill") + .font(.system(size: 14)) + } + Text(editingModel != nil ? "Update Model" : "Add Model") + .font(.system(size: 13, weight: .medium)) + } + .foregroundColor(.white) + .frame(maxWidth: .infinity) + .padding(.vertical, 10) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(isFormValid ? Color(.controlAccentColor) : Color.secondary) + .shadow(color: (isFormValid ? Color(.controlAccentColor) : Color.secondary).opacity(0.2), radius: 2, x: 0, y: 1) + ) + } + .buttonStyle(.plain) + .disabled(!isFormValid || isSaving) + } + } + .padding(16) + .background( + RoundedRectangle(cornerRadius: 12) + .fill(Color(.windowBackgroundColor)) + .overlay( + RoundedRectangle(cornerRadius: 12) + .stroke(Color(.separatorColor), lineWidth: 1) + ) + ) + } + } + .alert("Validation Errors", isPresented: $showingAlert) { + Button("OK") { } + } message: { + Text(validationErrors.joined(separator: "\n")) + } + .onChange(of: editingModel) { oldValue, newValue in + if newValue != nil { + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded = true + // Pre-fill values from editing model + if let editing = newValue { + displayName = editing.displayName + apiEndpoint = editing.apiEndpoint + apiKey = editing.apiKey + modelName = editing.modelName + isMultilingual = editing.isMultilingualModel + } + } + } + } + } + + private var isFormValid: Bool { + !displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && + !apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && + !apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && + !modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + } + + private func clearForm() { + displayName = "" + apiEndpoint = "" + apiKey = "" + modelName = "" + isMultilingual = true + } + + private func addModel() { + let trimmedDisplayName = displayName.trimmingCharacters(in: .whitespacesAndNewlines) + let trimmedApiEndpoint = apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines) + let trimmedApiKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines) + let trimmedModelName = modelName.trimmingCharacters(in: .whitespacesAndNewlines) + + // Generate a name from display name (lowercase, no spaces) + let generatedName = trimmedDisplayName.lowercased().replacingOccurrences(of: " ", with: "-") + + validationErrors = customModelManager.validateModel( + name: generatedName, + displayName: trimmedDisplayName, + apiEndpoint: trimmedApiEndpoint, + apiKey: trimmedApiKey, + modelName: trimmedModelName, + excludingId: editingModel?.id + ) + + if !validationErrors.isEmpty { + showingAlert = true + return + } + + isSaving = true + + // Simulate a brief save operation for better UX + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { + if let editing = editingModel { + // Update existing model + let updatedModel = CustomCloudModel( + id: editing.id, + name: generatedName, + displayName: trimmedDisplayName, + description: "Custom transcription model", + apiEndpoint: trimmedApiEndpoint, + apiKey: trimmedApiKey, + modelName: trimmedModelName, + isMultilingual: isMultilingual + ) + customModelManager.updateCustomModel(updatedModel) + } else { + // Add new model + let customModel = CustomCloudModel( + name: generatedName, + displayName: trimmedDisplayName, + description: "Custom transcription model", + apiEndpoint: trimmedApiEndpoint, + apiKey: trimmedApiKey, + modelName: trimmedModelName, + isMultilingual: isMultilingual + ) + customModelManager.addCustomModel(customModel) + } + + onModelAdded() + + // Reset form and collapse + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded = false + clearForm() + isSaving = false + } + } + } +} + +struct FormField: View { + let title: String + @Binding var text: String + let placeholder: String + var isSecure: Bool = false + + var body: some View { + VStack(alignment: .leading, spacing: 6) { + Text(title) + .font(.subheadline) + .fontWeight(.medium) + .foregroundColor(.primary) + + if isSecure { + SecureField(placeholder, text: $text) + .textFieldStyle(.roundedBorder) + } else { + TextField(placeholder, text: $text) + .textFieldStyle(.roundedBorder) + } + } + } +} \ No newline at end of file diff --git a/VoiceInk/Views/ModelCardRowView.swift b/VoiceInk/Views/ModelCardRowView.swift index 457a6da..f50f2d7 100644 --- a/VoiceInk/Views/ModelCardRowView.swift +++ b/VoiceInk/Views/ModelCardRowView.swift @@ -12,6 +12,7 @@ struct ModelCardRowView: View { var deleteAction: () -> Void var setDefaultAction: () -> Void var downloadAction: () -> Void + var editAction: ((CustomCloudModel) -> Void)? var body: some View { Group { @@ -37,6 +38,16 @@ struct ModelCardRowView: View { setDefaultAction: setDefaultAction ) } + case .custom: + if let customModel = model as? CustomCloudModel { + CustomModelCardView( + model: customModel, + isCurrent: isCurrent, + setDefaultAction: setDefaultAction, + deleteAction: deleteAction, + editAction: editAction ?? { _ in } + ) + } } } } @@ -576,4 +587,132 @@ private func performanceColor(value: Double) -> Color { case 0.4..<0.6: return Color(.systemOrange) default: return Color(.systemRed) } +} + +// MARK: - Custom Model Card View +struct CustomModelCardView: View { + let model: CustomCloudModel + let isCurrent: Bool + var setDefaultAction: () -> Void + var deleteAction: () -> Void + var editAction: (CustomCloudModel) -> Void + + var body: some View { + VStack(alignment: .leading, spacing: 0) { + // Main card content + HStack(alignment: .top, spacing: 16) { + VStack(alignment: .leading, spacing: 6) { + headerSection + metadataSection + descriptionSection + } + .frame(maxWidth: .infinity, alignment: .leading) + + actionSection + } + .padding(16) + } + .background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent)) + } + + private var headerSection: some View { + HStack(alignment: .firstTextBaseline) { + Text(model.displayName) + .font(.system(size: 13, weight: .semibold)) + .foregroundColor(Color(.labelColor)) + + statusBadge + + Spacer() + } + } + + private var statusBadge: some View { + Group { + if isCurrent { + Text("Default") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color.accentColor)) + .foregroundColor(.white) + } else { + Text("Custom") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color.orange.opacity(0.2))) + .foregroundColor(Color.orange) + } + } + } + + private var metadataSection: some View { + HStack(spacing: 12) { + // Provider + Label("Custom Provider", systemImage: "cloud") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(1) + + // Language + Label(model.language, systemImage: "globe") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(1) + + // OpenAI Compatible + Label("OpenAI Compatible", systemImage: "checkmark.seal") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(1) + } + .lineLimit(1) + } + + private var descriptionSection: some View { + Text(model.description) + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(2) + .fixedSize(horizontal: false, vertical: true) + .padding(.top, 4) + } + + private var actionSection: some View { + HStack(spacing: 8) { + if isCurrent { + Text("Default Model") + .font(.system(size: 12)) + .foregroundColor(Color(.secondaryLabelColor)) + } else { + Button(action: setDefaultAction) { + Text("Set as Default") + .font(.system(size: 12)) + } + .buttonStyle(.bordered) + .controlSize(.small) + } + + Menu { + Button { + editAction(model) + } label: { + Label("Edit Model", systemImage: "pencil") + } + + Button(role: .destructive) { + deleteAction() + } label: { + Label("Delete Model", systemImage: "trash") + } + } label: { + Image(systemName: "ellipsis.circle") + .font(.system(size: 14)) + } + .menuStyle(.borderlessButton) + .menuIndicator(.hidden) + .frame(width: 20, height: 20) + } + } } diff --git a/VoiceInk/Views/ModelManagementView.swift b/VoiceInk/Views/ModelManagementView.swift index 5fa5e59..2ad4ab1 100644 --- a/VoiceInk/Views/ModelManagementView.swift +++ b/VoiceInk/Views/ModelManagementView.swift @@ -4,10 +4,14 @@ import SwiftData struct ModelManagementView: View { @ObservedObject var whisperState: WhisperState @State private var modelToDelete: WhisperModel? + @State private var customModelToDelete: CustomCloudModel? + @State private var customModelToEdit: CustomCloudModel? @StateObject private var aiService = AIService() + @StateObject private var customModelManager = CustomModelManager.shared @EnvironmentObject private var enhancementService: AIEnhancementService @Environment(\.modelContext) private var modelContext @StateObject private var whisperPrompt = WhisperPrompt() + var body: some View { ScrollView { @@ -32,6 +36,19 @@ struct ModelManagementView: View { secondaryButton: .cancel() ) } + .alert(item: $customModelToDelete) { model in + Alert( + title: Text("Delete Custom Model"), + message: Text("Are you sure you want to delete the custom model '\(model.displayName)'?"), + primaryButton: .destructive(Text("Delete")) { + customModelManager.removeCustomModel(withId: model.id) + // Update whisperState to refresh the UI + whisperState.refreshAllAvailableModels() + }, + secondaryButton: .cancel() + ) + } + } private var defaultModelSection: some View { @@ -76,7 +93,9 @@ struct ModelManagementView: View { downloadProgress: whisperState.downloadProgress, modelURL: whisperState.availableModels.first { $0.name == model.name }?.url, deleteAction: { - if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { + if model.provider == .custom, let customModel = model as? CustomCloudModel { + customModelToDelete = customModel + } else if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { modelToDelete = downloadedModel } }, @@ -91,9 +110,22 @@ struct ModelManagementView: View { await whisperState.downloadModel(localModel) } } - } + }, + editAction: model.provider == .custom ? { customModel in + customModelToEdit = customModel + } : nil ) } + + // Add Custom Model Card at the bottom + AddCustomModelCardView( + customModelManager: customModelManager, + editingModel: customModelToEdit + ) { + // Refresh the models when a new custom model is added + whisperState.refreshAllAvailableModels() + customModelToEdit = nil // Clear editing state + } } } .padding() diff --git a/VoiceInk/Whisper/WhisperState+ModelQueries.swift b/VoiceInk/Whisper/WhisperState+ModelQueries.swift index b74666a..689b28c 100644 --- a/VoiceInk/Whisper/WhisperState+ModelQueries.swift +++ b/VoiceInk/Whisper/WhisperState+ModelQueries.swift @@ -15,6 +15,9 @@ extension WhisperState { case .deepgram: let key = UserDefaults.standard.string(forKey: "DeepgramAPIKey") return key != nil && !key!.isEmpty + case .custom: + // Custom models are always usable since they contain their own API keys + return true } } } diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index 9e4e502..5cb3b5e 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -114,6 +114,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { createRecordingsDirectoryIfNeeded() loadAvailableModels() loadCurrentTranscriptionModel() + refreshAllAvailableModels() } private func createRecordingsDirectoryIfNeeded() { @@ -512,6 +513,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { func getEnhancementService() -> AIEnhancementService? { return enhancementService } + + func refreshAllAvailableModels() { + allAvailableModels = PredefinedModels.models + } } struct WhisperModel: Identifiable {