diff --git a/VoiceInk/Services/AIEnhancementService.swift b/VoiceInk/Services/AIEnhancementService.swift index 300a887..71742f0 100644 --- a/VoiceInk/Services/AIEnhancementService.swift +++ b/VoiceInk/Services/AIEnhancementService.swift @@ -53,7 +53,6 @@ class AIEnhancementService: ObservableObject { private let aiService: AIService private let screenCaptureService: ScreenCaptureService - private let maxRetries = 3 private let baseTimeout: TimeInterval = 10 private let rateLimitInterval: TimeInterval = 1.0 private var lastRequestTime: Date? @@ -166,7 +165,7 @@ class AIEnhancementService: ObservableObject { return systemMessage } - private func makeRequest(text: String, mode: EnhancementPrompt, retryCount: Int = 0) async throws -> String { + private func makeRequest(text: String, mode: EnhancementPrompt) async throws -> String { guard isConfigured else { throw EnhancementError.notConfigured } @@ -200,80 +199,6 @@ class AIEnhancementService: ObservableObject { try await waitForRateLimit() switch aiService.selectedProvider { - case .gemini: - let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" - let model = aiService.currentModel - let fullURL = "\(baseEndpoint)/\(model):generateContent" - - var urlComponents = URLComponents(string: fullURL)! - urlComponents.queryItems = [URLQueryItem(name: "key", value: aiService.apiKey)] - - guard let url = urlComponents.url else { - throw EnhancementError.invalidResponse - } - - var request = URLRequest(url: url) - request.httpMethod = "POST" - request.addValue("application/json", forHTTPHeaderField: "Content-Type") - request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) - - let requestBody: [String: Any] = [ - "contents": [ - [ - "parts": [ - ["text": systemMessage], - ["text": formattedText] - ] - ] - ], - "generationConfig": [ - "temperature": 0.3, - ] - ] - - request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) - - do { - let (data, response) = try await URLSession.shared.data(for: request) - - guard let httpResponse = response as? HTTPURLResponse else { - throw EnhancementError.invalidResponse - } - - switch httpResponse.statusCode { - case 200: - guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let candidates = jsonResponse["candidates"] as? [[String: Any]], - let firstCandidate = candidates.first, - let content = firstCandidate["content"] as? [String: Any], - let parts = content["parts"] as? [[String: Any]], - let firstPart = parts.first, - let enhancedText = firstPart["text"] as? String else { - throw EnhancementError.enhancementFailed - } - - let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) - return filteredText - - case 401: - throw EnhancementError.authenticationFailed - case 429: - throw EnhancementError.rateLimitExceeded - case 500...599: - throw EnhancementError.serverError - default: - throw EnhancementError.apiError - } - } catch let error as EnhancementError { - throw error - } catch { - if retryCount < maxRetries { - try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) - return try await makeRequest(text: text, mode: mode, retryCount: retryCount + 1) - } - throw EnhancementError.networkError - } - case .anthropic: let requestBody: [String: Any] = [ "model": aiService.currentModel, @@ -289,7 +214,7 @@ class AIEnhancementService: ObservableObject { request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key") request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version") - request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) + request.timeoutInterval = baseTimeout request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody) do { @@ -320,14 +245,11 @@ class AIEnhancementService: ObservableObject { default: throw EnhancementError.apiError } + } catch let error as EnhancementError { throw error } catch { - if retryCount < maxRetries { - try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) - return try await makeRequest(text: text, mode: mode, retryCount: retryCount + 1) - } - throw EnhancementError.networkError + throw EnhancementError.customError(error.localizedDescription) } default: @@ -336,7 +258,7 @@ class AIEnhancementService: ObservableObject { request.httpMethod = "POST" request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization") - request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount)) + request.timeoutInterval = baseTimeout let messages: [[String: Any]] = [ ["role": "system", "content": systemMessage], @@ -347,8 +269,6 @@ class AIEnhancementService: ObservableObject { "model": aiService.currentModel, "messages": messages, "temperature": 0.3, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, "stream": false ] @@ -361,8 +281,7 @@ class AIEnhancementService: ObservableObject { throw EnhancementError.invalidResponse } - switch httpResponse.statusCode { - case 200: + if httpResponse.statusCode == 200 { guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any], let choices = jsonResponse["choices"] as? [[String: Any]], let firstChoice = choices.first, @@ -373,25 +292,15 @@ class AIEnhancementService: ObservableObject { let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) return filteredText - - case 401: - throw EnhancementError.authenticationFailed - case 429: - throw EnhancementError.rateLimitExceeded - case 500...599: - throw EnhancementError.serverError - default: - throw EnhancementError.apiError + } else { + let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response." + throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)") } } catch let error as EnhancementError { throw error } catch { - if retryCount < maxRetries { - try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) - return try await makeRequest(text: text, mode: mode, retryCount: retryCount + 1) - } - throw EnhancementError.networkError + throw EnhancementError.customError(error.localizedDescription) } } } @@ -400,26 +309,14 @@ class AIEnhancementService: ObservableObject { let startTime = Date() let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement - var retryCount = 0 - while retryCount < maxRetries { - do { - let result = try await makeRequest(text: text, mode: enhancementPrompt, retryCount: retryCount) - let duration = Date().timeIntervalSince(startTime) - return (result, duration) - } catch let error as EnhancementError { - if shouldRetry(error: error, retryCount: retryCount) { - retryCount += 1 - let delaySeconds = getRetryDelay(for: retryCount) - try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000)) - continue - } else { - throw error - } - } catch { - throw error - } + do { + let result = try await makeRequest(text: text, mode: enhancementPrompt) + let endTime = Date() + let duration = endTime.timeIntervalSince(startTime) + return (result, duration) + } catch { + throw error } - throw EnhancementError.maxRetriesExceeded } func captureScreenContext() async { @@ -457,21 +354,6 @@ class AIEnhancementService: ObservableObject { selectedPromptId = prompt.id } - private func shouldRetry(error: EnhancementError, retryCount: Int) -> Bool { - guard retryCount < maxRetries - 1 else { return false } - - switch error { - case .rateLimitExceeded, .serverError: - return true - default: - return false - } - } - - private func getRetryDelay(for retryCount: Int) -> TimeInterval { - return retryCount == 1 ? 1.0 : 2.0 - } - private func initializePredefinedPrompts() { let predefinedTemplates = PredefinedPrompts.createDefaultPrompts() @@ -506,7 +388,7 @@ enum EnhancementError: Error { case serverError case apiError case networkError - case maxRetriesExceeded + case customError(String) } extension EnhancementError: LocalizedError { @@ -530,8 +412,8 @@ extension EnhancementError: LocalizedError { return "AI provider API error. Please try again." case .networkError: return "Network connection failed. Check your internet." - case .maxRetriesExceeded: - return "Enhancement failed after multiple attempts." + case .customError(let message): + return message } } } \ No newline at end of file diff --git a/VoiceInk/Services/AIService.swift b/VoiceInk/Services/AIService.swift index a8d5d0f..5aa211d 100644 --- a/VoiceInk/Services/AIService.swift +++ b/VoiceInk/Services/AIService.swift @@ -22,7 +22,7 @@ enum AIProvider: String, CaseIterable { case .groq: return "https://api.groq.com/openai/v1/chat/completions" case .gemini: - return "https://generativelanguage.googleapis.com/v1beta/models" + return "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" case .anthropic: return "https://api.anthropic.com/v1/messages" case .openAI: @@ -291,8 +291,6 @@ class AIService: ObservableObject { } switch selectedProvider { - case .gemini: - verifyGeminiAPIKey(key, completion: completion) case .anthropic: verifyAnthropicAPIKey(key, completion: completion) case .elevenLabs: @@ -370,49 +368,6 @@ class AIService: ObservableObject { }.resume() } - private func verifyGeminiAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { - let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models" - let model = currentModel - let fullURL = "\(baseEndpoint)/\(model):generateContent" - - var urlComponents = URLComponents(string: fullURL)! - urlComponents.queryItems = [URLQueryItem(name: "key", value: key)] - - guard let url = urlComponents.url else { - completion(false) - return - } - - var request = URLRequest(url: url) - request.httpMethod = "POST" - request.addValue("application/json", forHTTPHeaderField: "Content-Type") - - let testBody: [String: Any] = [ - "contents": [ - [ - "parts": [ - ["text": "test"] - ] - ] - ] - ] - - request.httpBody = try? JSONSerialization.data(withJSONObject: testBody) - - URLSession.shared.dataTask(with: request) { data, response, error in - if let error = error { - completion(false) - return - } - - if let httpResponse = response as? HTTPURLResponse { - completion(httpResponse.statusCode == 200) - } else { - completion(false) - } - }.resume() - } - private func verifyElevenLabsAPIKey(_ key: String, completion: @escaping (Bool) -> Void) { let url = URL(string: "https://api.elevenlabs.io/v1/user")! diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index 74d7ee9..5a2f387 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -294,7 +294,7 @@ class WhisperState: NSObject, ObservableObject { let newTranscription = Transcription( text: originalText, duration: actualDuration, - enhancedText: "Enhancement failed: \(error.localizedDescription)", + enhancedText: "Enhancement failed: \(error)", audioFileURL: url.absoluteString, transcriptionModelName: model.displayName, transcriptionDuration: transcriptionDuration