import Foundation import SwiftData import AppKit import os enum EnhancementPrompt { case transcriptionEnhancement case aiAssistant } class AIEnhancementService: ObservableObject { private let logger = Logger(subsystem: "com.voiceink.enhancement", category: "AIEnhancementService") @Published var isEnhancementEnabled: Bool { didSet { UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled") if isEnhancementEnabled && selectedPromptId == nil { selectedPromptId = customPrompts.first?.id } NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil) NotificationCenter.default.post(name: .enhancementToggleChanged, object: nil) } } @Published var useClipboardContext: Bool { didSet { UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext") } } @Published var useScreenCaptureContext: Bool { didSet { UserDefaults.standard.set(useScreenCaptureContext, forKey: "useScreenCaptureContext") NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil) } } @Published var customPrompts: [CustomPrompt] { didSet { if let encoded = try? JSONEncoder().encode(customPrompts) { UserDefaults.standard.set(encoded, forKey: "customPrompts") } } } @Published var selectedPromptId: UUID? { didSet { UserDefaults.standard.set(selectedPromptId?.uuidString, forKey: "selectedPromptId") NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil) NotificationCenter.default.post(name: .promptSelectionChanged, object: nil) } } @Published var lastSystemMessageSent: String? @Published var lastUserMessageSent: String? var activePrompt: CustomPrompt? { allPrompts.first { $0.id == selectedPromptId } } var allPrompts: [CustomPrompt] { return customPrompts } private let aiService: AIService private let screenCaptureService: ScreenCaptureService private let dictionaryContextService: DictionaryContextService private let baseTimeout: TimeInterval = 30 private let rateLimitInterval: TimeInterval = 1.0 private var lastRequestTime: Date? private let modelContext: ModelContext init(aiService: AIService = AIService(), modelContext: ModelContext) { self.aiService = aiService self.modelContext = modelContext self.screenCaptureService = ScreenCaptureService() self.dictionaryContextService = DictionaryContextService.shared self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled") self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext") self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext") self.customPrompts = PromptMigrationService.migratePromptsIfNeeded() if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") { self.selectedPromptId = UUID(uuidString: savedPromptId) } if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) { self.selectedPromptId = allPrompts.first?.id } NotificationCenter.default.addObserver( self, selector: #selector(handleAPIKeyChange), name: .aiProviderKeyChanged, object: nil ) initializePredefinedPrompts() } deinit { NotificationCenter.default.removeObserver(self) } @objc private func handleAPIKeyChange() { DispatchQueue.main.async { self.objectWillChange.send() if !self.aiService.isAPIKeyValid { self.isEnhancementEnabled = false } } } func getAIService() -> AIService? { return aiService } var isConfigured: Bool { aiService.isAPIKeyValid } private func waitForRateLimit() async throws { if let lastRequest = lastRequestTime { let timeSinceLastRequest = Date().timeIntervalSince(lastRequest) if timeSinceLastRequest < rateLimitInterval { try await Task.sleep(nanoseconds: UInt64((rateLimitInterval - timeSinceLastRequest) * 1_000_000_000)) } } lastRequestTime = Date() } private func getSystemMessage(for mode: EnhancementPrompt) -> String { let selectedText = SelectedTextService.fetchSelectedText() if let activePrompt = activePrompt, activePrompt.id == PredefinedPrompts.assistantPromptId, let selectedText = selectedText, !selectedText.isEmpty { let selectedTextContext = "\n\nSelected Text: \(selectedText)" let generalContextSection = "\n\n\(selectedTextContext)\n" let dictionaryContextSection = if !dictionaryContextService.getDictionaryContext().isEmpty { "\n\n\(dictionaryContextService.getDictionaryContext())\n" } else { "" } return activePrompt.promptText + generalContextSection + dictionaryContextSection } let clipboardContext = if useClipboardContext, let clipboardText = NSPasteboard.general.string(forType: .string), !clipboardText.isEmpty { "\n\n\n\(clipboardText)\n" } else { "" } let screenCaptureContext = if useScreenCaptureContext, let capturedText = screenCaptureService.lastCapturedText, !capturedText.isEmpty { "\n\nActive Window Context: \(capturedText)" } else { "" } let dictionaryContext = dictionaryContextService.getDictionaryContext() let generalContextSection = if !clipboardContext.isEmpty || !screenCaptureContext.isEmpty { "\n\n\(clipboardContext)\(screenCaptureContext)\n" } else { "" } let dictionaryContextSection = if !dictionaryContext.isEmpty { "\n\n\(dictionaryContext)\n" } else { "" } guard let activePrompt = activePrompt else { if let defaultPrompt = allPrompts.first(where: { $0.id == PredefinedPrompts.defaultPromptId }) { var systemMessage = String(format: AIPrompts.customPromptTemplate, defaultPrompt.promptText) systemMessage += generalContextSection + dictionaryContextSection return systemMessage } return AIPrompts.assistantMode + generalContextSection + dictionaryContextSection } if activePrompt.id == PredefinedPrompts.assistantPromptId { return activePrompt.promptText + generalContextSection + dictionaryContextSection } var systemMessage = String(format: AIPrompts.customPromptTemplate, activePrompt.promptText) systemMessage += generalContextSection + dictionaryContextSection return systemMessage } private func makeRequest(text: String, mode: EnhancementPrompt) async throws -> String { guard isConfigured else { throw EnhancementError.notConfigured } guard !text.isEmpty else { return "" // Silently return empty string instead of throwing error } let formattedText = "\n\n\(text)\n" let systemMessage = getSystemMessage(for: mode) // Persist the exact payload being sent (also used for UI) self.lastSystemMessageSent = systemMessage self.lastUserMessageSent = formattedText // Log the message being sent to AI enhancement logger.notice("AI Enhancement - System Message: \(systemMessage, privacy: .public)") logger.notice("AI Enhancement - User Message: \(formattedText, privacy: .public)") if aiService.selectedProvider == .ollama { do { let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage) let filteredResult = AIEnhancementOutputFilter.filter(result) return filteredResult } catch { if let localError = error as? LocalAIError { throw EnhancementError.customError(localError.errorDescription ?? "An unknown Ollama error occurred.") } else { throw EnhancementError.customError(error.localizedDescription) } } } try await waitForRateLimit() switch aiService.selectedProvider { case .anthropic: let requestBody: [String: Any] = [ "model": aiService.currentModel, "max_tokens": 8192, "system": systemMessage, "messages": [ ["role": "user", "content": formattedText] ] ] var request = URLRequest(url: URL(string: aiService.selectedProvider.baseURL)!) request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key") request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version") request.timeoutInterval = baseTimeout request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { let (data, response) = try await URLSession.shared.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { throw EnhancementError.invalidResponse } if httpResponse.statusCode == 200 { guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any], let content = jsonResponse["content"] as? [[String: Any]], let firstContent = content.first, let enhancedText = firstContent["text"] as? String else { throw EnhancementError.enhancementFailed } let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) return filteredText } else if httpResponse.statusCode == 429 { throw EnhancementError.rateLimitExceeded } else if (500...599).contains(httpResponse.statusCode) { throw EnhancementError.serverError } else { let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response." throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)") } } catch let error as EnhancementError { throw error } catch let error as URLError { throw error } catch { throw EnhancementError.customError(error.localizedDescription) } default: let url = URL(string: aiService.selectedProvider.baseURL)! var request = URLRequest(url: url) request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization") request.timeoutInterval = baseTimeout let messages: [[String: Any]] = [ ["role": "system", "content": systemMessage], ["role": "user", "content": formattedText] ] let requestBody: [String: Any] = [ "model": aiService.currentModel, "messages": messages, "temperature": aiService.currentModel.lowercased().hasPrefix("gpt-5") ? 1.0 : 0.3, "stream": false ] request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { let (data, response) = try await URLSession.shared.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { throw EnhancementError.invalidResponse } if httpResponse.statusCode == 200 { guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any], let choices = jsonResponse["choices"] as? [[String: Any]], let firstChoice = choices.first, let message = firstChoice["message"] as? [String: Any], let enhancedText = message["content"] as? String else { throw EnhancementError.enhancementFailed } let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) return filteredText } else if httpResponse.statusCode == 429 { throw EnhancementError.rateLimitExceeded } else if (500...599).contains(httpResponse.statusCode) { throw EnhancementError.serverError } else { let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response." throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)") } } catch let error as EnhancementError { throw error } catch let error as URLError { throw error } catch { throw EnhancementError.customError(error.localizedDescription) } } } private func makeRequestWithRetry(text: String, mode: EnhancementPrompt, maxRetries: Int = 3, initialDelay: TimeInterval = 1.0) async throws -> String { var retries = 0 var currentDelay = initialDelay while retries < maxRetries { do { return try await makeRequest(text: text, mode: mode) } catch let error as EnhancementError { switch error { case .networkError, .serverError, .rateLimitExceeded: retries += 1 if retries < maxRetries { logger.warning("Request failed, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))") try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000)) currentDelay *= 2 // Exponential backoff } else { logger.error("Request failed after \(maxRetries) retries.") throw error } default: throw error } } catch { // For other errors, check if it's a network-related URLError let nsError = error as NSError if nsError.domain == NSURLErrorDomain && [NSURLErrorNotConnectedToInternet, NSURLErrorTimedOut, NSURLErrorNetworkConnectionLost].contains(nsError.code) { retries += 1 if retries < maxRetries { logger.warning("Request failed with network error, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))") try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000)) currentDelay *= 2 // Exponential backoff } else { logger.error("Request failed after \(maxRetries) retries with network error.") throw EnhancementError.networkError } } else { throw error } } } // This part should ideally not be reached, but as a fallback: throw EnhancementError.enhancementFailed } func enhance(_ text: String) async throws -> (String, TimeInterval, String?) { let startTime = Date() let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement let promptName = activePrompt?.title do { let result = try await makeRequestWithRetry(text: text, mode: enhancementPrompt) let endTime = Date() let duration = endTime.timeIntervalSince(startTime) return (result, duration, promptName) } catch { throw error } } func captureScreenContext() async { guard useScreenCaptureContext else { return } if let capturedText = await screenCaptureService.captureAndExtractText() { await MainActor.run { self.objectWillChange.send() } } } func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil, triggerWords: [String] = []) { let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false, triggerWords: triggerWords) customPrompts.append(newPrompt) if customPrompts.count == 1 { selectedPromptId = newPrompt.id } } func updatePrompt(_ prompt: CustomPrompt) { if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) { customPrompts[index] = prompt } } func deletePrompt(_ prompt: CustomPrompt) { customPrompts.removeAll { $0.id == prompt.id } if selectedPromptId == prompt.id { selectedPromptId = allPrompts.first?.id } } func setActivePrompt(_ prompt: CustomPrompt) { selectedPromptId = prompt.id } private func initializePredefinedPrompts() { let predefinedTemplates = PredefinedPrompts.createDefaultPrompts() for template in predefinedTemplates { if let existingIndex = customPrompts.firstIndex(where: { $0.id == template.id }) { var updatedPrompt = customPrompts[existingIndex] updatedPrompt = CustomPrompt( id: updatedPrompt.id, title: template.title, promptText: template.promptText, isActive: updatedPrompt.isActive, icon: template.icon, description: template.description, isPredefined: true, triggerWords: updatedPrompt.triggerWords ) customPrompts[existingIndex] = updatedPrompt } else { customPrompts.append(template) } } } } enum EnhancementError: Error { case notConfigured case invalidResponse case enhancementFailed case networkError case serverError case rateLimitExceeded case customError(String) } extension EnhancementError: LocalizedError { var errorDescription: String? { switch self { case .notConfigured: return "AI provider not configured. Please check your API key." case .invalidResponse: return "Invalid response from AI provider." case .enhancementFailed: return "AI enhancement failed to process the text." case .networkError: return "Network connection failed. Check your internet." case .serverError: return "The AI provider's server encountered an error. Please try again later." case .rateLimitExceeded: return "Rate limit exceeded. Please try again later." case .customError(let message): return message } } }