diff --git a/VoiceInk/Services/AIEnhancementService.swift b/VoiceInk/Services/AIEnhancementService.swift index 663ef34..7ae13d8 100644 --- a/VoiceInk/Services/AIEnhancementService.swift +++ b/VoiceInk/Services/AIEnhancementService.swift @@ -261,6 +261,8 @@ class AIEnhancementService: ObservableObject { let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) return filteredText + } else if (500...599).contains(httpResponse.statusCode) { + throw EnhancementError.serverError } else { let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response." throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)") @@ -268,6 +270,8 @@ class AIEnhancementService: ObservableObject { } catch let error as EnhancementError { throw error + } catch let error as URLError { + throw error } catch { throw EnhancementError.customError(error.localizedDescription) } @@ -312,6 +316,8 @@ class AIEnhancementService: ObservableObject { let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)) return filteredText + } else if (500...599).contains(httpResponse.statusCode) { + throw EnhancementError.serverError } else { let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response." throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)") @@ -319,19 +325,66 @@ class AIEnhancementService: ObservableObject { } catch let error as EnhancementError { throw error + } catch let error as URLError { + throw error } catch { throw EnhancementError.customError(error.localizedDescription) } } } + private func makeRequestWithRetry(text: String, mode: EnhancementPrompt, maxRetries: Int = 3, initialDelay: TimeInterval = 1.0) async throws -> String { + var retries = 0 + var currentDelay = initialDelay + + while retries < maxRetries { + do { + return try await makeRequest(text: text, mode: mode) + } catch let error as EnhancementError { + switch error { + case .networkError, .serverError: + retries += 1 + if retries < maxRetries { + logger.warning("Request failed, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))") + try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000)) + currentDelay *= 2 // Exponential backoff + } else { + logger.error("Request failed after \(maxRetries) retries.") + throw error + } + default: + throw error + } + } catch { + // For other errors, check if it's a network-related URLError + let nsError = error as NSError + if nsError.domain == NSURLErrorDomain && [NSURLErrorNotConnectedToInternet, NSURLErrorTimedOut, NSURLErrorNetworkConnectionLost].contains(nsError.code) { + retries += 1 + if retries < maxRetries { + logger.warning("Request failed with network error, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))") + try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000)) + currentDelay *= 2 // Exponential backoff + } else { + logger.error("Request failed after \(maxRetries) retries with network error.") + throw EnhancementError.networkError + } + } else { + throw error + } + } + } + + // This part should ideally not be reached, but as a fallback: + throw EnhancementError.enhancementFailed + } + func enhance(_ text: String) async throws -> (String, TimeInterval, String?) { let startTime = Date() let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement let promptName = activePrompt?.title do { - let result = try await makeRequest(text: text, mode: enhancementPrompt) + let result = try await makeRequestWithRetry(text: text, mode: enhancementPrompt) let endTime = Date() let duration = endTime.timeIntervalSince(startTime) return (result, duration, promptName) @@ -404,6 +457,7 @@ enum EnhancementError: Error { case invalidResponse case enhancementFailed case networkError + case serverError case customError(String) } @@ -418,6 +472,8 @@ extension EnhancementError: LocalizedError { return "AI enhancement failed to process the text." case .networkError: return "Network connection failed. Check your internet." + case .serverError: + return "The AI provider's server encountered an error. Please try again later." case .customError(let message): return message }