diff --git a/VoiceInk/Models/PredefinedModels.swift b/VoiceInk/Models/PredefinedModels.swift index 8c4b54a..b064c95 100644 --- a/VoiceInk/Models/PredefinedModels.swift +++ b/VoiceInk/Models/PredefinedModels.swift @@ -1,6 +1,6 @@ import Foundation - struct PredefinedModel: Identifiable, Hashable { + struct LocalModel: TranscriptionModel { let id = UUID() let name: String let displayName: String @@ -11,6 +11,7 @@ import Foundation let accuracy: Double let ramUsage: Double let hash: String + let provider: ModelProvider = .local var downloadURL: String { "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)" @@ -20,21 +21,9 @@ import Foundation "\(name).bin" } - func hash(into hasher: inout Hasher) { - hasher.combine(id) - } - - static func == (lhs: PredefinedModel, rhs: PredefinedModel) -> Bool { - lhs.id == rhs.id - } - var isMultilingualModel: Bool { supportedLanguages.count > 1 } - - var language: String { - isMultilingualModel ? "Multilingual" : "English-only" - } } enum PredefinedModels { @@ -53,8 +42,9 @@ import Foundation } } - static let models: [PredefinedModel] = [ - PredefinedModel( + static let models: [any TranscriptionModel] = [ + // Local Models + LocalModel( name: "ggml-tiny", displayName: "Tiny", size: "75 MiB", @@ -65,7 +55,7 @@ import Foundation ramUsage: 0.3, hash: "bd577a113a864445d4c299885e0cb97d4ba92b5f" ), - PredefinedModel( + LocalModel( name: "ggml-tiny.en", displayName: "Tiny (English)", size: "75 MiB", @@ -76,7 +66,7 @@ import Foundation ramUsage: 0.3, hash: "c78c86eb1a8faa21b369bcd33207cc90d64ae9df" ), - PredefinedModel( + LocalModel( name: "ggml-base.en", displayName: "Base (English)", size: "142 MiB", @@ -87,7 +77,7 @@ import Foundation ramUsage: 0.5, hash: "137c40403d78fd54d454da0f9bd998f78703390c" ), - PredefinedModel( + LocalModel( name: "ggml-large-v2", displayName: "Large v2", size: "2.9 GiB", @@ -98,7 +88,7 @@ import Foundation ramUsage: 3.8, hash: "0f4c8e30f21cf1769f637135f521436792c48186" ), - PredefinedModel( + LocalModel( name: "ggml-large-v3", displayName: "Large v3", size: "2.9 GiB", @@ -109,7 +99,7 @@ import Foundation ramUsage: 3.9, hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062" ), - PredefinedModel( + LocalModel( name: "ggml-large-v3-turbo", displayName: "Large v3 Turbo", size: "1.5 GiB", @@ -121,7 +111,7 @@ import Foundation ramUsage: 1.8, hash: "4af2b29d7ec73d781377bfd1758ca957a807e941" ), - PredefinedModel( + LocalModel( name: "ggml-large-v3-turbo-q5_0", displayName: "Large v3 Turbo (Quantized)", size: "547 MiB", @@ -132,6 +122,29 @@ import Foundation ramUsage: 1.0, hash: "e050f7970618a659205450ad97eb95a18d69c9ee" ), + + // Cloud Models + CloudModel( + name: "whisper-large-v3-turbo", + displayName: "Whisper Large v3 Turbo (Groq)", + description: "Groq's ultra-fast Whisper Large v3 Turbo model with lightning-speed inference", + provider: .groq, + speed: 0.7, + accuracy: 0.96, + isMultilingual: true, + supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true) + ), + CloudModel( + name: "scribe_v1", + displayName: "Scribe v1 (ElevenLabs)", + description: "ElevenLabs' Scribe model for fast and accurate transcription.", + provider: .elevenLabs, + speed: 0.8, + accuracy: 0.95, + isMultilingual: true, + supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true) + ), + ] static let allLanguages = [ diff --git a/VoiceInk/Models/TranscriptionModel.swift b/VoiceInk/Models/TranscriptionModel.swift new file mode 100644 index 0000000..47d53e0 --- /dev/null +++ b/VoiceInk/Models/TranscriptionModel.swift @@ -0,0 +1,57 @@ +import Foundation + +// Enum to differentiate between model providers +enum ModelProvider: String, Codable, Hashable, CaseIterable { + case local = "Local" + case groq = "Groq" + case elevenLabs = "ElevenLabs" + // Future providers can be added here +} + +// A unified protocol for any transcription model +protocol TranscriptionModel: Identifiable, Hashable { + var id: UUID { get } + var name: String { get } + var displayName: String { get } + var description: String { get } + var provider: ModelProvider { get } + + // Language capabilities + var isMultilingualModel: Bool { get } + var supportedLanguages: [String: String] { get } +} + +extension TranscriptionModel { + func hash(into hasher: inout Hasher) { + hasher.combine(id) + } + + var language: String { + isMultilingualModel ? "Multilingual" : "English-only" + } +} + +// A new struct for cloud models +struct CloudModel: TranscriptionModel { + let id: UUID + let name: String + let displayName: String + let description: String + let provider: ModelProvider + let speed: Double + let accuracy: Double + let isMultilingualModel: Bool + let supportedLanguages: [String: String] + + init(id: UUID = UUID(), name: String, displayName: String, description: String, provider: ModelProvider, speed: Double, accuracy: Double, isMultilingual: Bool, supportedLanguages: [String: String]) { + self.id = id + self.name = name + self.displayName = displayName + self.description = description + self.provider = provider + self.speed = speed + self.accuracy = accuracy + self.isMultilingualModel = isMultilingual + self.supportedLanguages = supportedLanguages + } +} \ No newline at end of file diff --git a/VoiceInk/PowerMode/ActiveWindowService.swift b/VoiceInk/PowerMode/ActiveWindowService.swift index 0a2d655..21a21e7 100644 --- a/VoiceInk/PowerMode/ActiveWindowService.swift +++ b/VoiceInk/PowerMode/ActiveWindowService.swift @@ -116,20 +116,35 @@ class ActiveWindowService: ObservableObject { } if let whisperState = self.whisperState, - let modelName = config.selectedWhisperModel { - let models = await MainActor.run { whisperState.availableModels } - if let selectedModel = models.first(where: { $0.name == modelName }) { + let modelName = config.selectedWhisperModel, + let selectedModel = await whisperState.allAvailableModels.first(where: { $0.name == modelName }) { + + let currentModelName = await MainActor.run { whisperState.currentTranscriptionModel?.name } + + // Only change the model if it's different from the current one. + if currentModelName != modelName { + // Set the new model as default. This works for both local and cloud models. + await whisperState.setDefaultTranscriptionModel(selectedModel) - let currentModelName = await MainActor.run { whisperState.currentModel?.name } - if currentModelName != modelName { - await whisperState.setDefaultModel(selectedModel) + // The cleanup and load cycle is only necessary for local models. + if selectedModel.provider == .local { + // Unload any previously loaded model to free up memory. await whisperState.cleanupModelResources() - do { - try await whisperState.loadModel(selectedModel) - } catch { - + // Load the new local model into memory. + if let localModel = await whisperState.availableModels.first(where: { $0.name == selectedModel.name }) { + do { + try await whisperState.loadModel(localModel) + logger.info("✅ Power Mode: Successfully loaded local model '\(localModel.name)'.") + } catch { + logger.error("❌ Power Mode: Failed to load local model '\(localModel.name)': \(error.localizedDescription)") + } } + } else { + // For cloud models, no in-memory loading is needed, but we should still + // clean up if the *previous* model was a local one. + await whisperState.cleanupModelResources() + logger.info("✅ Power Mode: Switched to cloud model '\(selectedModel.name)'. No local load needed.") } } } diff --git a/VoiceInk/PowerMode/PowerModeConfigView.swift b/VoiceInk/PowerMode/PowerModeConfigView.swift index 6ac9f59..dcf2aae 100644 --- a/VoiceInk/PowerMode/PowerModeConfigView.swift +++ b/VoiceInk/PowerMode/PowerModeConfigView.swift @@ -375,9 +375,12 @@ struct ConfigurationView: View { .foregroundColor(.secondary) Picker("", selection: modelBinding) { - ForEach(whisperState.availableModels) { model in - let displayName = whisperState.predefinedModels.first { $0.name == model.name }?.displayName ?? model.name - Text(displayName).tag(model.name as String?) + Text("Default") + .tag(nil as String?) + + ForEach(whisperState.usableModels, id: \.name) { model in + Text(model.displayName) + .tag(model.name as String?) } } .labelsHidden() @@ -387,7 +390,7 @@ struct ConfigurationView: View { // Language Selection Subsection if let selectedModel = effectiveModelName, - let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }), + let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }), modelInfo.isMultilingualModel { // Create a simple binding that uses UserDefaults language if nil @@ -416,7 +419,7 @@ struct ConfigurationView: View { .frame(maxWidth: .infinity) } } else if let selectedModel = effectiveModelName, - let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }), + let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }), !modelInfo.isMultilingualModel { // Silently set to English without showing UI EmptyView() @@ -488,7 +491,7 @@ struct ConfigurationView: View { .frame(maxWidth: .infinity, alignment: .leading) } else { Picker("", selection: providerBinding) { - ForEach(aiService.connectedProviders, id: \.self) { provider in + ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in Text(provider.rawValue).tag(provider) } } diff --git a/VoiceInk/PowerMode/PowerModeViewComponents.swift b/VoiceInk/PowerMode/PowerModeViewComponents.swift index 2b77bc2..1f317b7 100644 --- a/VoiceInk/PowerMode/PowerModeViewComponents.swift +++ b/VoiceInk/PowerMode/PowerModeViewComponents.swift @@ -105,7 +105,7 @@ struct ConfigurationRow: View { private var selectedModel: String? { if let modelName = config.selectedWhisperModel, - let model = whisperState.predefinedModels.first(where: { $0.name == modelName }) { + let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }) { return model.displayName } return "Default" @@ -117,7 +117,7 @@ struct ConfigurationRow: View { if langCode == "en" { return "English" } if let modelName = config.selectedWhisperModel, - let model = whisperState.predefinedModels.first(where: { $0.name == modelName }), + let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }), let langName = model.supportedLanguages[langCode] { return langName } diff --git a/VoiceInk/Services/AIService.swift b/VoiceInk/Services/AIService.swift index fbdd54f..c85125c 100644 --- a/VoiceInk/Services/AIService.swift +++ b/VoiceInk/Services/AIService.swift @@ -9,6 +9,7 @@ enum AIProvider: String, CaseIterable { case anthropic = "Anthropic" case mistral = "Mistral" case ollama = "Ollama" + case elevenLabs = "ElevenLabs" case custom = "Custom" var baseURL: String { @@ -25,6 +26,8 @@ enum AIProvider: String, CaseIterable { return "https://api.anthropic.com/v1/messages" case .mistral: return "https://api.mistral.ai/v1/chat/completions" + case .elevenLabs: + return "https://api.elevenlabs.io/v1/speech-to-text" case .ollama: return UserDefaults.standard.string(forKey: "ollamaBaseURL") ?? "http://localhost:11434" case .custom: @@ -46,6 +49,8 @@ enum AIProvider: String, CaseIterable { return "claude-3-5-sonnet-20241022" case .mistral: return "mistral-large-latest" + case .elevenLabs: + return "scribe_v1" case .ollama: return UserDefaults.standard.string(forKey: "ollamaSelectedModel") ?? "mistral" case .custom: @@ -88,6 +93,8 @@ enum AIProvider: String, CaseIterable { "mistral-small-latest", "mistral-saba-latest" ] + case .elevenLabs: + return ["scribe_v1", "scribe_v1_experimental"] case .ollama: return [] case .custom: @@ -257,6 +264,8 @@ class AIService: ObservableObject { verifyGeminiAPIKey(key, completion: completion) case .anthropic: verifyAnthropicAPIKey(key, completion: completion) + case .elevenLabs: + verifyElevenLabsAPIKey(key, completion: completion) default: verifyOpenAICompatibleAPIKey(key, completion: completion) } @@ -369,6 +378,30 @@ class AIService: ObservableObject { }.resume() } + private func verifyElevenLabsAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { + let url = URL(string: "https://api.elevenlabs.io/v1/user")! + var request = URLRequest(url: url) + request.httpMethod = "GET" + request.addValue("application/json", forHTTPHeaderField: "Content-Type") + request.addValue(key, forHTTPHeaderField: "xi-api-key") + + URLSession.shared.dataTask(with: request) { _, response, error in + if let error = error { + self.logger.error("ElevenLabs API key verification failed: \(error.localizedDescription)") + completion(false) + return + } + + if let httpResponse = response as? HTTPURLResponse { + completion(httpResponse.statusCode == 200) + } else { + completion(false) + } + }.resume() + } + + + func clearAPIKey() { guard selectedProvider.requiresAPIKey else { return } diff --git a/VoiceInk/Services/AudioProcessor.swift b/VoiceInk/Services/AudioFileProcessor.swift similarity index 81% rename from VoiceInk/Services/AudioProcessor.swift rename to VoiceInk/Services/AudioFileProcessor.swift index 74acff7..4fe96f0 100644 --- a/VoiceInk/Services/AudioProcessor.swift +++ b/VoiceInk/Services/AudioFileProcessor.swift @@ -173,4 +173,47 @@ class AudioProcessor { logger.notice("✅ Successfully converted \(samples.count) samples") return samples } -} \ No newline at end of file + func saveSamplesAsWav(samples: [Float], to url: URL) throws { + let outputFormat = AVAudioFormat( + commonFormat: .pcmFormatInt16, + sampleRate: AudioFormat.targetSampleRate, + channels: AudioFormat.targetChannels, + interleaved: true + ) + + guard let outputFormat = outputFormat else { + throw AudioProcessingError.unsupportedFormat + } + + let buffer = AVAudioPCMBuffer( + pcmFormat: outputFormat, + frameCapacity: AVAudioFrameCount(samples.count) + ) + + guard let buffer = buffer else { + throw AudioProcessingError.conversionFailed + } + + // Convert float samples to int16 + let int16Samples = samples.map { max(-1.0, min(1.0, $0)) * Float(Int16.max) }.map { Int16($0) } + + // Copy samples to buffer + int16Samples.withUnsafeBufferPointer { int16Buffer in + let int16Pointer = int16Buffer.baseAddress! + buffer.int16ChannelData![0].update(from: int16Pointer, count: int16Samples.count) + } + buffer.frameLength = AVAudioFrameCount(samples.count) + + // Create audio file + let audioFile = try AVAudioFile( + forWriting: url, + settings: outputFormat.settings, + commonFormat: .pcmFormatInt16, + interleaved: true + ) + + try audioFile.write(from: buffer) + logger.notice("✅ Successfully saved processed audio to \(url.lastPathComponent)") + } +} + diff --git a/VoiceInk/Services/AudioTranscriptionManager.swift b/VoiceInk/Services/AudioFileTranscriptionManager.swift similarity index 84% rename from VoiceInk/Services/AudioTranscriptionManager.swift rename to VoiceInk/Services/AudioFileTranscriptionManager.swift index 0770c45..283388e 100644 --- a/VoiceInk/Services/AudioTranscriptionManager.swift +++ b/VoiceInk/Services/AudioFileTranscriptionManager.swift @@ -15,10 +15,13 @@ class AudioTranscriptionManager: ObservableObject { @Published var errorMessage: String? private var currentTask: Task? - private var whisperContext: WhisperContext? private let audioProcessor = AudioProcessor() private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager") + // Transcription services - will be initialized when needed + private var localTranscriptionService: LocalTranscriptionService? + private let cloudTranscriptionService = CloudTranscriptionService() + enum ProcessingPhase { case idle case loading @@ -58,12 +61,14 @@ class AudioTranscriptionManager: ObservableObject { currentTask = Task { do { - guard let currentModel = whisperState.currentModel else { + guard let currentModel = whisperState.currentTranscriptionModel else { throw TranscriptionError.noModelSelected } - // Load Whisper model - whisperContext = try await WhisperContext.createContext(path: currentModel.url.path) + // Initialize local transcription service if needed + if localTranscriptionService == nil { + localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState) + } // Process audio file processingPhase = .processingAudio @@ -82,15 +87,19 @@ class AudioTranscriptionManager: ObservableObject { let permanentURL = recordingsDirectory.appendingPathComponent(fileName) try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true) - try FileManager.default.copyItem(at: url, to: permanentURL) + try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL) - // Transcribe + // Transcribe using appropriate service processingPhase = .transcribing - await whisperContext?.setPrompt(whisperState.whisperPrompt.transcriptionPrompt) - try await whisperContext?.fullTranscribe(samples: samples) - var text = await whisperContext?.getTranscription() ?? "" + var text: String + + if currentModel.provider == .local { + text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel) + } else { + text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel) + } + text = text.trimmingCharacters(in: .whitespacesAndNewlines) - text = WhisperTextFormatter.format(text) // Apply word replacements if enabled if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") { @@ -115,7 +124,6 @@ class AudioTranscriptionManager: ObservableObject { currentTranscription = transcription } catch { logger.error("Enhancement failed: \(error.localizedDescription)") - messageLog += "Enhancement failed: \(error.localizedDescription). Using original transcription.\n" let transcription = Transcription( text: text, duration: duration, @@ -148,28 +156,20 @@ class AudioTranscriptionManager: ObservableObject { func cancelProcessing() { currentTask?.cancel() - cleanupResources() } private func finishProcessing() { isProcessing = false processingPhase = .idle currentTask = nil - cleanupResources() } private func handleError(_ error: Error) { logger.error("Transcription error: \(error.localizedDescription)") errorMessage = error.localizedDescription - messageLog += "Error: \(error.localizedDescription)\n" isProcessing = false processingPhase = .idle currentTask = nil - cleanupResources() - } - - private func cleanupResources() { - whisperContext = nil } } diff --git a/VoiceInk/Services/AudioTranscriptionService.swift b/VoiceInk/Services/AudioFileTranscriptionService.swift similarity index 61% rename from VoiceInk/Services/AudioTranscriptionService.swift rename to VoiceInk/Services/AudioFileTranscriptionService.swift index e820dcf..759cbf6 100644 --- a/VoiceInk/Services/AudioTranscriptionService.swift +++ b/VoiceInk/Services/AudioFileTranscriptionService.swift @@ -10,12 +10,15 @@ class AudioTranscriptionService: ObservableObject { @Published var messageLog = "" @Published var currentError: TranscriptionError? - private var whisperContext: WhisperContext? private let modelContext: ModelContext private let enhancementService: AIEnhancementService? private let whisperState: WhisperState private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService") + // Transcription services + private let localTranscriptionService: LocalTranscriptionService + private let cloudTranscriptionService = CloudTranscriptionService() + enum TranscriptionError: Error { case noAudioFile case transcriptionFailed @@ -27,82 +30,67 @@ class AudioTranscriptionService: ObservableObject { self.modelContext = modelContext self.whisperState = whisperState self.enhancementService = whisperState.enhancementService + self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState) } - func retranscribeAudio(from url: URL, using whisperModel: WhisperModel) async throws -> Transcription { + func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription { guard FileManager.default.fileExists(atPath: url.path) else { throw TranscriptionError.noAudioFile } await MainActor.run { isTranscribing = true - messageLog = "Loading model...\n" + messageLog = "Starting retranscription...\n" } - // Load the whisper model if needed - if whisperContext == nil { - do { - whisperContext = try await WhisperContext.createContext(path: whisperModel.url.path) - messageLog += "Model loaded successfully.\n" - } catch { - logger.error("❌ Failed to load model: \(error.localizedDescription)") - messageLog += "Failed to load model: \(error.localizedDescription)\n" - isTranscribing = false - throw TranscriptionError.modelNotLoaded + do { + // Delegate transcription to appropriate service + var text: String + + if model.provider == .local { + messageLog += "Using local transcription service...\n" + text = try await localTranscriptionService.transcribe(audioURL: url, model: model) + messageLog += "Local transcription completed.\n" + } else { + messageLog += "Using cloud transcription service...\n" + text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model) + messageLog += "Cloud transcription completed.\n" } - } - - guard let whisperContext = whisperContext else { - isTranscribing = false - throw TranscriptionError.modelNotLoaded - } - - // Get audio duration - let audioAsset = AVURLAsset(url: url) - let duration = CMTimeGetSeconds(try await audioAsset.load(.duration)) - - // Create a permanent copy of the audio file - let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] - .appendingPathComponent("com.prakashjoshipax.VoiceInk") - .appendingPathComponent("Recordings") - - let fileName = "retranscribed_\(UUID().uuidString).wav" - let permanentURL = recordingsDirectory.appendingPathComponent(fileName) - - do { - try FileManager.default.copyItem(at: url, to: permanentURL) - } catch { - logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)") - messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n" - isTranscribing = false - throw error - } - - let permanentURLString = permanentURL.absoluteString - - // Transcribe the audio - messageLog += "Transcribing audio...\n" - - do { - // Read audio samples - let samples = try readAudioSamples(permanentURL) - // Process with Whisper - using the same prompt as WhisperState - messageLog += "Setting prompt: \(whisperState.whisperPrompt.transcriptionPrompt)\n" - await whisperContext.setPrompt(whisperState.whisperPrompt.transcriptionPrompt) - - try await whisperContext.fullTranscribe(samples: samples) - var text = await whisperContext.getTranscription() + // Common post-processing for both local and cloud transcriptions text = text.trimmingCharacters(in: .whitespacesAndNewlines) - logger.notice("✅ Retranscription completed successfully, length: \(text.count) characters") // Apply word replacements if enabled if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") { text = WordReplacementService.shared.applyReplacements(to: text) + messageLog += "Word replacements applied.\n" logger.notice("✅ Word replacements applied") } - // Apply AI enhancement if enabled - using the same enhancement service as WhisperState + // Get audio duration + let audioAsset = AVURLAsset(url: url) + let duration = CMTimeGetSeconds(try await audioAsset.load(.duration)) + + // Create a permanent copy of the audio file + let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] + .appendingPathComponent("com.prakashjoshipax.VoiceInk") + .appendingPathComponent("Recordings") + + let fileName = "retranscribed_\(UUID().uuidString).wav" + let permanentURL = recordingsDirectory.appendingPathComponent(fileName) + + do { + try FileManager.default.copyItem(at: url, to: permanentURL) + } catch { + logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)") + messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n" + isTranscribing = false + throw error + } + + let permanentURLString = permanentURL.absoluteString + + // Apply AI enhancement if enabled if let enhancementService = enhancementService, enhancementService.isEnhancementEnabled, enhancementService.isConfigured { @@ -182,19 +170,4 @@ class AudioTranscriptionService: ObservableObject { throw error } } - - private func readAudioSamples(_ url: URL) throws -> [Float] { - return try decodeWaveFile(url) - } - - private func decodeWaveFile(_ url: URL) throws -> [Float] { - let data = try Data(contentsOf: url) - let floats = stride(from: 44, to: data.count, by: 2).map { - return data[$0..<$0 + 2].withUnsafeBytes { - let short = Int16(littleEndian: $0.load(as: Int16.self)) - return max(-1.0, min(Float(short) / 32767.0, 1.0)) - } - } - return floats - } } diff --git a/VoiceInk/Services/CloudTranscriptionService.swift b/VoiceInk/Services/CloudTranscriptionService.swift new file mode 100644 index 0000000..ee66625 --- /dev/null +++ b/VoiceInk/Services/CloudTranscriptionService.swift @@ -0,0 +1,211 @@ +import Foundation +import os + +enum CloudTranscriptionError: Error, LocalizedError { + case unsupportedProvider + case missingAPIKey + case invalidAPIKey + case audioFileNotFound + case apiRequestFailed(statusCode: Int, message: String) + case networkError(Error) + case noTranscriptionReturned + case dataEncodingError + + var errorDescription: String? { + switch self { + case .unsupportedProvider: + return "The model provider is not supported by this service." + case .missingAPIKey: + return "API key for this service is missing. Please configure it in the settings." + case .invalidAPIKey: + return "The provided API key is invalid." + case .audioFileNotFound: + return "The audio file to transcribe could not be found." + case .apiRequestFailed(let statusCode, let message): + return "The API request failed with status code \(statusCode): \(message)" + case .networkError(let error): + return "A network error occurred: \(error.localizedDescription)" + case .noTranscriptionReturned: + return "The API returned an empty or invalid response." + case .dataEncodingError: + return "Failed to encode the request body." + } + } +} + +class CloudTranscriptionService: TranscriptionService { + + private struct APIConfig { + let url: URL + let apiKey: String + let modelName: String + } + + func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { + let config = try getAPIConfig(for: model) + + var request: URLRequest + var body: Data + + switch model.provider { + case .elevenLabs: + let boundary = "Boundary-\(UUID().uuidString)" + var elevenLabsRequest = URLRequest(url: config.url) + elevenLabsRequest.httpMethod = "POST" + elevenLabsRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") + elevenLabsRequest.setValue(config.apiKey, forHTTPHeaderField: "xi-api-key") + body = try createElevenLabsRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary) + request = elevenLabsRequest + + case .groq: + let boundary = "Boundary-\(UUID().uuidString)" + var openAICompatibleRequest = URLRequest(url: config.url) + openAICompatibleRequest.httpMethod = "POST" + openAICompatibleRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") + openAICompatibleRequest.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization") + body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary) + request = openAICompatibleRequest + + + + default: + throw CloudTranscriptionError.unsupportedProvider + } + + let (data, response) = try await URLSession.shared.upload(for: request, from: body) + guard let httpResponse = response as? HTTPURLResponse else { + throw CloudTranscriptionError.networkError(URLError(.badServerResponse)) + } + + if !(200...299).contains(httpResponse.statusCode) { + let errorMessage = String(data: data, encoding: .utf8) ?? "No error message" + throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage) + } + + do { + let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data) + return transcriptionResponse.text + } catch { + throw CloudTranscriptionError.noTranscriptionReturned + } + } + + private func getAPIConfig(for model: any TranscriptionModel) throws -> APIConfig { + let providerKey: String + let apiURL: URL + + switch model.provider { + case .groq: + providerKey = "GROQ" + apiURL = URL(string: "https://api.groq.com/openai/v1/audio/transcriptions")! + case .elevenLabs: + providerKey = "ElevenLabs" + apiURL = URL(string: "https://api.elevenlabs.io/v1/speech-to-text")! + default: + throw CloudTranscriptionError.unsupportedProvider + } + + guard let apiKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey"), !apiKey.isEmpty else { + throw CloudTranscriptionError.missingAPIKey + } + + return APIConfig(url: apiURL, apiKey: apiKey, modelName: model.name) + } + + private func createElevenLabsRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data { + var body = Data() + let crlf = "\r\n" + + guard let audioData = try? Data(contentsOf: audioURL) else { + throw CloudTranscriptionError.audioFileNotFound + } + + // File + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!) + body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!) + body.append(audioData) + body.append(crlf.data(using: .utf8)!) + + // Model ID + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"model_id\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(modelName.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + + let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto" + if selectedLanguage != "auto", !selectedLanguage.isEmpty { + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"language_code\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(selectedLanguage.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + } + + body.append("--\(boundary)--\(crlf)".data(using: .utf8)!) + + return body + } + + private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data { + var body = Data() + let crlf = "\r\n" + + guard let audioData = try? Data(contentsOf: audioURL) else { + throw CloudTranscriptionError.audioFileNotFound + } + + let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto" + let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? "" + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!) + body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!) + body.append(audioData) + body.append(crlf.data(using: .utf8)!) + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(modelName.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + + if selectedLanguage != "auto", !selectedLanguage.isEmpty { + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(selectedLanguage.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + } + + // Include prompt for OpenAI-compatible APIs + if !prompt.isEmpty { + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append(prompt.data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + } + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append("json".data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + + body.append("--\(boundary)\(crlf)".data(using: .utf8)!) + body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!) + body.append("0".data(using: .utf8)!) + body.append(crlf.data(using: .utf8)!) + body.append("--\(boundary)--\(crlf)".data(using: .utf8)!) + + return body + } + + private struct TranscriptionResponse: Decodable { + let text: String + let language: String? + let duration: Double? + let x_groq: GroqMetadata? + + struct GroqMetadata: Decodable { + let id: String? + } + } + + +} \ No newline at end of file diff --git a/VoiceInk/Services/LocalTranscriptionService.swift b/VoiceInk/Services/LocalTranscriptionService.swift new file mode 100644 index 0000000..bba3600 --- /dev/null +++ b/VoiceInk/Services/LocalTranscriptionService.swift @@ -0,0 +1,87 @@ +import Foundation +import AVFoundation +import os + +class LocalTranscriptionService: TranscriptionService { + + private var whisperContext: WhisperContext? + private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "LocalTranscriptionService") + private let modelsDirectory: URL + private weak var whisperState: WhisperState? + + init(modelsDirectory: URL, whisperState: WhisperState? = nil) { + self.modelsDirectory = modelsDirectory + self.whisperState = whisperState + } + + func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String { + guard let localModel = model as? LocalModel else { + throw WhisperError.couldNotInitializeContext + } + + logger.notice("Initiating local transcription for model: \(localModel.displayName)") + + // Check if the required model is already loaded in WhisperState + if let whisperState = whisperState, + await whisperState.isModelLoaded, + let loadedContext = await whisperState.whisperContext, + await whisperState.currentModel?.name == localModel.name { + + logger.notice("✅ Using already loaded model: \(localModel.name)") + whisperContext = loadedContext + } else { + // Model not loaded or wrong model loaded, proceed with loading + let modelURL = modelsDirectory.appendingPathComponent(localModel.filename) + + guard FileManager.default.fileExists(atPath: modelURL.path) else { + logger.error("Model file not found at path: \(modelURL.path)") + throw WhisperError.couldNotInitializeContext + } + + logger.notice("Loading model: \(localModel.name)") + do { + whisperContext = try await WhisperContext.createContext(path: modelURL.path) + } catch { + logger.error("Failed to load model: \(localModel.name) - \(error.localizedDescription)") + throw WhisperError.couldNotInitializeContext + } + } + + guard let whisperContext = whisperContext else { + logger.error("Cannot transcribe: Model could not be loaded") + throw WhisperError.couldNotInitializeContext + } + + // Read audio data + let data = try readAudioSamples(audioURL) + + // Set prompt + let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? "" + await whisperContext.setPrompt(currentPrompt) + + // Transcribe + await whisperContext.fullTranscribe(samples: data) + let text = await whisperContext.getTranscription() + + logger.notice("✅ Local transcription completed successfully.") + + // Only release resources if we created a new context (not using the shared one) + if await whisperState?.whisperContext !== whisperContext { + await whisperContext.releaseResources() + self.whisperContext = nil + } + + return text + } + + private func readAudioSamples(_ url: URL) throws -> [Float] { + let data = try Data(contentsOf: url) + let floats = stride(from: 44, to: data.count, by: 2).map { + return data[$0..<$0 + 2].withUnsafeBytes { + let short = Int16(littleEndian: $0.load(as: Int16.self)) + return max(-1.0, min(Float(short) / 32767.0, 1.0)) + } + } + return floats + } +} \ No newline at end of file diff --git a/VoiceInk/Services/TranscriptionService.swift b/VoiceInk/Services/TranscriptionService.swift new file mode 100644 index 0000000..5823a87 --- /dev/null +++ b/VoiceInk/Services/TranscriptionService.swift @@ -0,0 +1,14 @@ +import Foundation + +/// A protocol defining the interface for a transcription service. +/// This allows for a unified way to handle both local and cloud-based transcription models. +protocol TranscriptionService { + /// Transcribes the audio from a given file URL. + /// + /// - Parameters: + /// - audioURL: The URL of the audio file to transcribe. + /// - model: The `TranscriptionModel` to use for transcription. This provides context about the provider (local, OpenAI, etc.). + /// - Returns: The transcribed text as a `String`. + /// - Throws: An error if the transcription fails. + func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String +} \ No newline at end of file diff --git a/VoiceInk/Views/APIKeyManagementView.swift b/VoiceInk/Views/APIKeyManagementView.swift index 70c15be..902a1c7 100644 --- a/VoiceInk/Views/APIKeyManagementView.swift +++ b/VoiceInk/Views/APIKeyManagementView.swift @@ -44,7 +44,7 @@ struct APIKeyManagementView: View { // Provider Selection Picker("AI Provider", selection: $aiService.selectedProvider) { - ForEach(AIProvider.allCases, id: \.self) { provider in + ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in Text(provider.rawValue).tag(provider) } } @@ -385,6 +385,8 @@ struct APIKeyManagementView: View { URL(string: "https://console.anthropic.com/settings/keys")! case .mistral: URL(string: "https://console.mistral.ai/api-keys")! + case .elevenLabs: + URL(string: "https://elevenlabs.io/speech-synthesis")! case .ollama, .custom: URL(string: "")! // This case should never be reached } diff --git a/VoiceInk/Views/AudioPlayerView.swift b/VoiceInk/Views/AudioPlayerView.swift index a87c062..ff3399c 100644 --- a/VoiceInk/Views/AudioPlayerView.swift +++ b/VoiceInk/Views/AudioPlayerView.swift @@ -388,7 +388,7 @@ struct AudioPlayerView: View { } private func retranscribeAudio() { - guard let currentModel = whisperState.currentModel else { + guard let currentTranscriptionModel = whisperState.currentTranscriptionModel else { errorMessage = "No transcription model selected" showRetranscribeError = true DispatchQueue.main.asyncAfter(deadline: .now() + 3) { @@ -401,7 +401,7 @@ struct AudioPlayerView: View { Task { do { - let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentModel) + let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentTranscriptionModel) await MainActor.run { isRetranscribing = false showRetranscribeSuccess = true diff --git a/VoiceInk/Views/LanguageSelectionView.swift b/VoiceInk/Views/LanguageSelectionView.swift index 64bdeed..4a157b4 100644 --- a/VoiceInk/Views/LanguageSelectionView.swift +++ b/VoiceInk/Views/LanguageSelectionView.swift @@ -26,23 +26,18 @@ struct LanguageSelectionView: View { // Function to check if current model is multilingual private func isMultilingualModel() -> Bool { - guard let currentModel = whisperState.currentModel, - let predefinedModel = PredefinedModels.models.first(where: { $0.name == currentModel.name }) else { + guard let currentModel = whisperState.currentTranscriptionModel else { return false } - return predefinedModel.isMultilingualModel + return currentModel.isMultilingualModel } // Function to get current model's supported languages private func getCurrentModelLanguages() -> [String: String] { - guard let currentModel = whisperState.currentModel, - let predefinedModel = PredefinedModels.models.first(where: { - $0.name == currentModel.name - }) - else { + guard let currentModel = whisperState.currentTranscriptionModel else { return ["en": "English"] // Default to English if no model found } - return predefinedModel.supportedLanguages + return currentModel.supportedLanguages } // Get the display name of the current language @@ -74,16 +69,13 @@ struct LanguageSelectionView: View { Text("Transcription Language") .font(.headline) - if let currentModel = whisperState.currentModel, - let predefinedModel = PredefinedModels.models.first(where: { - $0.name == currentModel.name - }) + if let currentModel = whisperState.currentTranscriptionModel { if isMultilingualModel() { VStack(alignment: .leading, spacing: 8) { Picker("Select Language", selection: $selectedLanguage) { ForEach( - predefinedModel.supportedLanguages.sorted(by: { + currentModel.supportedLanguages.sorted(by: { if $0.key == "auto" { return true } if $1.key == "auto" { return false } return $0.value < $1.value @@ -97,7 +89,7 @@ struct LanguageSelectionView: View { updateLanguage(newValue) } - Text("Current model: \(predefinedModel.displayName)") + Text("Current model: \(currentModel.displayName)") .font(.caption) .foregroundColor(.secondary) @@ -114,7 +106,7 @@ struct LanguageSelectionView: View { .font(.subheadline) .foregroundColor(.primary) - Text("Current model: \(predefinedModel.displayName)") + Text("Current model: \(currentModel.displayName)") .font(.caption) .foregroundColor(.secondary) diff --git a/VoiceInk/Views/MenuBarView.swift b/VoiceInk/Views/MenuBarView.swift index 5d58f7f..f06ed7d 100644 --- a/VoiceInk/Views/MenuBarView.swift +++ b/VoiceInk/Views/MenuBarView.swift @@ -80,26 +80,21 @@ struct MenuBarView: View { } Menu { - ForEach(whisperState.availableModels) { model in + ForEach(whisperState.usableModels, id: \.id) { model in Button { Task { - await whisperState.setDefaultModel(model) + await whisperState.setDefaultTranscriptionModel(model) } } label: { HStack { - Text(PredefinedModels.models.first { $0.name == model.name }?.displayName ?? model.name) - if whisperState.currentModel?.name == model.name { + Text(model.displayName) + if whisperState.currentTranscriptionModel?.id == model.id { Image(systemName: "checkmark") } } } } - if whisperState.availableModels.isEmpty { - Text("No models downloaded") - .foregroundColor(.secondary) - } - Divider() Button("Manage Models") { @@ -107,7 +102,7 @@ struct MenuBarView: View { } } label: { HStack { - Text("Model: \(PredefinedModels.models.first { $0.name == whisperState.currentModel?.name }?.displayName ?? "None")") + Text("Model: \(whisperState.currentTranscriptionModel?.displayName ?? "None")") Image(systemName: "chevron.up.chevron.down") .font(.system(size: 10)) } diff --git a/VoiceInk/Views/ModelCardRowView.swift b/VoiceInk/Views/ModelCardRowView.swift index 2b0bc1e..9771e6c 100644 --- a/VoiceInk/Views/ModelCardRowView.swift +++ b/VoiceInk/Views/ModelCardRowView.swift @@ -2,7 +2,49 @@ import SwiftUI import AppKit struct ModelCardRowView: View { - let model: PredefinedModel + let model: any TranscriptionModel + let isDownloaded: Bool + let isCurrent: Bool + let downloadProgress: [String: Double] + let modelURL: URL? + + // Actions + var deleteAction: () -> Void + var setDefaultAction: () -> Void + var downloadAction: () -> Void + + var body: some View { + Group { + switch model.provider { + case .local: + if let localModel = model as? LocalModel { + LocalModelCardView( + model: localModel, + isDownloaded: isDownloaded, + isCurrent: isCurrent, + downloadProgress: downloadProgress, + modelURL: modelURL, + deleteAction: deleteAction, + setDefaultAction: setDefaultAction, + downloadAction: downloadAction + ) + } + case .groq, .elevenLabs: + if let cloudModel = model as? CloudModel { + CloudModelCardView( + model: cloudModel, + isCurrent: isCurrent, + setDefaultAction: setDefaultAction + ) + } + } + } + } +} + +// MARK: - Local Model Card View +struct LocalModelCardView: View { + let model: LocalModel let isDownloaded: Bool let isCurrent: Bool let downloadProgress: [String: Double] @@ -36,8 +78,6 @@ struct ModelCardRowView: View { .background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent)) } - // MARK: - Components - private var headerSection: some View { HStack(alignment: .firstTextBaseline) { Text(model.displayName) @@ -162,14 +202,210 @@ struct ModelCardRowView: View { Label("Delete Model", systemImage: "trash") } - if isDownloaded { - Button { - if let modelURL = modelURL { - NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "") - } - } label: { - Label("Show in Finder", systemImage: "folder") + Button { + if let modelURL = modelURL { + NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "") } + } label: { + Label("Show in Finder", systemImage: "folder") + } + } label: { + Image(systemName: "ellipsis.circle") + .font(.system(size: 14)) + } + .menuStyle(.borderlessButton) + .menuIndicator(.hidden) + .frame(width: 20, height: 20) + } + } + } +} + +// MARK: - Cloud Model Card View +struct CloudModelCardView: View { + let model: CloudModel + let isCurrent: Bool + var setDefaultAction: () -> Void + + @StateObject private var aiService = AIService() + @State private var isExpanded = false + @State private var apiKey = "" + @State private var isVerifying = false + @State private var verificationStatus: VerificationStatus = .none + + enum VerificationStatus { + case none, verifying, success, failure + } + + private var isConfigured: Bool { + guard let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") else { + return false + } + return !savedKey.isEmpty + } + + private var providerKey: String { + switch model.provider { + case .groq: + return "GROQ" + case .elevenLabs: + return "ElevenLabs" + default: + return model.provider.rawValue + } + } + + var body: some View { + VStack(alignment: .leading, spacing: 0) { + // Main card content + HStack(alignment: .top, spacing: 16) { + VStack(alignment: .leading, spacing: 6) { + headerSection + metadataSection + descriptionSection + } + .frame(maxWidth: .infinity, alignment: .leading) + + actionSection + } + .padding(16) + + // Expandable configuration section + if isExpanded { + Divider() + .padding(.horizontal, 16) + + configurationSection + .padding(16) + } + } + .background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent)) + .onAppear { + loadSavedAPIKey() + } + } + + private var headerSection: some View { + HStack(alignment: .firstTextBaseline) { + Text(model.displayName) + .font(.system(size: 13, weight: .semibold)) + .foregroundColor(Color(.labelColor)) + + statusBadge + + Spacer() + } + } + + private var statusBadge: some View { + Group { + if isCurrent { + Text("Default") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color.accentColor)) + .foregroundColor(.white) + } else if isConfigured { + Text("Configured") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color(.systemGreen).opacity(0.2))) + .foregroundColor(Color(.systemGreen)) + } else { + Text("Setup Required") + .font(.system(size: 11, weight: .medium)) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Capsule().fill(Color(.systemOrange).opacity(0.2))) + .foregroundColor(Color(.systemOrange)) + } + } + } + + private var metadataSection: some View { + HStack(spacing: 16) { + // Provider + Label(model.provider.rawValue, systemImage: "cloud") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + + // Language + Label(model.language, systemImage: "globe") + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + + // Speed + HStack(spacing: 4) { + Text("Speed") + .font(.system(size: 11, weight: .medium)) + .foregroundColor(Color(.secondaryLabelColor)) + progressDotsWithNumber(value: model.speed * 10) + } + + // Accuracy + HStack(spacing: 4) { + Text("Accuracy") + .font(.system(size: 11, weight: .medium)) + .foregroundColor(Color(.secondaryLabelColor)) + progressDotsWithNumber(value: model.accuracy * 10) + } + } + } + + private var descriptionSection: some View { + Text(model.description) + .font(.system(size: 11)) + .foregroundColor(Color(.secondaryLabelColor)) + .lineLimit(2) + .fixedSize(horizontal: false, vertical: true) + .padding(.top, 4) + } + + private var actionSection: some View { + HStack(spacing: 8) { + if isCurrent { + Text("Default Model") + .font(.system(size: 12)) + .foregroundColor(Color(.secondaryLabelColor)) + } else if isConfigured { + Button(action: setDefaultAction) { + Text("Set as Default") + .font(.system(size: 12)) + } + .buttonStyle(.bordered) + .controlSize(.small) + } else { + Button(action: { + withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) { + isExpanded.toggle() + } + }) { + HStack(spacing: 4) { + Text("Configure") + .font(.system(size: 12, weight: .medium)) + Image(systemName: "gear") + .font(.system(size: 12, weight: .medium)) + } + .foregroundColor(.white) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background( + Capsule() + .fill(Color(.controlAccentColor)) + .shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1) + ) + } + .buttonStyle(.plain) + } + + if isConfigured { + Menu { + Button { + clearAPIKey() + } label: { + Label("Remove API Key", systemImage: "trash") } } label: { Image(systemName: "ellipsis.circle") @@ -182,40 +418,134 @@ struct ModelCardRowView: View { } } - // MARK: - Helpers - - private var downloadComponents: [(String, Double)] { - [ - ("Model", downloadProgress[model.name + "_main"] ?? 0), - ("CoreML", downloadProgress[model.name + "_coreml"] ?? 0) - ].filter { $0.1 > 0 } - } - - private func progressDotsWithNumber(value: Double) -> some View { - HStack(spacing: 4) { - progressDots(value: value) - Text(String(format: "%.1f", value)) - .font(.system(size: 10, weight: .medium, design: .monospaced)) - .foregroundColor(Color(.secondaryLabelColor)) - } - } - - private func progressDots(value: Double) -> some View { - HStack(spacing: 2) { - ForEach(0..<5) { index in - Circle() - .fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor)) - .frame(width: 6, height: 6) + private var configurationSection: some View { + VStack(alignment: .leading, spacing: 12) { + Text("API Key Configuration") + .font(.system(size: 13, weight: .semibold)) + .foregroundColor(Color(.labelColor)) + + HStack(spacing: 8) { + SecureField("Enter your \(model.provider.rawValue) API key", text: $apiKey) + .textFieldStyle(.roundedBorder) + .disabled(isVerifying) + + Button(action: verifyAPIKey) { + HStack(spacing: 4) { + if isVerifying { + ProgressView() + .scaleEffect(0.7) + .frame(width: 12, height: 12) + } else { + Image(systemName: verificationStatus == .success ? "checkmark" : "checkmark.shield") + .font(.system(size: 12, weight: .medium)) + } + Text(isVerifying ? "Verifying..." : "Verify") + .font(.system(size: 12, weight: .medium)) + } + .foregroundColor(.white) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background( + Capsule() + .fill(verificationStatus == .success ? Color(.systemGreen) : Color(.controlAccentColor)) + ) + } + .buttonStyle(.plain) + .disabled(apiKey.isEmpty || isVerifying) + } + + if verificationStatus == .failure { + Text("Invalid API key. Please check your key and try again.") + .font(.caption) + .foregroundColor(Color(.systemRed)) + } else if verificationStatus == .success { + Text("API key verified successfully!") + .font(.caption) + .foregroundColor(Color(.systemGreen)) } } } - private func performanceColor(value: Double) -> Color { - switch value { - case 0.8...1.0: return Color(.systemGreen) - case 0.6..<0.8: return Color(.systemYellow) - case 0.4..<0.6: return Color(.systemOrange) - default: return Color(.systemRed) + private func loadSavedAPIKey() { + if let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") { + apiKey = savedKey + verificationStatus = .success } } + + private func verifyAPIKey() { + guard !apiKey.isEmpty else { return } + + isVerifying = true + verificationStatus = .verifying + + // Set the provider in AIService temporarily for verification + let originalProvider = aiService.selectedProvider + if model.provider == .groq { + aiService.selectedProvider = .groq + } else if model.provider == .elevenLabs { + aiService.selectedProvider = .elevenLabs + } + + aiService.verifyAPIKey(apiKey) { [self] isValid in + DispatchQueue.main.async { + self.isVerifying = false + if isValid { + self.verificationStatus = .success + // Save the API key + UserDefaults.standard.set(self.apiKey, forKey: "\(self.providerKey)APIKey") + + // Collapse the configuration section after successful verification + withAnimation(.easeInOut(duration: 0.3)) { + self.isExpanded = false + } + } else { + self.verificationStatus = .failure + } + + // Restore original provider + aiService.selectedProvider = originalProvider + } + } + } + + private func clearAPIKey() { + UserDefaults.standard.removeObject(forKey: "\(providerKey)APIKey") + apiKey = "" + verificationStatus = .none + + withAnimation(.easeInOut(duration: 0.3)) { + isExpanded = false + } + } +} + +// MARK: - Helper Views and Functions + +private func progressDotsWithNumber(value: Double) -> some View { + HStack(spacing: 4) { + progressDots(value: value) + Text(String(format: "%.1f", value)) + .font(.system(size: 10, weight: .medium, design: .monospaced)) + .foregroundColor(Color(.secondaryLabelColor)) + } +} + +private func progressDots(value: Double) -> some View { + HStack(spacing: 2) { + ForEach(0..<5) { index in + Circle() + .fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor)) + .frame(width: 6, height: 6) + } + } +} + +private func performanceColor(value: Double) -> Color { + switch value { + case 0.8...1.0: return Color(.systemGreen) + case 0.6..<0.8: return Color(.systemYellow) + case 0.4..<0.6: return Color(.systemOrange) + default: return Color(.systemRed) + } } diff --git a/VoiceInk/Views/ModelManagementView.swift b/VoiceInk/Views/ModelManagementView.swift index 339a6a1..5fa5e59 100644 --- a/VoiceInk/Views/ModelManagementView.swift +++ b/VoiceInk/Views/ModelManagementView.swift @@ -39,9 +39,7 @@ struct ModelManagementView: View { Text("Default Model") .font(.headline) .foregroundColor(.secondary) - Text(whisperState.currentModel.flatMap { model in - PredefinedModels.models.first { $0.name == model.name }?.displayName - } ?? "No model selected") + Text(whisperState.currentTranscriptionModel?.displayName ?? "No model selected") .font(.title2) .fontWeight(.bold) } @@ -62,7 +60,7 @@ struct ModelManagementView: View { .font(.title3) .fontWeight(.semibold) - Text("(\(whisperState.predefinedModels.count))") + Text("(\(whisperState.allAvailableModels.count))") .foregroundColor(.secondary) .font(.subheadline) @@ -70,11 +68,11 @@ struct ModelManagementView: View { } VStack(spacing: 12) { - ForEach(whisperState.predefinedModels) { model in + ForEach(whisperState.allAvailableModels, id: \.id) { model in ModelCardRowView( model: model, isDownloaded: whisperState.availableModels.contains { $0.name == model.name }, - isCurrent: whisperState.currentModel?.name == model.name, + isCurrent: whisperState.currentTranscriptionModel?.name == model.name, downloadProgress: whisperState.downloadProgress, modelURL: whisperState.availableModels.first { $0.name == model.name }?.url, deleteAction: { @@ -83,15 +81,15 @@ struct ModelManagementView: View { } }, setDefaultAction: { - if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) { - Task { - await whisperState.setDefaultModel(downloadedModel) - } + Task { + await whisperState.setDefaultTranscriptionModel(model) } }, downloadAction: { - Task { - await whisperState.downloadModel(model) + if let localModel = model as? LocalModel { + Task { + await whisperState.downloadModel(localModel) + } } } ) diff --git a/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift b/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift index 5a37f32..22d992a 100644 --- a/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift +++ b/VoiceInk/Views/Onboarding/OnboardingModelDownloadView.swift @@ -9,7 +9,7 @@ struct OnboardingModelDownloadView: View { @State private var isModelSet = false @State private var showTutorial = false - private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" }! + private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" } as! LocalModel var body: some View { ZStack { diff --git a/VoiceInk/Whisper/WhisperState+ModelManager.swift b/VoiceInk/Whisper/WhisperState+ModelManager.swift index 5408972..215b39d 100644 --- a/VoiceInk/Whisper/WhisperState+ModelManager.swift +++ b/VoiceInk/Whisper/WhisperState+ModelManager.swift @@ -122,7 +122,7 @@ extension WhisperState { } // Shows an alert about Core ML support and first-run optimization - private func showCoreMLAlert(for model: PredefinedModel, completion: @escaping () -> Void) { + private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) { Task { @MainActor in let alert = NSAlert() alert.messageText = "Core ML Support for \(model.displayName) Model" @@ -138,7 +138,7 @@ extension WhisperState { } } - func downloadModel(_ model: PredefinedModel) async { + func downloadModel(_ model: LocalModel) async { guard let url = URL(string: model.downloadURL) else { return } // Check if model supports Core ML (non-quantized models) @@ -160,7 +160,7 @@ extension WhisperState { } } - private func performModelDownload(_ model: PredefinedModel, _ url: URL) async { + private func performModelDownload(_ model: LocalModel, _ url: URL) async { do { let whisperModel = try await downloadMainModel(model, from: url) @@ -176,7 +176,7 @@ extension WhisperState { } } - private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel { + private func downloadMainModel(_ model: LocalModel, from url: URL) async throws -> WhisperModel { let progressKeyMain = model.name + "_main" let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain) @@ -232,8 +232,7 @@ extension WhisperState { return model } - private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) { - currentError = .modelDownloadFailed + private func handleModelDownloadError(_ model: LocalModel, _ error: Error) { self.downloadProgress.removeValue(forKey: model.name + "_main") self.downloadProgress.removeValue(forKey: model.name + "_coreml") } @@ -262,7 +261,6 @@ extension WhisperState { } } catch { logError("Error deleting model: \(model.name)", error) - currentError = .modelDeletionFailed } } diff --git a/VoiceInk/Whisper/WhisperState+ModelQueries.swift b/VoiceInk/Whisper/WhisperState+ModelQueries.swift new file mode 100644 index 0000000..4058b9b --- /dev/null +++ b/VoiceInk/Whisper/WhisperState+ModelQueries.swift @@ -0,0 +1,18 @@ +import Foundation + +extension WhisperState { + var usableModels: [any TranscriptionModel] { + allAvailableModels.filter { model in + switch model.provider { + case .local: + return availableModels.contains { $0.name == model.name } + case .groq: + let key = UserDefaults.standard.string(forKey: "GROQAPIKey") + return key != nil && !key!.isEmpty + case .elevenLabs: + let key = UserDefaults.standard.string(forKey: "ElevenLabsAPIKey") + return key != nil && !key!.isEmpty + } + } + } +} diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index bfbfaf6..e2b03ef 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -12,9 +12,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { @Published var canTranscribe = false @Published var isRecording = false @Published var currentModel: WhisperModel? + @Published var currentTranscriptionModel: (any TranscriptionModel)? @Published var isModelLoading = false @Published var availableModels: [WhisperModel] = [] - @Published var predefinedModels: [PredefinedModel] = PredefinedModels.models + @Published var allAvailableModels: [any TranscriptionModel] = PredefinedModels.models @Published var clipboardMessage = "" @Published var miniRecorderError: String? @Published var isProcessing = false @@ -53,6 +54,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { let modelContext: ModelContext + // Transcription Services + private var localTranscriptionService: LocalTranscriptionService + private let cloudTranscriptionService = CloudTranscriptionService() + private var modelUrl: URL? { let possibleURLs = [ Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "Models"), @@ -85,21 +90,28 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { init(modelContext: ModelContext, enhancementService: AIEnhancementService? = nil) { self.modelContext = modelContext - self.modelsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] + let appSupportDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] .appendingPathComponent("com.prakashjoshipax.VoiceInk") - .appendingPathComponent("WhisperModels") - self.recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] - .appendingPathComponent("com.prakashjoshipax.VoiceInk") - .appendingPathComponent("Recordings") + + self.modelsDirectory = appSupportDirectory.appendingPathComponent("WhisperModels") + self.recordingsDirectory = appSupportDirectory.appendingPathComponent("Recordings") + + // Initialize services without whisperState reference first + self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory) + self.enhancementService = enhancementService self.licenseViewModel = LicenseViewModel() super.init() + // Set the whisperState reference after super.init() + self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self) + setupNotifications() createModelsDirectoryIfNeeded() createRecordingsDirectoryIfNeeded() loadAvailableModels() + loadCurrentTranscriptionModel() if let savedModelName = UserDefaults.standard.string(forKey: "CurrentModel"), let savedModel = availableModels.first(where: { $0.name == savedModelName }) { @@ -139,11 +151,11 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { logger.error("❌ No recorded file found after stopping recording") } } else { - guard currentModel != nil else { + guard currentTranscriptionModel != nil else { await MainActor.run { let alert = NSAlert() - alert.messageText = "No Whisper Model Selected" - alert.informativeText = "Please select a default whisper model in AI Models tab before recording." + alert.messageText = "No AI Model Selected" + alert.informativeText = "Please select a default AI model in AI Models tab before recording." alert.alertStyle = .warning alert.addButton(withTitle: "OK") alert.runModal() @@ -174,11 +186,15 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { await ActiveWindowService.shared.applyConfigurationForCurrentApp() - if let currentModel = await self.currentModel, await self.whisperContext == nil { - do { - try await self.loadModel(currentModel) - } catch { - self.logger.error("❌ Model loading failed: \(error.localizedDescription)") + // Only load model if it's a local model and not already loaded + if let model = self.currentTranscriptionModel, model.provider == .local { + if let localWhisperModel = self.availableModels.first(where: { $0.name == model.name }), + self.whisperContext == nil { + do { + try await self.loadModel(localWhisperModel) + } catch { + self.logger.error("❌ Model loading failed: \(error.localizedDescription)") + } } } @@ -256,11 +272,13 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { await cleanupModelResources() return } + await MainActor.run { isProcessing = true isTranscribing = true canTranscribe = false } + defer { if shouldCancelRecording { Task { @@ -268,56 +286,33 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { } } } - guard let currentModel = currentModel else { + + guard let model = currentTranscriptionModel else { logger.error("❌ Cannot transcribe: No model selected") - currentError = .modelLoadFailed return } - if whisperContext == nil { - logger.notice("🔄 Model not loaded yet, attempting to load now: \(currentModel.name)") - do { - try await loadModel(currentModel) - } catch { - logger.error("❌ Failed to load model: \(currentModel.name) - \(error.localizedDescription)") - currentError = .modelLoadFailed - return - } - } - guard let whisperContext = whisperContext else { - logger.error("❌ Cannot transcribe: Model could not be loaded") - currentError = .modelLoadFailed - return - } - logger.notice("🔄 Starting transcription with model: \(currentModel.name)") + + logger.notice("🔄 Starting transcription with model: \(model.displayName)") + do { - let permanentURL = try saveRecordingPermanently(url) - let permanentURLString = permanentURL.absoluteString - if shouldCancelRecording { return } - let data = try readAudioSamples(url) - if shouldCancelRecording { return } - - // Get the actual audio duration from the file - let audioAsset = AVURLAsset(url: url) - let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration)) - logger.notice("📊 Audio file duration: \(actualDuration) seconds") - - // Ensure we're using the most recent prompt from UserDefaults - let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? whisperPrompt.transcriptionPrompt - await whisperContext.setPrompt(currentPrompt) - - if shouldCancelRecording { return } - await whisperContext.fullTranscribe(samples: data) - if shouldCancelRecording { return } - var text = await whisperContext.getTranscription() + // --- Core Transcription Logic --- + let transcriptionService: TranscriptionService = (model.provider == .local) ? localTranscriptionService : cloudTranscriptionService + var text = try await transcriptionService.transcribe(audioURL: url, model: model) text = text.trimmingCharacters(in: .whitespacesAndNewlines) + logger.notice("✅ Transcription completed successfully, length: \(text.count) characters") + + // --- Post-processing and Saving --- + let permanentURL = try saveRecordingPermanently(url) if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") { text = WordReplacementService.shared.applyReplacements(to: text) logger.notice("✅ Word replacements applied") } + let audioAsset = AVURLAsset(url: url) + let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration)) var promptDetectionResult: PromptDetectionService.PromptDetectionResult? = nil - let originalText = text + let originalText = text if let enhancementService = enhancementService, enhancementService.isConfigured { let detectionResult = promptDetectionService.analyzeText(text, with: enhancementService) @@ -330,23 +325,22 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { enhancementService.isConfigured { do { if shouldCancelRecording { return } - // Use processed text (without trigger words) for AI enhancement let textForAI = promptDetectionResult?.processedText ?? text let enhancedText = try await enhancementService.enhance(textForAI) let newTranscription = Transcription( - text: originalText, + text: originalText, duration: actualDuration, enhancedText: enhancedText, - audioFileURL: permanentURLString + audioFileURL: permanentURL.absoluteString ) modelContext.insert(newTranscription) try? modelContext.save() - text = enhancedText + text = enhancedText } catch { let newTranscription = Transcription( text: originalText, duration: actualDuration, - audioFileURL: permanentURLString + audioFileURL: permanentURL.absoluteString ) modelContext.insert(newTranscription) try? modelContext.save() @@ -355,11 +349,12 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { let newTranscription = Transcription( text: originalText, duration: actualDuration, - audioFileURL: permanentURLString + audioFileURL: permanentURL.absoluteString ) modelContext.insert(newTranscription) try? modelContext.save() } + if case .trialExpired = licenseViewModel.licenseState { text = """ Your trial has expired. Upgrade to VoiceInk Pro at tryvoiceink.com/buy @@ -367,7 +362,6 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { """ } - // Add a space to the end of the text text += " " SoundManager.shared.playStopSound() @@ -386,8 +380,8 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { } try? FileManager.default.removeItem(at: url) - if let result = promptDetectionResult, - let enhancementService = enhancementService, + if let result = promptDetectionResult, + let enhancementService = enhancementService, result.shouldEnableAI { await promptDetectionService.restoreOriginalSettings(result, to: enhancementService) } @@ -396,39 +390,64 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { await cleanupModelResources() } catch { - currentError = .transcriptionFailed + logger.error("❌ Transcription failed: \(error.localizedDescription)") await cleanupModelResources() await dismissMiniRecorder() } } - private func readAudioSamples(_ url: URL) throws -> [Float] { - return try decodeWaveFile(url) - } - - private func decodeWaveFile(_ url: URL) throws -> [Float] { - let data = try Data(contentsOf: url) - let floats = stride(from: 44, to: data.count, by: 2).map { - return data[$0..<$0 + 2].withUnsafeBytes { - let short = Int16(littleEndian: $0.load(as: Int16.self)) - return max(-1.0, min(Float(short) / 32767.0, 1.0)) - } - } - return floats - } - - @Published var currentError: WhisperStateError? - - func getEnhancementService() -> AIEnhancementService? { - return enhancementService - } - private func saveRecordingPermanently(_ tempURL: URL) throws -> URL { let fileName = "\(UUID().uuidString).wav" let permanentURL = recordingsDirectory.appendingPathComponent(fileName) try FileManager.default.copyItem(at: tempURL, to: permanentURL) return permanentURL } + + private func loadCurrentTranscriptionModel() { + if let savedModelName = UserDefaults.standard.string(forKey: "CurrentTranscriptionModel"), + let savedModel = allAvailableModels.first(where: { $0.name == savedModelName }) { + currentTranscriptionModel = savedModel + + // If it's a local model, also set it as currentModel for backward compatibility + if let localModel = savedModel as? LocalModel, + let whisperModel = availableModels.first(where: { $0.name == localModel.name }) { + currentModel = whisperModel + } + } + } + + // Function to set any transcription model as default + func setDefaultTranscriptionModel(_ model: any TranscriptionModel) async { + await MainActor.run { + self.currentTranscriptionModel = model + UserDefaults.standard.set(model.name, forKey: "CurrentTranscriptionModel") + + // If it's a local model, also update currentModel for backward compatibility + if let localModel = model as? LocalModel, + let whisperModel = self.availableModels.first(where: { $0.name == localModel.name }) { + self.currentModel = whisperModel + UserDefaults.standard.set(whisperModel.name, forKey: "CurrentModel") + } else { + // For cloud models, clear the old currentModel + self.currentModel = nil + } + + // Enable transcription for cloud models immediately since they don't need loading + if model.provider != .local { + self.canTranscribe = true + self.isModelLoaded = true + } + } + + logger.info("Default transcription model set to: \(model.name) (\(model.provider.rawValue))") + + // Post notification about the model change + NotificationCenter.default.post(name: .didChangeModel, object: nil, userInfo: ["modelName": model.name]) + } + + func getEnhancementService() -> AIEnhancementService? { + return enhancementService + } } struct WhisperModel: Identifiable { @@ -473,4 +492,5 @@ private class TaskDelegate: NSObject, URLSessionTaskDelegate { extension Notification.Name { static let toggleMiniRecorder = Notification.Name("toggleMiniRecorder") + static let didChangeModel = Notification.Name("didChangeModel") }