diff --git a/VoiceInk/Services/AIEnhancementService.swift b/VoiceInk/Services/AIEnhancementService.swift index ccd91cf..f338b26 100644 --- a/VoiceInk/Services/AIEnhancementService.swift +++ b/VoiceInk/Services/AIEnhancementService.swift @@ -17,16 +17,12 @@ class AIEnhancementService: ObservableObject { @Published var isEnhancementEnabled: Bool { didSet { UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled") - // When enhancement is enabled, ensure a prompt is selected if isEnhancementEnabled && selectedPromptId == nil { - // Select the first prompt (default) if none is selected selectedPromptId = customPrompts.first?.id } - // Cancel any existing capture task currentCaptureTask?.cancel() - // Trigger screen capture when enhancement is enabled and screen capture is on if isEnhancementEnabled && useScreenCaptureContext { currentCaptureTask = Task { await captureScreenContext() @@ -71,7 +67,6 @@ class AIEnhancementService: ObservableObject { } var allPrompts: [CustomPrompt] { - // Always include the latest default prompt first, followed by custom prompts PredefinedPrompts.createDefaultPrompts() + customPrompts.filter { !$0.isPredefined } } @@ -80,7 +75,7 @@ class AIEnhancementService: ObservableObject { private var currentCaptureTask: Task? private let maxRetries = 3 private let baseTimeout: TimeInterval = 4 - private let rateLimitInterval: TimeInterval = 1.0 // 1 request per second + private let rateLimitInterval: TimeInterval = 1.0 private var lastRequestTime: Date? private let modelContext: ModelContext @@ -89,20 +84,11 @@ class AIEnhancementService: ObservableObject { self.modelContext = modelContext self.screenCaptureService = ScreenCaptureService() - // Print UserDefaults domain - if let domain = Bundle.main.bundleIdentifier { - print("⚙️ UserDefaults domain: \(domain)") - if let prefsPath = NSSearchPathForDirectoriesInDomains(.libraryDirectory, .userDomainMask, true).first { - print("⚙️ Preferences directory: \(prefsPath)/Preferences/\(domain).plist") - } - } - self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled") self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext") self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext") self.assistantTriggerWord = UserDefaults.standard.string(forKey: "assistantTriggerWord") ?? "hey" - // Load only custom prompts (non-predefined ones) if let savedPromptsData = UserDefaults.standard.data(forKey: "customPrompts"), let decodedPrompts = try? JSONDecoder().decode([CustomPrompt].self, from: savedPromptsData) { self.customPrompts = decodedPrompts @@ -110,18 +96,14 @@ class AIEnhancementService: ObservableObject { self.customPrompts = [] } - // Load selected prompt ID if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") { self.selectedPromptId = UUID(uuidString: savedPromptId) } - // Ensure a prompt is selected if enhancement is enabled if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) { - // Set first prompt (default) as selected self.selectedPromptId = allPrompts.first?.id } - // Setup notification observer for API key changes NotificationCenter.default.addObserver( self, selector: #selector(handleAPIKeyChange), @@ -137,7 +119,6 @@ class AIEnhancementService: ObservableObject { @objc private func handleAPIKeyChange() { DispatchQueue.main.async { self.objectWillChange.send() - // Optionally disable enhancement if API key is cleared if !self.aiService.isAPIKeyValid { self.isEnhancementEnabled = false } @@ -159,70 +140,44 @@ class AIEnhancementService: ObservableObject { } private func determineMode(text: String) -> EnhancementMode { - // Only use AI assistant mode if text starts with configured trigger word - if text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) { - return .aiAssistant - } - return .transcriptionEnhancement + text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) ? .aiAssistant : .transcriptionEnhancement } private func getSystemMessage(for mode: EnhancementMode) -> String { - // Get clipboard context if enabled and available let clipboardContext = if useClipboardContext, let clipboardText = NSPasteboard.general.string(forType: .string), !clipboardText.isEmpty { - """ - - Available Clipboard Context: \(clipboardText) - """ + "\n\nAvailable Clipboard Context: \(clipboardText)" } else { "" } - // Get screen capture context if enabled and available let screenCaptureContext = if useScreenCaptureContext, let capturedText = screenCaptureService.lastCapturedText, !capturedText.isEmpty { - """ - - Active Window Context: \(capturedText) - """ + "\n\nActive Window Context: \(capturedText)" } else { "" } - // Format all context information together with instructions let contextSection = if !clipboardContext.isEmpty || !screenCaptureContext.isEmpty { - """ - - \(AIPrompts.contextInstructions) - - - \(clipboardContext) - \(screenCaptureContext) - - """ + "\n\n\(AIPrompts.contextInstructions)\n\n\(clipboardContext)\(screenCaptureContext)\n" } else { "" } switch mode { case .transcriptionEnhancement: - // Check if the active prompt is the Assistant prompt if let activePrompt = activePrompt, activePrompt.id == PredefinedPrompts.assistantPromptId { - // For the Assistant predefined prompt, use the assistant mode prompt directly - // This ensures proper system message formatting return AIPrompts.assistantMode + contextSection } - // For all other prompts, use the custom prompt template var systemMessage = String(format: AIPrompts.customPromptTemplate, activePrompt!.promptText) systemMessage += contextSection return systemMessage case .aiAssistant: - // For AI assistant mode, use the assistant mode prompt directly return AIPrompts.assistantMode + contextSection } } @@ -238,49 +193,37 @@ class AIEnhancementService: ObservableObject { throw EnhancementError.emptyText } - // Format transcript with boundary markers let formattedText = "\n\(text)\n" - - // Determine mode and get system message let mode = determineMode(text: text) let systemMessage = getSystemMessage(for: mode) - // Handle Ollama requests differently if aiService.selectedProvider == .ollama { - logger.notice("📤 Request to Ollama") - logger.notice("🤖 System: \(systemMessage, privacy: .public)") - logger.notice("📝 Sending: \(formattedText, privacy: .public)") do { let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage) - logger.notice("✅ Ollama enhancement successful") - logger.notice("📝 Received: \(result, privacy: .public)") return result } catch let error as LocalAIError { switch error { case .serviceUnavailable: - logger.error("🔌 Ollama service unavailable") throw EnhancementError.notConfigured case .modelNotFound: - logger.error("🤖 Ollama model not found") throw EnhancementError.enhancementFailed case .serverError: - logger.error("🔥 Ollama server error") throw EnhancementError.serverError default: - logger.error("❌ Ollama enhancement failed") throw EnhancementError.enhancementFailed } } } - // Handle cloud provider requests - // Wait for rate limit try await waitForRateLimit() - // Special handling for Gemini and Anthropic switch aiService.selectedProvider { case .gemini: - var urlComponents = URLComponents(string: aiService.selectedProvider.baseURL)! + let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" + let model = aiService.currentModel + let fullURL = "\(baseEndpoint)/\(model):generateContent" + + var urlComponents = URLComponents(string: fullURL)! urlComponents.queryItems = [URLQueryItem(name: "key", value: aiService.apiKey)] guard let url = urlComponents.url else { @@ -290,9 +233,7 @@ class AIEnhancementService: ObservableObject { var request = URLRequest(url: url) request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") - - let timeout = baseTimeout * pow(2.0, Double(retryCount)) - request.timeoutInterval = timeout + request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) let requestBody: [String: Any] = [ "contents": [ @@ -311,13 +252,9 @@ class AIEnhancementService: ObservableObject { request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { - logger.notice("📤 Request to Gemini") - logger.notice("🤖 System: \(systemMessage, privacy: .public)") - logger.notice("📝 Sending: \(formattedText, privacy: .public)") let (data, response) = try await URLSession.shared.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { - logger.error("❌ Invalid Gemini response") throw EnhancementError.invalidResponse } @@ -330,47 +267,33 @@ class AIEnhancementService: ObservableObject { let parts = content["parts"] as? [[String: Any]], let firstPart = parts.first, let enhancedText = firstPart["text"] as? String else { - logger.error("❌ Failed to parse Gemini response") throw EnhancementError.enhancementFailed } - let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) - logger.notice("✅ Gemini enhancement successful") - logger.notice("📝 Received: \(result, privacy: .public)") - return result + return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) case 401: - logger.error("🔒 Authentication failed") throw EnhancementError.authenticationFailed - case 429: - logger.error("⏳ Rate limit exceeded") throw EnhancementError.rateLimitExceeded - case 500...599: - logger.error("🔥 Server error (\(httpResponse.statusCode))") throw EnhancementError.serverError - default: - logger.error("❌ Unexpected status (\(httpResponse.statusCode))") throw EnhancementError.apiError } } catch let error as EnhancementError { throw error } catch { - logger.error("❌ Network error: \(error.localizedDescription)") - if retryCount < maxRetries { try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) return try await makeRequest(text: text, retryCount: retryCount + 1) } - throw EnhancementError.networkError } case .anthropic: let requestBody: [String: Any] = [ - "model": aiService.selectedProvider.defaultModel, + "model": aiService.currentModel, "max_tokens": 1024, "system": systemMessage, "messages": [ @@ -383,20 +306,13 @@ class AIEnhancementService: ObservableObject { request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key") request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version") - - let timeout = baseTimeout * pow(2.0, Double(retryCount)) - request.timeoutInterval = timeout - + request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { - logger.notice("📤 Request to Anthropic") - logger.notice("🤖 System: \(systemMessage, privacy: .public)") - logger.notice("📝 Sending: \(formattedText, privacy: .public)") let (data, response) = try await URLSession.shared.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { - logger.error("❌ Invalid Anthropic response") throw EnhancementError.invalidResponse } @@ -406,67 +322,45 @@ class AIEnhancementService: ObservableObject { let content = jsonResponse["content"] as? [[String: Any]], let firstContent = content.first, let enhancedText = firstContent["text"] as? String else { - logger.error("❌ Failed to parse Anthropic response") throw EnhancementError.enhancementFailed } - let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) - logger.notice("✅ Anthropic enhancement successful") - logger.notice("📝 Received: \(result, privacy: .public)") - return result + return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) case 401: - logger.error("🔒 Authentication failed") throw EnhancementError.authenticationFailed - case 429: - logger.error("⏳ Rate limit exceeded") throw EnhancementError.rateLimitExceeded - case 500...599: - logger.error("🔥 Server error (\(httpResponse.statusCode))") throw EnhancementError.serverError - default: - logger.error("❌ Unexpected status (\(httpResponse.statusCode))") throw EnhancementError.apiError } } catch let error as EnhancementError { throw error } catch { - logger.error("❌ Network error: \(error.localizedDescription)") - if retryCount < maxRetries { try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) return try await makeRequest(text: text, retryCount: retryCount + 1) } - throw EnhancementError.networkError } default: - // Handle OpenAI compatible providers let url = URL(string: aiService.selectedProvider.baseURL)! var request = URLRequest(url: url) request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization") - - // Set timeout based on retry count with exponential backoff - let timeout = baseTimeout * pow(2.0, Double(retryCount)) - request.timeoutInterval = timeout - - logger.debug("Full system message: \(systemMessage)") + request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) let messages: [[String: Any]] = [ ["role": "system", "content": systemMessage], ["role": "user", "content": formattedText] ] - logger.info("Making request to \(self.aiService.selectedProvider.rawValue) with text length: \(text.count) characters") - let requestBody: [String: Any] = [ - "model": aiService.selectedProvider.defaultModel, + "model": aiService.currentModel, "messages": messages, "temperature": 0.3, "frequency_penalty": 0.0, @@ -477,13 +371,9 @@ class AIEnhancementService: ObservableObject { request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { - logger.notice("📤 Request to \(self.aiService.selectedProvider.rawValue, privacy: .public)") - logger.notice("🤖 System: \(systemMessage, privacy: .public)") - logger.notice("📝 Sending: \(formattedText, privacy: .public)") let (data, response) = try await URLSession.shared.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { - logger.error("❌ Invalid response") throw EnhancementError.invalidResponse } @@ -494,87 +384,65 @@ class AIEnhancementService: ObservableObject { let firstChoice = choices.first, let message = firstChoice["message"] as? [String: Any], let enhancedText = message["content"] as? String else { - logger.error("❌ Failed to parse response") throw EnhancementError.enhancementFailed } - let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) - logger.notice("✅ Enhancement successful") - logger.notice("📝 Received: \(result, privacy: .public)") - return result + return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines) case 401: - logger.error("🔒 Authentication failed") throw EnhancementError.authenticationFailed - case 429: - logger.error("⏳ Rate limit exceeded") throw EnhancementError.rateLimitExceeded - case 500...599: - logger.error("🔥 Server error (\(httpResponse.statusCode))") throw EnhancementError.serverError - default: - logger.error("❌ Unexpected status (\(httpResponse.statusCode))") throw EnhancementError.apiError } } catch let error as EnhancementError { throw error } catch { - logger.error("❌ Network error: \(error.localizedDescription)") - if retryCount < maxRetries { try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) return try await makeRequest(text: text, retryCount: retryCount + 1) } - throw EnhancementError.networkError } } } func enhance(_ text: String) async throws -> String { + logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)") var retryCount = 0 while retryCount < maxRetries { do { - return try await makeRequest(text: text, retryCount: retryCount) + let result = try await makeRequest(text: text, retryCount: retryCount) + logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)") + return result } catch EnhancementError.rateLimitExceeded where retryCount < maxRetries - 1 { + logger.notice("⚠️ Rate limit exceeded, retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))") retryCount += 1 try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) continue } catch { + logger.notice("❌ AI enhancement failed: \(error.localizedDescription)") throw error } } + logger.notice("❌ AI enhancement failed: maximum retries exceeded") throw EnhancementError.maxRetriesExceeded } - // Add a new method to capture screen context func captureScreenContext() async { - // Only check for screen capture context toggle - guard useScreenCaptureContext else { - logger.notice("📷 Screen capture context is disabled") - return - } + guard useScreenCaptureContext else { return } - logger.notice("📷 Initiating screen capture for context") - // Wait for the screen capture to complete and check result if let capturedText = await screenCaptureService.captureAndExtractText() { - logger.notice("📷 Screen capture successful, got \(capturedText.count, privacy: .public) characters") - // Ensure we're on the main thread when updating published properties await MainActor.run { - // Manually trigger objectWillChange to ensure UI updates self.objectWillChange.send() } - } else { - logger.notice("📷 Screen capture failed or returned empty result") } } - // MARK: - Prompt Management - func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil) { let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false) customPrompts.append(newPrompt) @@ -584,10 +452,7 @@ class AIEnhancementService: ObservableObject { } func updatePrompt(_ prompt: CustomPrompt) { - // Don't allow updates to predefined prompts - if prompt.isPredefined { - return - } + if prompt.isPredefined { return } if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) { customPrompts[index] = prompt @@ -595,10 +460,7 @@ class AIEnhancementService: ObservableObject { } func deletePrompt(_ prompt: CustomPrompt) { - // Don't allow deletion of predefined prompts - if prompt.isPredefined { - return - } + if prompt.isPredefined { return } customPrompts.removeAll { $0.id == prompt.id } if selectedPromptId == prompt.id { diff --git a/VoiceInk/Services/AIPrompts.swift b/VoiceInk/Services/AIPrompts.swift deleted file mode 100644 index 5d22fd9..0000000 --- a/VoiceInk/Services/AIPrompts.swift +++ /dev/null @@ -1,37 +0,0 @@ -enum AIPrompts { - static let customPromptTemplate = """ - - Your task is to reformat and enhance the text provided within tags according to the following guidelines: - - %@ - - IMPORTANT: The input will be wrapped in tags to identify what needs enhancement. - Your response should ONLY be to enhance text WITHOUT any tags. - DO NOT include tags in your response. - - """ - - static let assistantMode = """ - - Provide a direct clear, and concise reply to the user's query. Use the available context if directly related to the user's query. - Remember to: - 1. Be helpful and informative - 2. Be accurate and precise - 3. Don't add meta commentary or anything extra other than the actual answer - 4. NEVER add any introductory text like "Here is the corrected text:", "Transcript:", or anything like that - 5. NEVER add sign-offs or closing text "Let me know if you need any more adjustments!", or anything like that except the actual answer. - 6. Maintain a friendly, casual tone - - """ - - static let contextInstructions = """ - - Your task is to work ONLY with the content within the tags. - - IMPORTANT: The information in section is ONLY for reference. - - NEVER include the context directly in your output - - Context should only help you better understand the user's query - - - """ -} diff --git a/VoiceInk/Services/AIService.swift b/VoiceInk/Services/AIService.swift index f9e9524..fbdd54f 100644 --- a/VoiceInk/Services/AIService.swift +++ b/VoiceInk/Services/AIService.swift @@ -1,4 +1,5 @@ import Foundation +import os enum AIProvider: String, CaseIterable { case groq = "GROQ" @@ -19,7 +20,7 @@ enum AIProvider: String, CaseIterable { case .deepSeek: return "https://api.deepseek.com/v1/chat/completions" case .gemini: - return "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" + return "https://generativelanguage.googleapis.com/v1beta/models" case .anthropic: return "https://api.anthropic.com/v1/messages" case .mistral: @@ -36,7 +37,7 @@ enum AIProvider: String, CaseIterable { case .groq: return "llama-3.3-70b-versatile" case .openAI: - return "gpt-4o-mini-2024-07-18" + return "gpt-4.1-mini" case .deepSeek: return "deepseek-chat" case .gemini: @@ -44,7 +45,7 @@ enum AIProvider: String, CaseIterable { case .anthropic: return "claude-3-5-sonnet-20241022" case .mistral: - return "mistral-large-2411" + return "mistral-large-latest" case .ollama: return UserDefaults.standard.string(forKey: "ollamaSelectedModel") ?? "mistral" case .custom: @@ -52,6 +53,48 @@ enum AIProvider: String, CaseIterable { } } + var availableModels: [String] { + switch self { + case .groq: + return [ + "llama-3.3-70b-versatile", + "llama-3.1-8b-instant" + ] + case .openAI: + return [ + "gpt-4.1", + "gpt-4.1-mini" + ] + case .deepSeek: + return [ + "deepseek-chat", + "deepseek-reasoner" + ] + case .gemini: + return [ + "gemini-2.5-flash-preview-04-17", + "gemini-2.0-flash", + "gemini-2.0-flash-lite" + ] + case .anthropic: + return [ + "claude-3-7-sonnet-latest", + "claude-3-5-haiku-latest", + "claude-3-5-sonnet-latest" + ] + case .mistral: + return [ + "mistral-large-latest", + "mistral-small-latest", + "mistral-saba-latest" + ] + case .ollama: + return [] + case .custom: + return [] + } + } + var requiresAPIKey: Bool { switch self { case .ollama: @@ -63,6 +106,8 @@ enum AIProvider: String, CaseIterable { } class AIService: ObservableObject { + private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AIService") + @Published var apiKey: String = "" @Published var isAPIKeyValid: Bool = false @Published var customBaseURL: String = UserDefaults.standard.string(forKey: "customProviderBaseURL") ?? "" { @@ -78,7 +123,6 @@ class AIService: ObservableObject { @Published var selectedProvider: AIProvider { didSet { userDefaults.set(selectedProvider.rawValue, forKey: "selectedAIProvider") - // Load API key for the selected provider if it requires one if selectedProvider.requiresAPIKey { if let savedKey = userDefaults.string(forKey: "\(selectedProvider.rawValue)APIKey") { self.apiKey = savedKey @@ -88,10 +132,8 @@ class AIService: ObservableObject { self.isAPIKeyValid = false } } else { - // For providers that don't require API key (like Ollama) self.apiKey = "" self.isAPIKeyValid = true - // Check Ollama connection if selectedProvider == .ollama { Task { await ollamaService.checkConnection() @@ -102,6 +144,7 @@ class AIService: ObservableObject { } } + @Published private var selectedModels: [AIProvider: String] = [:] private let userDefaults = UserDefaults.standard private let ollamaService = OllamaService() @@ -116,25 +159,37 @@ class AIService: ObservableObject { } } + var currentModel: String { + if let selectedModel = selectedModels[selectedProvider], + !selectedModel.isEmpty, + (selectedProvider == .ollama && !selectedModel.isEmpty) || availableModels.contains(selectedModel) { + return selectedModel + } + return selectedProvider.defaultModel + } + + var availableModels: [String] { + if selectedProvider == .ollama { + return ollamaService.availableModels.map { $0.name } + } + return selectedProvider.availableModels + } + init() { - // Load selected provider if let savedProvider = userDefaults.string(forKey: "selectedAIProvider"), let provider = AIProvider(rawValue: savedProvider) { self.selectedProvider = provider } else { - self.selectedProvider = .gemini // Default to Gemini + self.selectedProvider = .gemini } - // Load API key for the current provider if it requires one if selectedProvider.requiresAPIKey { if let savedKey = userDefaults.string(forKey: "\(selectedProvider.rawValue)APIKey") { self.apiKey = savedKey self.isAPIKeyValid = true } } else { - // For providers that don't require API key self.isAPIKeyValid = true - // Check Ollama connection if it's the selected provider if selectedProvider == .ollama { Task { await ollamaService.checkConnection() @@ -142,29 +197,48 @@ class AIService: ObservableObject { } } } + + loadSavedModelSelections() + } + + private func loadSavedModelSelections() { + for provider in AIProvider.allCases { + let key = "\(provider.rawValue)SelectedModel" + if let savedModel = userDefaults.string(forKey: key), !savedModel.isEmpty { + selectedModels[provider] = savedModel + } + } + } + + func selectModel(_ model: String) { + guard !model.isEmpty else { return } + + selectedModels[selectedProvider] = model + let key = "\(selectedProvider.rawValue)SelectedModel" + userDefaults.set(model, forKey: key) + + if selectedProvider == .ollama { + updateSelectedOllamaModel(model) + } + + objectWillChange.send() } func saveAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { - // Skip verification for providers that don't require API key guard selectedProvider.requiresAPIKey else { - print("📝 [\(selectedProvider.rawValue)] API key not required, skipping verification") completion(true) return } - print("🔑 [\(selectedProvider.rawValue)] Starting API key verification...") - // Verify the API key before saving verifyAPIKey(key) { [weak self] isValid in guard let self = self else { return } DispatchQueue.main.async { if isValid { - print("✅ [\(self.selectedProvider.rawValue)] API key verified successfully") self.apiKey = key self.isAPIKeyValid = true self.userDefaults.set(key, forKey: "\(self.selectedProvider.rawValue)APIKey") NotificationCenter.default.post(name: .aiProviderKeyChanged, object: nil) } else { - print("❌ [\(self.selectedProvider.rawValue)] API key verification failed") self.isAPIKeyValid = false } completion(isValid) @@ -173,18 +247,11 @@ class AIService: ObservableObject { } func verifyAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { - // Skip verification for providers that don't require API key guard selectedProvider.requiresAPIKey else { - print("📝 [\(selectedProvider.rawValue)] API key verification skipped - not required") completion(true) return } - print("🔍 [\(selectedProvider.rawValue)] Verifying API key...") - print("🌐 Using base URL: \(selectedProvider.baseURL)") - print("🤖 Using model: \(selectedProvider.defaultModel)") - - // Special handling for different providers switch selectedProvider { case .gemini: verifyGeminiAPIKey(key, completion: completion) @@ -203,7 +270,7 @@ class AIService: ObservableObject { request.addValue("Bearer \(key)", forHTTPHeaderField: "Authorization") let testBody: [String: Any] = [ - "model": selectedProvider.defaultModel, + "model": currentModel, "messages": [ ["role": "user", "content": "test"] ], @@ -212,19 +279,15 @@ class AIService: ObservableObject { request.httpBody = try? JSONSerialization.data(withJSONObject: testBody) - print("📤 Sending verification request...") URLSession.shared.dataTask(with: request) { data, response, error in if let error = error { - print("❌ Network error during verification: \(error.localizedDescription)") completion(false) return } if let httpResponse = response as? HTTPURLResponse { - print("📥 Received response with status code: \(httpResponse.statusCode)") completion(httpResponse.statusCode == 200) } else { - print("❌ Invalid response received") completion(false) } }.resume() @@ -239,7 +302,7 @@ class AIService: ObservableObject { request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version") let testBody: [String: Any] = [ - "model": selectedProvider.defaultModel, + "model": currentModel, "max_tokens": 1024, "system": "You are a test system.", "messages": [ @@ -249,26 +312,26 @@ class AIService: ObservableObject { request.httpBody = try? JSONSerialization.data(withJSONObject: testBody) - print("📤 Sending Anthropic verification request...") URLSession.shared.dataTask(with: request) { data, response, error in if let error = error { - print("❌ Network error during Anthropic verification: \(error.localizedDescription)") completion(false) return } if let httpResponse = response as? HTTPURLResponse { - print("📥 Received Anthropic response with status code: \(httpResponse.statusCode)") completion(httpResponse.statusCode == 200) } else { - print("❌ Invalid Anthropic response received") completion(false) } }.resume() } private func verifyGeminiAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { - var urlComponents = URLComponents(string: selectedProvider.baseURL)! + let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" + let model = currentModel + let fullURL = "\(baseEndpoint)/\(model):generateContent" + + var urlComponents = URLComponents(string: fullURL)! urlComponents.queryItems = [URLQueryItem(name: "key", value: key)] guard let url = urlComponents.url else { @@ -292,26 +355,21 @@ class AIService: ObservableObject { request.httpBody = try? JSONSerialization.data(withJSONObject: testBody) - print("📤 Sending Gemini verification request...") URLSession.shared.dataTask(with: request) { data, response, error in if let error = error { - print("❌ Network error during Gemini verification: \(error.localizedDescription)") completion(false) return } if let httpResponse = response as? HTTPURLResponse { - print("📥 Received Gemini response with status code: \(httpResponse.statusCode)") completion(httpResponse.statusCode == 200) } else { - print("❌ Invalid Gemini response received") completion(false) } }.resume() } func clearAPIKey() { - // Skip for providers that don't require API key guard selectedProvider.requiresAPIKey else { return } apiKey = "" @@ -320,7 +378,6 @@ class AIService: ObservableObject { NotificationCenter.default.post(name: .aiProviderKeyChanged, object: nil) } - // Add method to check Ollama connection func checkOllamaConnection(completion: @escaping (Bool) -> Void) { Task { [weak self] in guard let self = self else { return } @@ -331,31 +388,34 @@ class AIService: ObservableObject { } } - // Add method to get available Ollama models func fetchOllamaModels() async -> [OllamaService.OllamaModel] { await ollamaService.refreshModels() return ollamaService.availableModels } - // Add method to enhance text using Ollama func enhanceWithOllama(text: String, systemPrompt: String) async throws -> String { - return try await ollamaService.enhance(text, withSystemPrompt: systemPrompt) + logger.notice("🔄 Sending transcription to Ollama for enhancement (model: \(self.ollamaService.selectedModel))") + do { + let result = try await ollamaService.enhance(text, withSystemPrompt: systemPrompt) + logger.notice("✅ Ollama enhancement completed successfully (\(result.count) characters)") + return result + } catch { + logger.notice("❌ Ollama enhancement failed: \(error.localizedDescription)") + throw error + } } - // Add method to update Ollama base URL func updateOllamaBaseURL(_ newURL: String) { ollamaService.baseURL = newURL userDefaults.set(newURL, forKey: "ollamaBaseURL") } - // Add method to update selected Ollama model func updateSelectedOllamaModel(_ modelName: String) { ollamaService.selectedModel = modelName userDefaults.set(modelName, forKey: "ollamaSelectedModel") } } -// Add extension for notification name extension Notification.Name { static let aiProviderKeyChanged = Notification.Name("aiProviderKeyChanged") } diff --git a/VoiceInk/Views/APIKeyManagementView.swift b/VoiceInk/Views/APIKeyManagementView.swift index ec5c1be..9485a96 100644 --- a/VoiceInk/Views/APIKeyManagementView.swift +++ b/VoiceInk/Views/APIKeyManagementView.swift @@ -55,6 +55,20 @@ struct APIKeyManagementView: View { } } + // Model Selection - only show for standard providers with available models + if !aiService.availableModels.isEmpty && + aiService.selectedProvider != .ollama && + aiService.selectedProvider != .custom { + Picker("Model", selection: Binding( + get: { aiService.currentModel }, + set: { aiService.selectModel($0) } + )) { + ForEach(aiService.availableModels, id: \.self) { model in + Text(model).tag(model) + } + } + } + if aiService.selectedProvider == .ollama { // Ollama Configuration VStack(alignment: .leading, spacing: 16) { @@ -441,103 +455,105 @@ struct APIKeyManagementView: View { .padding() .background(Color.secondary.opacity(0.03)) .cornerRadius(12) - } else if aiService.isAPIKeyValid { - // API Key Display for other providers - VStack(alignment: .leading, spacing: 8) { - Text("API Key") - .font(.subheadline) - .foregroundColor(.secondary) - - HStack { - Text(String(repeating: "•", count: 40)) + } else { + // API Key Display for other providers if valid + if aiService.isAPIKeyValid { + VStack(alignment: .leading, spacing: 8) { + Text("API Key") + .font(.subheadline) + .foregroundColor(.secondary) + + HStack { + Text(String(repeating: "•", count: 40)) + .font(.system(.body, design: .monospaced)) + + Spacer() + + Button(action: { + aiService.clearAPIKey() + }) { + Label("Remove Key", systemImage: "trash") + .foregroundColor(.red) + } + .buttonStyle(.borderless) + } + } + } else { + // API Key Input for other providers + VStack(alignment: .leading, spacing: 8) { + Text("Enter your API Key") + .font(.subheadline) + .foregroundColor(.secondary) + + SecureField("API Key", text: $apiKey) + .textFieldStyle(RoundedBorderTextFieldStyle()) .font(.system(.body, design: .monospaced)) - Spacer() - - Button(action: { - aiService.clearAPIKey() - }) { - Label("Remove Key", systemImage: "trash") - .foregroundColor(.red) - } - .buttonStyle(.borderless) - } - } - } else { - // API Key Input for other providers - VStack(alignment: .leading, spacing: 8) { - Text("Enter your API Key") - .font(.subheadline) - .foregroundColor(.secondary) - - SecureField("API Key", text: $apiKey) - .textFieldStyle(RoundedBorderTextFieldStyle()) - .font(.system(.body, design: .monospaced)) - - HStack { - Button(action: { - isVerifying = true - aiService.saveAPIKey(apiKey) { success in - isVerifying = false - if !success { - alertMessage = "Invalid API key. Please check and try again." - showAlert = true + HStack { + Button(action: { + isVerifying = true + aiService.saveAPIKey(apiKey) { success in + isVerifying = false + if !success { + alertMessage = "Invalid API key. Please check and try again." + showAlert = true + } + apiKey = "" } - apiKey = "" - } - }) { - HStack { - if isVerifying { - ProgressView() - .scaleEffect(0.5) - .frame(width: 16, height: 16) - } else { - Image(systemName: "checkmark.circle.fill") + }) { + HStack { + if isVerifying { + ProgressView() + .scaleEffect(0.5) + .frame(width: 16, height: 16) + } else { + Image(systemName: "checkmark.circle.fill") + } + Text("Verify and Save") } - Text("Verify and Save") } - } - - Spacer() - - HStack(spacing: 8) { - Text(aiService.selectedProvider == .groq || aiService.selectedProvider == .gemini ? "Free" : "Paid") - .font(.caption2) - .foregroundColor(.secondary) - .padding(.horizontal, 6) - .padding(.vertical, 2) - .background(Color.secondary.opacity(0.1)) - .cornerRadius(4) - if aiService.selectedProvider != .ollama && aiService.selectedProvider != .custom { - Button { - let url = switch aiService.selectedProvider { - case .groq: - URL(string: "https://console.groq.com/keys")! - case .openAI: - URL(string: "https://platform.openai.com/api-keys")! - case .deepSeek: - URL(string: "https://platform.deepseek.com/api-keys")! - case .gemini: - URL(string: "https://makersuite.google.com/app/apikey")! - case .anthropic: - URL(string: "https://console.anthropic.com/settings/keys")! - case .mistral: - URL(string: "https://console.mistral.ai/api-keys")! - case .ollama, .custom: - URL(string: "")! // This case should never be reached - } - NSWorkspace.shared.open(url) - } label: { - HStack(spacing: 4) { - Text("Get API Key") - .foregroundColor(.accentColor) - Image(systemName: "arrow.up.right") - .font(.caption) - .foregroundColor(.accentColor) + Spacer() + + HStack(spacing: 8) { + Text(aiService.selectedProvider == .groq || aiService.selectedProvider == .gemini ? "Free" : "Paid") + .font(.caption2) + .foregroundColor(.secondary) + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(Color.secondary.opacity(0.1)) + .cornerRadius(4) + + if aiService.selectedProvider != .ollama && aiService.selectedProvider != .custom { + Button { + let url = switch aiService.selectedProvider { + case .groq: + URL(string: "https://console.groq.com/keys")! + case .openAI: + URL(string: "https://platform.openai.com/api-keys")! + case .deepSeek: + URL(string: "https://platform.deepseek.com/api-keys")! + case .gemini: + URL(string: "https://makersuite.google.com/app/apikey")! + case .anthropic: + URL(string: "https://console.anthropic.com/settings/keys")! + case .mistral: + URL(string: "https://console.mistral.ai/api-keys")! + case .ollama, .custom: + URL(string: "")! // This case should never be reached + } + NSWorkspace.shared.open(url) + } label: { + HStack(spacing: 4) { + Text("Get API Key") + .foregroundColor(.accentColor) + Image(systemName: "arrow.up.right") + .font(.caption) + .foregroundColor(.accentColor) + } } + .buttonStyle(.plain) } - .buttonStyle(.plain) } } }