Add OpenRouter support and enhancement error handling
This commit is contained in:
parent
9b0c2366b6
commit
6ddb89cbf8
@ -1,5 +1,4 @@
|
||||
import Foundation
|
||||
import os
|
||||
import SwiftData
|
||||
import AppKit
|
||||
|
||||
@ -9,11 +8,6 @@ enum EnhancementPrompt {
|
||||
}
|
||||
|
||||
class AIEnhancementService: ObservableObject {
|
||||
private let logger = Logger(
|
||||
subsystem: "com.prakashjoshipax.VoiceInk",
|
||||
category: "aienhancement"
|
||||
)
|
||||
|
||||
@Published var isEnhancementEnabled: Bool {
|
||||
didSet {
|
||||
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
|
||||
@ -21,7 +15,8 @@ class AIEnhancementService: ObservableObject {
|
||||
selectedPromptId = customPrompts.first?.id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Published var useClipboardContext: Bool {
|
||||
didSet {
|
||||
UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext")
|
||||
@ -73,7 +68,6 @@ class AIEnhancementService: ObservableObject {
|
||||
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
|
||||
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
|
||||
|
||||
// Use migration service to load prompts, preserving existing data
|
||||
self.customPrompts = PromptMigrationService.migratePromptsIfNeeded()
|
||||
|
||||
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
|
||||
@ -174,31 +168,16 @@ class AIEnhancementService: ObservableObject {
|
||||
|
||||
private func makeRequest(text: String, mode: EnhancementPrompt, retryCount: Int = 0) async throws -> String {
|
||||
guard isConfigured else {
|
||||
logger.error("AI Enhancement: API not configured")
|
||||
throw EnhancementError.notConfigured
|
||||
}
|
||||
|
||||
guard !text.isEmpty else {
|
||||
logger.error("AI Enhancement: Empty text received")
|
||||
throw EnhancementError.emptyText
|
||||
}
|
||||
|
||||
let formattedText = "\n<TRANSCRIPT>\n\(text)\n</TRANSCRIPT>"
|
||||
|
||||
// Log individual contexts if enabled and available
|
||||
if useClipboardContext, let clipboardText = NSPasteboard.general.string(forType: .string), !clipboardText.isEmpty {
|
||||
logger.notice("Clipboard Context: \(clipboardText, privacy: .public)")
|
||||
}
|
||||
if useScreenCaptureContext, let capturedText = screenCaptureService.lastCapturedText, !capturedText.isEmpty {
|
||||
logger.notice("Screen Capture Context: \(capturedText, privacy: .public)")
|
||||
}
|
||||
|
||||
let systemMessage = getSystemMessage(for: mode)
|
||||
|
||||
logger.notice("🛰️ Sending to AI provider: \(self.aiService.selectedProvider.rawValue, privacy: .public)")
|
||||
logger.notice("System Message: \(systemMessage, privacy: .public)")
|
||||
logger.notice("User Message: \(formattedText, privacy: .public)")
|
||||
|
||||
if aiService.selectedProvider == .ollama {
|
||||
do {
|
||||
let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage)
|
||||
@ -279,7 +258,6 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -335,7 +313,6 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -398,7 +375,6 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -417,40 +393,26 @@ class AIEnhancementService: ObservableObject {
|
||||
}
|
||||
|
||||
func enhance(_ text: String) async throws -> String {
|
||||
logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)")
|
||||
|
||||
let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement
|
||||
|
||||
var retryCount = 0
|
||||
while retryCount < maxRetries {
|
||||
do {
|
||||
let result = try await makeRequest(text: text, mode: enhancementPrompt, retryCount: retryCount)
|
||||
logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)")
|
||||
return result
|
||||
} catch let error as EnhancementError {
|
||||
if shouldRetry(error: error, retryCount: retryCount) {
|
||||
let errorType = switch error {
|
||||
case .rateLimitExceeded: "Rate limit exceeded"
|
||||
case .serverError: "Server error occurred"
|
||||
case .networkError: "Network error occurred"
|
||||
default: "Unknown error"
|
||||
}
|
||||
|
||||
logger.notice("⚠️ \(errorType), retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))")
|
||||
retryCount += 1
|
||||
let delaySeconds = getRetryDelay(for: retryCount)
|
||||
try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000))
|
||||
continue
|
||||
} else {
|
||||
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
} catch {
|
||||
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
}
|
||||
logger.notice("❌ AI enhancement failed: maximum retries exceeded")
|
||||
throw EnhancementError.maxRetriesExceeded
|
||||
}
|
||||
|
||||
@ -509,21 +471,19 @@ class AIEnhancementService: ObservableObject {
|
||||
|
||||
for template in predefinedTemplates {
|
||||
if let existingIndex = customPrompts.firstIndex(where: { $0.id == template.id }) {
|
||||
// Update existing predefined prompt: only update prompt text, preserve trigger word
|
||||
var updatedPrompt = customPrompts[existingIndex]
|
||||
updatedPrompt = CustomPrompt(
|
||||
id: updatedPrompt.id,
|
||||
title: template.title,
|
||||
promptText: template.promptText, // Update from template
|
||||
promptText: template.promptText,
|
||||
isActive: updatedPrompt.isActive,
|
||||
icon: template.icon,
|
||||
description: template.description,
|
||||
isPredefined: true,
|
||||
triggerWords: updatedPrompt.triggerWords // Preserve user's trigger words
|
||||
triggerWords: updatedPrompt.triggerWords
|
||||
)
|
||||
customPrompts[existingIndex] = updatedPrompt
|
||||
} else {
|
||||
// Add new predefined prompt (no default trigger word)
|
||||
customPrompts.append(template)
|
||||
}
|
||||
}
|
||||
@ -541,6 +501,31 @@ enum EnhancementError: Error {
|
||||
case apiError
|
||||
case networkError
|
||||
case maxRetriesExceeded
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
extension EnhancementError: LocalizedError {
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .notConfigured:
|
||||
return "AI provider not configured. Please check your API key."
|
||||
case .emptyText:
|
||||
return "No text to enhance."
|
||||
case .invalidResponse:
|
||||
return "Invalid response from AI provider."
|
||||
case .enhancementFailed:
|
||||
return "AI enhancement failed to process the text."
|
||||
case .authenticationFailed:
|
||||
return "API key is invalid. Please check your credentials."
|
||||
case .rateLimitExceeded:
|
||||
return "Rate limit exceeded. Please try again later."
|
||||
case .serverError:
|
||||
return "AI provider server error. Please try again."
|
||||
case .apiError:
|
||||
return "AI provider API error. Please try again."
|
||||
case .networkError:
|
||||
return "Network connection failed. Check your internet."
|
||||
case .maxRetriesExceeded:
|
||||
return "Enhancement failed after multiple attempts."
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7,12 +7,14 @@ enum AIProvider: String, CaseIterable {
|
||||
case deepSeek = "DeepSeek"
|
||||
case gemini = "Gemini"
|
||||
case anthropic = "Anthropic"
|
||||
case openRouter = "OpenRouter"
|
||||
case mistral = "Mistral"
|
||||
case ollama = "Ollama"
|
||||
case elevenLabs = "ElevenLabs"
|
||||
case deepgram = "Deepgram"
|
||||
case custom = "Custom"
|
||||
|
||||
|
||||
var baseURL: String {
|
||||
switch self {
|
||||
case .groq:
|
||||
@ -25,6 +27,8 @@ enum AIProvider: String, CaseIterable {
|
||||
return "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
case .anthropic:
|
||||
return "https://api.anthropic.com/v1/messages"
|
||||
case .openRouter:
|
||||
return "https://openrouter.ai/api/v1/chat/completions"
|
||||
case .mistral:
|
||||
return "https://api.mistral.ai/v1/chat/completions"
|
||||
case .elevenLabs:
|
||||
@ -35,6 +39,7 @@ enum AIProvider: String, CaseIterable {
|
||||
return "https://api.deepgram.com/v1/listen"
|
||||
case .custom:
|
||||
return UserDefaults.standard.string(forKey: "customProviderBaseURL") ?? ""
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,6 +65,8 @@ enum AIProvider: String, CaseIterable {
|
||||
return "whisper-1"
|
||||
case .custom:
|
||||
return UserDefaults.standard.string(forKey: "customProviderModel") ?? ""
|
||||
case .openRouter:
|
||||
return "openai/gpt-4o"
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +114,8 @@ enum AIProvider: String, CaseIterable {
|
||||
return ["whisper-1"]
|
||||
case .custom:
|
||||
return []
|
||||
case .openRouter:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,6 +172,8 @@ class AIService: ObservableObject {
|
||||
private let userDefaults = UserDefaults.standard
|
||||
private lazy var ollamaService = OllamaService()
|
||||
|
||||
private var openRouterModels: [String] = []
|
||||
|
||||
var connectedProviders: [AIProvider] {
|
||||
AIProvider.allCases.filter { provider in
|
||||
if provider == .ollama {
|
||||
@ -186,6 +197,8 @@ class AIService: ObservableObject {
|
||||
var availableModels: [String] {
|
||||
if selectedProvider == .ollama {
|
||||
return ollamaService.availableModels.map { $0.name }
|
||||
} else if selectedProvider == .openRouter {
|
||||
return openRouterModels
|
||||
}
|
||||
return selectedProvider.availableModels
|
||||
}
|
||||
@ -205,7 +218,6 @@ class AIService: ObservableObject {
|
||||
}
|
||||
} else {
|
||||
self.isAPIKeyValid = true
|
||||
|
||||
}
|
||||
|
||||
loadSavedModelSelections()
|
||||
@ -451,14 +463,6 @@ class AIService: ObservableObject {
|
||||
}
|
||||
|
||||
func enhanceWithOllama(text: String, systemPrompt: String) async throws -> String {
|
||||
// Ensure connection is established before attempting enhancement
|
||||
if !ollamaService.isConnected {
|
||||
await ollamaService.checkConnection()
|
||||
if ollamaService.isConnected && ollamaService.availableModels.isEmpty {
|
||||
await ollamaService.refreshModels()
|
||||
}
|
||||
}
|
||||
|
||||
logger.notice("🔄 Sending transcription to Ollama for enhancement (model: \(self.ollamaService.selectedModel))")
|
||||
do {
|
||||
let result = try await ollamaService.enhance(text, withSystemPrompt: systemPrompt)
|
||||
@ -479,8 +483,57 @@ class AIService: ObservableObject {
|
||||
ollamaService.selectedModel = modelName
|
||||
userDefaults.set(modelName, forKey: "ollamaSelectedModel")
|
||||
}
|
||||
|
||||
func fetchOpenRouterModels() async {
|
||||
let url = URL(string: "https://openrouter.ai/api/v1/models")!
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "GET"
|
||||
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
do {
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else {
|
||||
logger.error("Failed to fetch OpenRouter models: Invalid HTTP response")
|
||||
await MainActor.run {
|
||||
self.openRouterModels = []
|
||||
self.objectWillChange.send()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let dataArray = jsonResponse["data"] as? [[String: Any]] else {
|
||||
logger.error("Failed to parse OpenRouter models JSON")
|
||||
await MainActor.run {
|
||||
self.openRouterModels = []
|
||||
self.objectWillChange.send()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
let models = dataArray.compactMap { $0["id"] as? String }
|
||||
await MainActor.run {
|
||||
self.openRouterModels = models.sorted()
|
||||
if self.selectedProvider == .openRouter && self.currentModel == self.selectedProvider.defaultModel && !models.isEmpty {
|
||||
self.selectModel(models.sorted().first!)
|
||||
}
|
||||
self.objectWillChange.send()
|
||||
}
|
||||
logger.info("Successfully fetched \(models.count) OpenRouter models.")
|
||||
|
||||
} catch {
|
||||
logger.error("Error fetching OpenRouter models: \(error.localizedDescription)")
|
||||
await MainActor.run {
|
||||
self.openRouterModels = []
|
||||
self.objectWillChange.send()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension Notification.Name {
|
||||
static let aiProviderKeyChanged = Notification.Name("aiProviderKeyChanged")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -52,6 +52,10 @@ struct APIKeyManagementView: View {
|
||||
.onChange(of: aiService.selectedProvider) { oldValue, newValue in
|
||||
if aiService.selectedProvider == .ollama {
|
||||
checkOllamaConnection()
|
||||
} else if aiService.selectedProvider == .openRouter {
|
||||
Task {
|
||||
await aiService.fetchOpenRouterModels()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,12 +63,26 @@ struct APIKeyManagementView: View {
|
||||
if !aiService.availableModels.isEmpty &&
|
||||
aiService.selectedProvider != .ollama &&
|
||||
aiService.selectedProvider != .custom {
|
||||
Picker("Model", selection: Binding(
|
||||
get: { aiService.currentModel },
|
||||
set: { aiService.selectModel($0) }
|
||||
)) {
|
||||
ForEach(aiService.availableModels, id: \.self) { model in
|
||||
Text(model).tag(model)
|
||||
HStack {
|
||||
Picker("Model", selection: Binding(
|
||||
get: { aiService.currentModel },
|
||||
set: { aiService.selectModel($0) }
|
||||
)) {
|
||||
ForEach(aiService.availableModels, id: \.self) { model in
|
||||
Text(model).tag(model)
|
||||
}
|
||||
}
|
||||
|
||||
if aiService.selectedProvider == .openRouter {
|
||||
Button(action: {
|
||||
Task {
|
||||
await aiService.fetchOpenRouterModels()
|
||||
}
|
||||
}) {
|
||||
Image(systemName: "arrow.clockwise")
|
||||
}
|
||||
.buttonStyle(.borderless)
|
||||
.help("Refresh models")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -200,6 +218,7 @@ struct APIKeyManagementView: View {
|
||||
.padding()
|
||||
.background(Color.secondary.opacity(0.03))
|
||||
.cornerRadius(12)
|
||||
|
||||
} else if aiService.selectedProvider == .custom {
|
||||
VStack(alignment: .leading, spacing: 16) {
|
||||
// Header
|
||||
@ -391,6 +410,8 @@ struct APIKeyManagementView: View {
|
||||
URL(string: "https://console.deepgram.com/api-keys")!
|
||||
case .ollama, .custom:
|
||||
URL(string: "")! // This case should never be reached
|
||||
case .openRouter:
|
||||
URL(string: "https://openrouter.ai/keys")!
|
||||
}
|
||||
NSWorkspace.shared.open(url)
|
||||
} label: {
|
||||
@ -419,6 +440,10 @@ struct APIKeyManagementView: View {
|
||||
.onAppear {
|
||||
if aiService.selectedProvider == .ollama {
|
||||
checkOllamaConnection()
|
||||
} else if aiService.selectedProvider == .openRouter {
|
||||
Task {
|
||||
await aiService.fetchOpenRouterModels()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -336,13 +336,23 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
try? modelContext.save()
|
||||
text = enhancedText
|
||||
} catch {
|
||||
// Enhancement failed - save error in enhancedText field and show notification
|
||||
let newTranscription = Transcription(
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
enhancedText: "Enhancement failed: \(error.localizedDescription)",
|
||||
audioFileURL: permanentURL?.absoluteString
|
||||
)
|
||||
modelContext.insert(newTranscription)
|
||||
try? modelContext.save()
|
||||
|
||||
// Show notification about enhancement failure
|
||||
await MainActor.run {
|
||||
NotificationManager.shared.showNotification(
|
||||
title: "AI enhancement failed",
|
||||
type: .error
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let newTranscription = Transcription(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user