Add OpenRouter support and enhancement error handling

This commit is contained in:
Beingpax 2025-06-24 16:29:12 +05:45
parent 9b0c2366b6
commit 6ddb89cbf8
4 changed files with 134 additions and 61 deletions

View File

@ -1,5 +1,4 @@
import Foundation
import os
import SwiftData
import AppKit
@ -9,11 +8,6 @@ enum EnhancementPrompt {
}
class AIEnhancementService: ObservableObject {
private let logger = Logger(
subsystem: "com.prakashjoshipax.VoiceInk",
category: "aienhancement"
)
@Published var isEnhancementEnabled: Bool {
didSet {
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
@ -21,7 +15,8 @@ class AIEnhancementService: ObservableObject {
selectedPromptId = customPrompts.first?.id
}
}
}
}
@Published var useClipboardContext: Bool {
didSet {
UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext")
@ -73,7 +68,6 @@ class AIEnhancementService: ObservableObject {
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
// Use migration service to load prompts, preserving existing data
self.customPrompts = PromptMigrationService.migratePromptsIfNeeded()
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
@ -174,31 +168,16 @@ class AIEnhancementService: ObservableObject {
private func makeRequest(text: String, mode: EnhancementPrompt, retryCount: Int = 0) async throws -> String {
guard isConfigured else {
logger.error("AI Enhancement: API not configured")
throw EnhancementError.notConfigured
}
guard !text.isEmpty else {
logger.error("AI Enhancement: Empty text received")
throw EnhancementError.emptyText
}
let formattedText = "\n<TRANSCRIPT>\n\(text)\n</TRANSCRIPT>"
// Log individual contexts if enabled and available
if useClipboardContext, let clipboardText = NSPasteboard.general.string(forType: .string), !clipboardText.isEmpty {
logger.notice("Clipboard Context: \(clipboardText, privacy: .public)")
}
if useScreenCaptureContext, let capturedText = screenCaptureService.lastCapturedText, !capturedText.isEmpty {
logger.notice("Screen Capture Context: \(capturedText, privacy: .public)")
}
let systemMessage = getSystemMessage(for: mode)
logger.notice("🛰️ Sending to AI provider: \(self.aiService.selectedProvider.rawValue, privacy: .public)")
logger.notice("System Message: \(systemMessage, privacy: .public)")
logger.notice("User Message: \(formattedText, privacy: .public)")
if aiService.selectedProvider == .ollama {
do {
let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage)
@ -279,7 +258,6 @@ class AIEnhancementService: ObservableObject {
case 429:
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
throw EnhancementError.serverError
default:
throw EnhancementError.apiError
@ -335,7 +313,6 @@ class AIEnhancementService: ObservableObject {
case 429:
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
throw EnhancementError.serverError
default:
throw EnhancementError.apiError
@ -398,7 +375,6 @@ class AIEnhancementService: ObservableObject {
case 429:
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
throw EnhancementError.serverError
default:
throw EnhancementError.apiError
@ -417,40 +393,26 @@ class AIEnhancementService: ObservableObject {
}
func enhance(_ text: String) async throws -> String {
logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)")
let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement
var retryCount = 0
while retryCount < maxRetries {
do {
let result = try await makeRequest(text: text, mode: enhancementPrompt, retryCount: retryCount)
logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)")
return result
} catch let error as EnhancementError {
if shouldRetry(error: error, retryCount: retryCount) {
let errorType = switch error {
case .rateLimitExceeded: "Rate limit exceeded"
case .serverError: "Server error occurred"
case .networkError: "Network error occurred"
default: "Unknown error"
}
logger.notice("⚠️ \(errorType), retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))")
retryCount += 1
let delaySeconds = getRetryDelay(for: retryCount)
try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000))
continue
} else {
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
throw error
}
} catch {
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
throw error
}
}
logger.notice("❌ AI enhancement failed: maximum retries exceeded")
throw EnhancementError.maxRetriesExceeded
}
@ -509,21 +471,19 @@ class AIEnhancementService: ObservableObject {
for template in predefinedTemplates {
if let existingIndex = customPrompts.firstIndex(where: { $0.id == template.id }) {
// Update existing predefined prompt: only update prompt text, preserve trigger word
var updatedPrompt = customPrompts[existingIndex]
updatedPrompt = CustomPrompt(
id: updatedPrompt.id,
title: template.title,
promptText: template.promptText, // Update from template
promptText: template.promptText,
isActive: updatedPrompt.isActive,
icon: template.icon,
description: template.description,
isPredefined: true,
triggerWords: updatedPrompt.triggerWords // Preserve user's trigger words
triggerWords: updatedPrompt.triggerWords
)
customPrompts[existingIndex] = updatedPrompt
} else {
// Add new predefined prompt (no default trigger word)
customPrompts.append(template)
}
}
@ -541,6 +501,31 @@ enum EnhancementError: Error {
case apiError
case networkError
case maxRetriesExceeded
}
}
extension EnhancementError: LocalizedError {
var errorDescription: String? {
switch self {
case .notConfigured:
return "AI provider not configured. Please check your API key."
case .emptyText:
return "No text to enhance."
case .invalidResponse:
return "Invalid response from AI provider."
case .enhancementFailed:
return "AI enhancement failed to process the text."
case .authenticationFailed:
return "API key is invalid. Please check your credentials."
case .rateLimitExceeded:
return "Rate limit exceeded. Please try again later."
case .serverError:
return "AI provider server error. Please try again."
case .apiError:
return "AI provider API error. Please try again."
case .networkError:
return "Network connection failed. Check your internet."
case .maxRetriesExceeded:
return "Enhancement failed after multiple attempts."
}
}
}

View File

@ -7,12 +7,14 @@ enum AIProvider: String, CaseIterable {
case deepSeek = "DeepSeek"
case gemini = "Gemini"
case anthropic = "Anthropic"
case openRouter = "OpenRouter"
case mistral = "Mistral"
case ollama = "Ollama"
case elevenLabs = "ElevenLabs"
case deepgram = "Deepgram"
case custom = "Custom"
var baseURL: String {
switch self {
case .groq:
@ -25,6 +27,8 @@ enum AIProvider: String, CaseIterable {
return "https://generativelanguage.googleapis.com/v1beta/models"
case .anthropic:
return "https://api.anthropic.com/v1/messages"
case .openRouter:
return "https://openrouter.ai/api/v1/chat/completions"
case .mistral:
return "https://api.mistral.ai/v1/chat/completions"
case .elevenLabs:
@ -35,6 +39,7 @@ enum AIProvider: String, CaseIterable {
return "https://api.deepgram.com/v1/listen"
case .custom:
return UserDefaults.standard.string(forKey: "customProviderBaseURL") ?? ""
}
}
@ -60,6 +65,8 @@ enum AIProvider: String, CaseIterable {
return "whisper-1"
case .custom:
return UserDefaults.standard.string(forKey: "customProviderModel") ?? ""
case .openRouter:
return "openai/gpt-4o"
}
}
@ -107,6 +114,8 @@ enum AIProvider: String, CaseIterable {
return ["whisper-1"]
case .custom:
return []
case .openRouter:
return []
}
}
@ -163,6 +172,8 @@ class AIService: ObservableObject {
private let userDefaults = UserDefaults.standard
private lazy var ollamaService = OllamaService()
private var openRouterModels: [String] = []
var connectedProviders: [AIProvider] {
AIProvider.allCases.filter { provider in
if provider == .ollama {
@ -186,6 +197,8 @@ class AIService: ObservableObject {
var availableModels: [String] {
if selectedProvider == .ollama {
return ollamaService.availableModels.map { $0.name }
} else if selectedProvider == .openRouter {
return openRouterModels
}
return selectedProvider.availableModels
}
@ -205,7 +218,6 @@ class AIService: ObservableObject {
}
} else {
self.isAPIKeyValid = true
}
loadSavedModelSelections()
@ -451,14 +463,6 @@ class AIService: ObservableObject {
}
func enhanceWithOllama(text: String, systemPrompt: String) async throws -> String {
// Ensure connection is established before attempting enhancement
if !ollamaService.isConnected {
await ollamaService.checkConnection()
if ollamaService.isConnected && ollamaService.availableModels.isEmpty {
await ollamaService.refreshModels()
}
}
logger.notice("🔄 Sending transcription to Ollama for enhancement (model: \(self.ollamaService.selectedModel))")
do {
let result = try await ollamaService.enhance(text, withSystemPrompt: systemPrompt)
@ -479,8 +483,57 @@ class AIService: ObservableObject {
ollamaService.selectedModel = modelName
userDefaults.set(modelName, forKey: "ollamaSelectedModel")
}
func fetchOpenRouterModels() async {
let url = URL(string: "https://openrouter.ai/api/v1/models")!
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else {
logger.error("Failed to fetch OpenRouter models: Invalid HTTP response")
await MainActor.run {
self.openRouterModels = []
self.objectWillChange.send()
}
return
}
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let dataArray = jsonResponse["data"] as? [[String: Any]] else {
logger.error("Failed to parse OpenRouter models JSON")
await MainActor.run {
self.openRouterModels = []
self.objectWillChange.send()
}
return
}
let models = dataArray.compactMap { $0["id"] as? String }
await MainActor.run {
self.openRouterModels = models.sorted()
if self.selectedProvider == .openRouter && self.currentModel == self.selectedProvider.defaultModel && !models.isEmpty {
self.selectModel(models.sorted().first!)
}
self.objectWillChange.send()
}
logger.info("Successfully fetched \(models.count) OpenRouter models.")
} catch {
logger.error("Error fetching OpenRouter models: \(error.localizedDescription)")
await MainActor.run {
self.openRouterModels = []
self.objectWillChange.send()
}
}
}
}
extension Notification.Name {
static let aiProviderKeyChanged = Notification.Name("aiProviderKeyChanged")
}

View File

@ -52,6 +52,10 @@ struct APIKeyManagementView: View {
.onChange(of: aiService.selectedProvider) { oldValue, newValue in
if aiService.selectedProvider == .ollama {
checkOllamaConnection()
} else if aiService.selectedProvider == .openRouter {
Task {
await aiService.fetchOpenRouterModels()
}
}
}
@ -59,12 +63,26 @@ struct APIKeyManagementView: View {
if !aiService.availableModels.isEmpty &&
aiService.selectedProvider != .ollama &&
aiService.selectedProvider != .custom {
Picker("Model", selection: Binding(
get: { aiService.currentModel },
set: { aiService.selectModel($0) }
)) {
ForEach(aiService.availableModels, id: \.self) { model in
Text(model).tag(model)
HStack {
Picker("Model", selection: Binding(
get: { aiService.currentModel },
set: { aiService.selectModel($0) }
)) {
ForEach(aiService.availableModels, id: \.self) { model in
Text(model).tag(model)
}
}
if aiService.selectedProvider == .openRouter {
Button(action: {
Task {
await aiService.fetchOpenRouterModels()
}
}) {
Image(systemName: "arrow.clockwise")
}
.buttonStyle(.borderless)
.help("Refresh models")
}
}
}
@ -200,6 +218,7 @@ struct APIKeyManagementView: View {
.padding()
.background(Color.secondary.opacity(0.03))
.cornerRadius(12)
} else if aiService.selectedProvider == .custom {
VStack(alignment: .leading, spacing: 16) {
// Header
@ -391,6 +410,8 @@ struct APIKeyManagementView: View {
URL(string: "https://console.deepgram.com/api-keys")!
case .ollama, .custom:
URL(string: "")! // This case should never be reached
case .openRouter:
URL(string: "https://openrouter.ai/keys")!
}
NSWorkspace.shared.open(url)
} label: {
@ -419,6 +440,10 @@ struct APIKeyManagementView: View {
.onAppear {
if aiService.selectedProvider == .ollama {
checkOllamaConnection()
} else if aiService.selectedProvider == .openRouter {
Task {
await aiService.fetchOpenRouterModels()
}
}
}
}

View File

@ -336,13 +336,23 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
try? modelContext.save()
text = enhancedText
} catch {
// Enhancement failed - save error in enhancedText field and show notification
let newTranscription = Transcription(
text: originalText,
duration: actualDuration,
enhancedText: "Enhancement failed: \(error.localizedDescription)",
audioFileURL: permanentURL?.absoluteString
)
modelContext.insert(newTranscription)
try? modelContext.save()
// Show notification about enhancement failure
await MainActor.run {
NotificationManager.shared.showNotification(
title: "AI enhancement failed",
type: .error
)
}
}
} else {
let newTranscription = Transcription(