feat: add AI model selection support

This commit is contained in:
Beingpax 2025-05-04 16:58:03 +05:45
parent a74de9dce5
commit d389d8ecc1
4 changed files with 240 additions and 339 deletions

View File

@ -17,16 +17,12 @@ class AIEnhancementService: ObservableObject {
@Published var isEnhancementEnabled: Bool {
didSet {
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
// When enhancement is enabled, ensure a prompt is selected
if isEnhancementEnabled && selectedPromptId == nil {
// Select the first prompt (default) if none is selected
selectedPromptId = customPrompts.first?.id
}
// Cancel any existing capture task
currentCaptureTask?.cancel()
// Trigger screen capture when enhancement is enabled and screen capture is on
if isEnhancementEnabled && useScreenCaptureContext {
currentCaptureTask = Task {
await captureScreenContext()
@ -71,7 +67,6 @@ class AIEnhancementService: ObservableObject {
}
var allPrompts: [CustomPrompt] {
// Always include the latest default prompt first, followed by custom prompts
PredefinedPrompts.createDefaultPrompts() + customPrompts.filter { !$0.isPredefined }
}
@ -80,7 +75,7 @@ class AIEnhancementService: ObservableObject {
private var currentCaptureTask: Task<Void, Never>?
private let maxRetries = 3
private let baseTimeout: TimeInterval = 4
private let rateLimitInterval: TimeInterval = 1.0 // 1 request per second
private let rateLimitInterval: TimeInterval = 1.0
private var lastRequestTime: Date?
private let modelContext: ModelContext
@ -89,20 +84,11 @@ class AIEnhancementService: ObservableObject {
self.modelContext = modelContext
self.screenCaptureService = ScreenCaptureService()
// Print UserDefaults domain
if let domain = Bundle.main.bundleIdentifier {
print("⚙️ UserDefaults domain: \(domain)")
if let prefsPath = NSSearchPathForDirectoriesInDomains(.libraryDirectory, .userDomainMask, true).first {
print("⚙️ Preferences directory: \(prefsPath)/Preferences/\(domain).plist")
}
}
self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled")
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
self.assistantTriggerWord = UserDefaults.standard.string(forKey: "assistantTriggerWord") ?? "hey"
// Load only custom prompts (non-predefined ones)
if let savedPromptsData = UserDefaults.standard.data(forKey: "customPrompts"),
let decodedPrompts = try? JSONDecoder().decode([CustomPrompt].self, from: savedPromptsData) {
self.customPrompts = decodedPrompts
@ -110,18 +96,14 @@ class AIEnhancementService: ObservableObject {
self.customPrompts = []
}
// Load selected prompt ID
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
self.selectedPromptId = UUID(uuidString: savedPromptId)
}
// Ensure a prompt is selected if enhancement is enabled
if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) {
// Set first prompt (default) as selected
self.selectedPromptId = allPrompts.first?.id
}
// Setup notification observer for API key changes
NotificationCenter.default.addObserver(
self,
selector: #selector(handleAPIKeyChange),
@ -137,7 +119,6 @@ class AIEnhancementService: ObservableObject {
@objc private func handleAPIKeyChange() {
DispatchQueue.main.async {
self.objectWillChange.send()
// Optionally disable enhancement if API key is cleared
if !self.aiService.isAPIKeyValid {
self.isEnhancementEnabled = false
}
@ -159,70 +140,44 @@ class AIEnhancementService: ObservableObject {
}
private func determineMode(text: String) -> EnhancementMode {
// Only use AI assistant mode if text starts with configured trigger word
if text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) {
return .aiAssistant
}
return .transcriptionEnhancement
text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) ? .aiAssistant : .transcriptionEnhancement
}
private func getSystemMessage(for mode: EnhancementMode) -> String {
// Get clipboard context if enabled and available
let clipboardContext = if useClipboardContext,
let clipboardText = NSPasteboard.general.string(forType: .string),
!clipboardText.isEmpty {
"""
Available Clipboard Context: \(clipboardText)
"""
"\n\nAvailable Clipboard Context: \(clipboardText)"
} else {
""
}
// Get screen capture context if enabled and available
let screenCaptureContext = if useScreenCaptureContext,
let capturedText = screenCaptureService.lastCapturedText,
!capturedText.isEmpty {
"""
Active Window Context: \(capturedText)
"""
"\n\nActive Window Context: \(capturedText)"
} else {
""
}
// Format all context information together with instructions
let contextSection = if !clipboardContext.isEmpty || !screenCaptureContext.isEmpty {
"""
\(AIPrompts.contextInstructions)
<CONTEXT_INFORMATION>
\(clipboardContext)
\(screenCaptureContext)
</CONTEXT_INFORMATION>
"""
"\n\n\(AIPrompts.contextInstructions)\n\n<CONTEXT_INFORMATION>\(clipboardContext)\(screenCaptureContext)\n</CONTEXT_INFORMATION>"
} else {
""
}
switch mode {
case .transcriptionEnhancement:
// Check if the active prompt is the Assistant prompt
if let activePrompt = activePrompt,
activePrompt.id == PredefinedPrompts.assistantPromptId {
// For the Assistant predefined prompt, use the assistant mode prompt directly
// This ensures proper system message formatting
return AIPrompts.assistantMode + contextSection
}
// For all other prompts, use the custom prompt template
var systemMessage = String(format: AIPrompts.customPromptTemplate, activePrompt!.promptText)
systemMessage += contextSection
return systemMessage
case .aiAssistant:
// For AI assistant mode, use the assistant mode prompt directly
return AIPrompts.assistantMode + contextSection
}
}
@ -238,49 +193,37 @@ class AIEnhancementService: ObservableObject {
throw EnhancementError.emptyText
}
// Format transcript with boundary markers
let formattedText = "<TRANSCRIPT>\n\(text)\n</TRANSCRIPT>"
// Determine mode and get system message
let mode = determineMode(text: text)
let systemMessage = getSystemMessage(for: mode)
// Handle Ollama requests differently
if aiService.selectedProvider == .ollama {
logger.notice("📤 Request to Ollama")
logger.notice("🤖 System: \(systemMessage, privacy: .public)")
logger.notice("📝 Sending: \(formattedText, privacy: .public)")
do {
let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage)
logger.notice("✅ Ollama enhancement successful")
logger.notice("📝 Received: \(result, privacy: .public)")
return result
} catch let error as LocalAIError {
switch error {
case .serviceUnavailable:
logger.error("🔌 Ollama service unavailable")
throw EnhancementError.notConfigured
case .modelNotFound:
logger.error("🤖 Ollama model not found")
throw EnhancementError.enhancementFailed
case .serverError:
logger.error("🔥 Ollama server error")
throw EnhancementError.serverError
default:
logger.error("❌ Ollama enhancement failed")
throw EnhancementError.enhancementFailed
}
}
}
// Handle cloud provider requests
// Wait for rate limit
try await waitForRateLimit()
// Special handling for Gemini and Anthropic
switch aiService.selectedProvider {
case .gemini:
var urlComponents = URLComponents(string: aiService.selectedProvider.baseURL)!
let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
let model = aiService.currentModel
let fullURL = "\(baseEndpoint)/\(model):generateContent"
var urlComponents = URLComponents(string: fullURL)!
urlComponents.queryItems = [URLQueryItem(name: "key", value: aiService.apiKey)]
guard let url = urlComponents.url else {
@ -290,9 +233,7 @@ class AIEnhancementService: ObservableObject {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
let timeout = baseTimeout * pow(2.0, Double(retryCount))
request.timeoutInterval = timeout
request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount))
let requestBody: [String: Any] = [
"contents": [
@ -311,13 +252,9 @@ class AIEnhancementService: ObservableObject {
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
do {
logger.notice("📤 Request to Gemini")
logger.notice("🤖 System: \(systemMessage, privacy: .public)")
logger.notice("📝 Sending: \(formattedText, privacy: .public)")
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
logger.error("❌ Invalid Gemini response")
throw EnhancementError.invalidResponse
}
@ -330,47 +267,33 @@ class AIEnhancementService: ObservableObject {
let parts = content["parts"] as? [[String: Any]],
let firstPart = parts.first,
let enhancedText = firstPart["text"] as? String else {
logger.error("❌ Failed to parse Gemini response")
throw EnhancementError.enhancementFailed
}
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
logger.notice("✅ Gemini enhancement successful")
logger.notice("📝 Received: \(result, privacy: .public)")
return result
return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
case 401:
logger.error("🔒 Authentication failed")
throw EnhancementError.authenticationFailed
case 429:
logger.error("⏳ Rate limit exceeded")
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("🔥 Server error (\(httpResponse.statusCode))")
throw EnhancementError.serverError
default:
logger.error("❌ Unexpected status (\(httpResponse.statusCode))")
throw EnhancementError.apiError
}
} catch let error as EnhancementError {
throw error
} catch {
logger.error("❌ Network error: \(error.localizedDescription)")
if retryCount < maxRetries {
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
return try await makeRequest(text: text, retryCount: retryCount + 1)
}
throw EnhancementError.networkError
}
case .anthropic:
let requestBody: [String: Any] = [
"model": aiService.selectedProvider.defaultModel,
"model": aiService.currentModel,
"max_tokens": 1024,
"system": systemMessage,
"messages": [
@ -383,20 +306,13 @@ class AIEnhancementService: ObservableObject {
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key")
request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
let timeout = baseTimeout * pow(2.0, Double(retryCount))
request.timeoutInterval = timeout
request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount))
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
do {
logger.notice("📤 Request to Anthropic")
logger.notice("🤖 System: \(systemMessage, privacy: .public)")
logger.notice("📝 Sending: \(formattedText, privacy: .public)")
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
logger.error("❌ Invalid Anthropic response")
throw EnhancementError.invalidResponse
}
@ -406,67 +322,45 @@ class AIEnhancementService: ObservableObject {
let content = jsonResponse["content"] as? [[String: Any]],
let firstContent = content.first,
let enhancedText = firstContent["text"] as? String else {
logger.error("❌ Failed to parse Anthropic response")
throw EnhancementError.enhancementFailed
}
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
logger.notice("✅ Anthropic enhancement successful")
logger.notice("📝 Received: \(result, privacy: .public)")
return result
return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
case 401:
logger.error("🔒 Authentication failed")
throw EnhancementError.authenticationFailed
case 429:
logger.error("⏳ Rate limit exceeded")
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("🔥 Server error (\(httpResponse.statusCode))")
throw EnhancementError.serverError
default:
logger.error("❌ Unexpected status (\(httpResponse.statusCode))")
throw EnhancementError.apiError
}
} catch let error as EnhancementError {
throw error
} catch {
logger.error("❌ Network error: \(error.localizedDescription)")
if retryCount < maxRetries {
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
return try await makeRequest(text: text, retryCount: retryCount + 1)
}
throw EnhancementError.networkError
}
default:
// Handle OpenAI compatible providers
let url = URL(string: aiService.selectedProvider.baseURL)!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization")
// Set timeout based on retry count with exponential backoff
let timeout = baseTimeout * pow(2.0, Double(retryCount))
request.timeoutInterval = timeout
logger.debug("Full system message: \(systemMessage)")
request.timeoutInterval = baseTimeout * pow(2.0, Double(retryCount))
let messages: [[String: Any]] = [
["role": "system", "content": systemMessage],
["role": "user", "content": formattedText]
]
logger.info("Making request to \(self.aiService.selectedProvider.rawValue) with text length: \(text.count) characters")
let requestBody: [String: Any] = [
"model": aiService.selectedProvider.defaultModel,
"model": aiService.currentModel,
"messages": messages,
"temperature": 0.3,
"frequency_penalty": 0.0,
@ -477,13 +371,9 @@ class AIEnhancementService: ObservableObject {
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
do {
logger.notice("📤 Request to \(self.aiService.selectedProvider.rawValue, privacy: .public)")
logger.notice("🤖 System: \(systemMessage, privacy: .public)")
logger.notice("📝 Sending: \(formattedText, privacy: .public)")
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
logger.error("❌ Invalid response")
throw EnhancementError.invalidResponse
}
@ -494,87 +384,65 @@ class AIEnhancementService: ObservableObject {
let firstChoice = choices.first,
let message = firstChoice["message"] as? [String: Any],
let enhancedText = message["content"] as? String else {
logger.error("❌ Failed to parse response")
throw EnhancementError.enhancementFailed
}
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
logger.notice("✅ Enhancement successful")
logger.notice("📝 Received: \(result, privacy: .public)")
return result
return enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
case 401:
logger.error("🔒 Authentication failed")
throw EnhancementError.authenticationFailed
case 429:
logger.error("⏳ Rate limit exceeded")
throw EnhancementError.rateLimitExceeded
case 500...599:
logger.error("🔥 Server error (\(httpResponse.statusCode))")
throw EnhancementError.serverError
default:
logger.error("❌ Unexpected status (\(httpResponse.statusCode))")
throw EnhancementError.apiError
}
} catch let error as EnhancementError {
throw error
} catch {
logger.error("❌ Network error: \(error.localizedDescription)")
if retryCount < maxRetries {
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
return try await makeRequest(text: text, retryCount: retryCount + 1)
}
throw EnhancementError.networkError
}
}
}
func enhance(_ text: String) async throws -> String {
logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)")
var retryCount = 0
while retryCount < maxRetries {
do {
return try await makeRequest(text: text, retryCount: retryCount)
let result = try await makeRequest(text: text, retryCount: retryCount)
logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)")
return result
} catch EnhancementError.rateLimitExceeded where retryCount < maxRetries - 1 {
logger.notice("⚠️ Rate limit exceeded, retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))")
retryCount += 1
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
continue
} catch {
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
throw error
}
}
logger.notice("❌ AI enhancement failed: maximum retries exceeded")
throw EnhancementError.maxRetriesExceeded
}
// Add a new method to capture screen context
func captureScreenContext() async {
// Only check for screen capture context toggle
guard useScreenCaptureContext else {
logger.notice("📷 Screen capture context is disabled")
return
}
guard useScreenCaptureContext else { return }
logger.notice("📷 Initiating screen capture for context")
// Wait for the screen capture to complete and check result
if let capturedText = await screenCaptureService.captureAndExtractText() {
logger.notice("📷 Screen capture successful, got \(capturedText.count, privacy: .public) characters")
// Ensure we're on the main thread when updating published properties
await MainActor.run {
// Manually trigger objectWillChange to ensure UI updates
self.objectWillChange.send()
}
} else {
logger.notice("📷 Screen capture failed or returned empty result")
}
}
// MARK: - Prompt Management
func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil) {
let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false)
customPrompts.append(newPrompt)
@ -584,10 +452,7 @@ class AIEnhancementService: ObservableObject {
}
func updatePrompt(_ prompt: CustomPrompt) {
// Don't allow updates to predefined prompts
if prompt.isPredefined {
return
}
if prompt.isPredefined { return }
if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) {
customPrompts[index] = prompt
@ -595,10 +460,7 @@ class AIEnhancementService: ObservableObject {
}
func deletePrompt(_ prompt: CustomPrompt) {
// Don't allow deletion of predefined prompts
if prompt.isPredefined {
return
}
if prompt.isPredefined { return }
customPrompts.removeAll { $0.id == prompt.id }
if selectedPromptId == prompt.id {

View File

@ -1,37 +0,0 @@
enum AIPrompts {
static let customPromptTemplate = """
<SYSTEM_INSTRUCTIONS>
Your task is to reformat and enhance the text provided within <TRANSCRIPT> tags according to the following guidelines:
%@
IMPORTANT: The input will be wrapped in <TRANSCRIPT> tags to identify what needs enhancement.
Your response should ONLY be to enhance text WITHOUT any tags.
DO NOT include <TRANSCRIPT> tags in your response.
</SYSTEM_INSTRUCTIONS>
"""
static let assistantMode = """
<SYSTEM_INSTRUCTIONS>
Provide a direct clear, and concise reply to the user's query. Use the available context if directly related to the user's query.
Remember to:
1. Be helpful and informative
2. Be accurate and precise
3. Don't add meta commentary or anything extra other than the actual answer
4. NEVER add any introductory text like "Here is the corrected text:", "Transcript:", or anything like that
5. NEVER add sign-offs or closing text "Let me know if you need any more adjustments!", or anything like that except the actual answer.
6. Maintain a friendly, casual tone
</SYSTEM_INSTRUCTIONS>
"""
static let contextInstructions = """
<CONTEXT_USAGE_INSTRUCTIONS>
Your task is to work ONLY with the content within the <TRANSCRIPT> tags.
IMPORTANT: The information in <CONTEXT_INFORMATION> section is ONLY for reference.
- NEVER include the context directly in your output
- Context should only help you better understand the user's query
</CONTEXT_USAGE_INSTRUCTIONS>
"""
}

View File

@ -1,4 +1,5 @@
import Foundation
import os
enum AIProvider: String, CaseIterable {
case groq = "GROQ"
@ -19,7 +20,7 @@ enum AIProvider: String, CaseIterable {
case .deepSeek:
return "https://api.deepseek.com/v1/chat/completions"
case .gemini:
return "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
return "https://generativelanguage.googleapis.com/v1beta/models"
case .anthropic:
return "https://api.anthropic.com/v1/messages"
case .mistral:
@ -36,7 +37,7 @@ enum AIProvider: String, CaseIterable {
case .groq:
return "llama-3.3-70b-versatile"
case .openAI:
return "gpt-4o-mini-2024-07-18"
return "gpt-4.1-mini"
case .deepSeek:
return "deepseek-chat"
case .gemini:
@ -44,7 +45,7 @@ enum AIProvider: String, CaseIterable {
case .anthropic:
return "claude-3-5-sonnet-20241022"
case .mistral:
return "mistral-large-2411"
return "mistral-large-latest"
case .ollama:
return UserDefaults.standard.string(forKey: "ollamaSelectedModel") ?? "mistral"
case .custom:
@ -52,6 +53,48 @@ enum AIProvider: String, CaseIterable {
}
}
var availableModels: [String] {
switch self {
case .groq:
return [
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant"
]
case .openAI:
return [
"gpt-4.1",
"gpt-4.1-mini"
]
case .deepSeek:
return [
"deepseek-chat",
"deepseek-reasoner"
]
case .gemini:
return [
"gemini-2.5-flash-preview-04-17",
"gemini-2.0-flash",
"gemini-2.0-flash-lite"
]
case .anthropic:
return [
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
"claude-3-5-sonnet-latest"
]
case .mistral:
return [
"mistral-large-latest",
"mistral-small-latest",
"mistral-saba-latest"
]
case .ollama:
return []
case .custom:
return []
}
}
var requiresAPIKey: Bool {
switch self {
case .ollama:
@ -63,6 +106,8 @@ enum AIProvider: String, CaseIterable {
}
class AIService: ObservableObject {
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AIService")
@Published var apiKey: String = ""
@Published var isAPIKeyValid: Bool = false
@Published var customBaseURL: String = UserDefaults.standard.string(forKey: "customProviderBaseURL") ?? "" {
@ -78,7 +123,6 @@ class AIService: ObservableObject {
@Published var selectedProvider: AIProvider {
didSet {
userDefaults.set(selectedProvider.rawValue, forKey: "selectedAIProvider")
// Load API key for the selected provider if it requires one
if selectedProvider.requiresAPIKey {
if let savedKey = userDefaults.string(forKey: "\(selectedProvider.rawValue)APIKey") {
self.apiKey = savedKey
@ -88,10 +132,8 @@ class AIService: ObservableObject {
self.isAPIKeyValid = false
}
} else {
// For providers that don't require API key (like Ollama)
self.apiKey = ""
self.isAPIKeyValid = true
// Check Ollama connection
if selectedProvider == .ollama {
Task {
await ollamaService.checkConnection()
@ -102,6 +144,7 @@ class AIService: ObservableObject {
}
}
@Published private var selectedModels: [AIProvider: String] = [:]
private let userDefaults = UserDefaults.standard
private let ollamaService = OllamaService()
@ -116,25 +159,37 @@ class AIService: ObservableObject {
}
}
var currentModel: String {
if let selectedModel = selectedModels[selectedProvider],
!selectedModel.isEmpty,
(selectedProvider == .ollama && !selectedModel.isEmpty) || availableModels.contains(selectedModel) {
return selectedModel
}
return selectedProvider.defaultModel
}
var availableModels: [String] {
if selectedProvider == .ollama {
return ollamaService.availableModels.map { $0.name }
}
return selectedProvider.availableModels
}
init() {
// Load selected provider
if let savedProvider = userDefaults.string(forKey: "selectedAIProvider"),
let provider = AIProvider(rawValue: savedProvider) {
self.selectedProvider = provider
} else {
self.selectedProvider = .gemini // Default to Gemini
self.selectedProvider = .gemini
}
// Load API key for the current provider if it requires one
if selectedProvider.requiresAPIKey {
if let savedKey = userDefaults.string(forKey: "\(selectedProvider.rawValue)APIKey") {
self.apiKey = savedKey
self.isAPIKeyValid = true
}
} else {
// For providers that don't require API key
self.isAPIKeyValid = true
// Check Ollama connection if it's the selected provider
if selectedProvider == .ollama {
Task {
await ollamaService.checkConnection()
@ -142,29 +197,48 @@ class AIService: ObservableObject {
}
}
}
loadSavedModelSelections()
}
private func loadSavedModelSelections() {
for provider in AIProvider.allCases {
let key = "\(provider.rawValue)SelectedModel"
if let savedModel = userDefaults.string(forKey: key), !savedModel.isEmpty {
selectedModels[provider] = savedModel
}
}
}
func selectModel(_ model: String) {
guard !model.isEmpty else { return }
selectedModels[selectedProvider] = model
let key = "\(selectedProvider.rawValue)SelectedModel"
userDefaults.set(model, forKey: key)
if selectedProvider == .ollama {
updateSelectedOllamaModel(model)
}
objectWillChange.send()
}
func saveAPIKey(_ key: String, completion: @escaping (Bool) -> Void) {
// Skip verification for providers that don't require API key
guard selectedProvider.requiresAPIKey else {
print("📝 [\(selectedProvider.rawValue)] API key not required, skipping verification")
completion(true)
return
}
print("🔑 [\(selectedProvider.rawValue)] Starting API key verification...")
// Verify the API key before saving
verifyAPIKey(key) { [weak self] isValid in
guard let self = self else { return }
DispatchQueue.main.async {
if isValid {
print("✅ [\(self.selectedProvider.rawValue)] API key verified successfully")
self.apiKey = key
self.isAPIKeyValid = true
self.userDefaults.set(key, forKey: "\(self.selectedProvider.rawValue)APIKey")
NotificationCenter.default.post(name: .aiProviderKeyChanged, object: nil)
} else {
print("❌ [\(self.selectedProvider.rawValue)] API key verification failed")
self.isAPIKeyValid = false
}
completion(isValid)
@ -173,18 +247,11 @@ class AIService: ObservableObject {
}
func verifyAPIKey(_ key: String, completion: @escaping (Bool) -> Void) {
// Skip verification for providers that don't require API key
guard selectedProvider.requiresAPIKey else {
print("📝 [\(selectedProvider.rawValue)] API key verification skipped - not required")
completion(true)
return
}
print("🔍 [\(selectedProvider.rawValue)] Verifying API key...")
print("🌐 Using base URL: \(selectedProvider.baseURL)")
print("🤖 Using model: \(selectedProvider.defaultModel)")
// Special handling for different providers
switch selectedProvider {
case .gemini:
verifyGeminiAPIKey(key, completion: completion)
@ -203,7 +270,7 @@ class AIService: ObservableObject {
request.addValue("Bearer \(key)", forHTTPHeaderField: "Authorization")
let testBody: [String: Any] = [
"model": selectedProvider.defaultModel,
"model": currentModel,
"messages": [
["role": "user", "content": "test"]
],
@ -212,19 +279,15 @@ class AIService: ObservableObject {
request.httpBody = try? JSONSerialization.data(withJSONObject: testBody)
print("📤 Sending verification request...")
URLSession.shared.dataTask(with: request) { data, response, error in
if let error = error {
print("❌ Network error during verification: \(error.localizedDescription)")
completion(false)
return
}
if let httpResponse = response as? HTTPURLResponse {
print("📥 Received response with status code: \(httpResponse.statusCode)")
completion(httpResponse.statusCode == 200)
} else {
print("❌ Invalid response received")
completion(false)
}
}.resume()
@ -239,7 +302,7 @@ class AIService: ObservableObject {
request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
let testBody: [String: Any] = [
"model": selectedProvider.defaultModel,
"model": currentModel,
"max_tokens": 1024,
"system": "You are a test system.",
"messages": [
@ -249,26 +312,26 @@ class AIService: ObservableObject {
request.httpBody = try? JSONSerialization.data(withJSONObject: testBody)
print("📤 Sending Anthropic verification request...")
URLSession.shared.dataTask(with: request) { data, response, error in
if let error = error {
print("❌ Network error during Anthropic verification: \(error.localizedDescription)")
completion(false)
return
}
if let httpResponse = response as? HTTPURLResponse {
print("📥 Received Anthropic response with status code: \(httpResponse.statusCode)")
completion(httpResponse.statusCode == 200)
} else {
print("❌ Invalid Anthropic response received")
completion(false)
}
}.resume()
}
private func verifyGeminiAPIKey(_ key: String, completion: @escaping (Bool) -> Void) {
var urlComponents = URLComponents(string: selectedProvider.baseURL)!
let baseEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
let model = currentModel
let fullURL = "\(baseEndpoint)/\(model):generateContent"
var urlComponents = URLComponents(string: fullURL)!
urlComponents.queryItems = [URLQueryItem(name: "key", value: key)]
guard let url = urlComponents.url else {
@ -292,26 +355,21 @@ class AIService: ObservableObject {
request.httpBody = try? JSONSerialization.data(withJSONObject: testBody)
print("📤 Sending Gemini verification request...")
URLSession.shared.dataTask(with: request) { data, response, error in
if let error = error {
print("❌ Network error during Gemini verification: \(error.localizedDescription)")
completion(false)
return
}
if let httpResponse = response as? HTTPURLResponse {
print("📥 Received Gemini response with status code: \(httpResponse.statusCode)")
completion(httpResponse.statusCode == 200)
} else {
print("❌ Invalid Gemini response received")
completion(false)
}
}.resume()
}
func clearAPIKey() {
// Skip for providers that don't require API key
guard selectedProvider.requiresAPIKey else { return }
apiKey = ""
@ -320,7 +378,6 @@ class AIService: ObservableObject {
NotificationCenter.default.post(name: .aiProviderKeyChanged, object: nil)
}
// Add method to check Ollama connection
func checkOllamaConnection(completion: @escaping (Bool) -> Void) {
Task { [weak self] in
guard let self = self else { return }
@ -331,31 +388,34 @@ class AIService: ObservableObject {
}
}
// Add method to get available Ollama models
func fetchOllamaModels() async -> [OllamaService.OllamaModel] {
await ollamaService.refreshModels()
return ollamaService.availableModels
}
// Add method to enhance text using Ollama
func enhanceWithOllama(text: String, systemPrompt: String) async throws -> String {
return try await ollamaService.enhance(text, withSystemPrompt: systemPrompt)
logger.notice("🔄 Sending transcription to Ollama for enhancement (model: \(self.ollamaService.selectedModel))")
do {
let result = try await ollamaService.enhance(text, withSystemPrompt: systemPrompt)
logger.notice("✅ Ollama enhancement completed successfully (\(result.count) characters)")
return result
} catch {
logger.notice("❌ Ollama enhancement failed: \(error.localizedDescription)")
throw error
}
}
// Add method to update Ollama base URL
func updateOllamaBaseURL(_ newURL: String) {
ollamaService.baseURL = newURL
userDefaults.set(newURL, forKey: "ollamaBaseURL")
}
// Add method to update selected Ollama model
func updateSelectedOllamaModel(_ modelName: String) {
ollamaService.selectedModel = modelName
userDefaults.set(modelName, forKey: "ollamaSelectedModel")
}
}
// Add extension for notification name
extension Notification.Name {
static let aiProviderKeyChanged = Notification.Name("aiProviderKeyChanged")
}

View File

@ -55,6 +55,20 @@ struct APIKeyManagementView: View {
}
}
// Model Selection - only show for standard providers with available models
if !aiService.availableModels.isEmpty &&
aiService.selectedProvider != .ollama &&
aiService.selectedProvider != .custom {
Picker("Model", selection: Binding(
get: { aiService.currentModel },
set: { aiService.selectModel($0) }
)) {
ForEach(aiService.availableModels, id: \.self) { model in
Text(model).tag(model)
}
}
}
if aiService.selectedProvider == .ollama {
// Ollama Configuration
VStack(alignment: .leading, spacing: 16) {
@ -441,103 +455,105 @@ struct APIKeyManagementView: View {
.padding()
.background(Color.secondary.opacity(0.03))
.cornerRadius(12)
} else if aiService.isAPIKeyValid {
// API Key Display for other providers
VStack(alignment: .leading, spacing: 8) {
Text("API Key")
.font(.subheadline)
.foregroundColor(.secondary)
HStack {
Text(String(repeating: "", count: 40))
} else {
// API Key Display for other providers if valid
if aiService.isAPIKeyValid {
VStack(alignment: .leading, spacing: 8) {
Text("API Key")
.font(.subheadline)
.foregroundColor(.secondary)
HStack {
Text(String(repeating: "", count: 40))
.font(.system(.body, design: .monospaced))
Spacer()
Button(action: {
aiService.clearAPIKey()
}) {
Label("Remove Key", systemImage: "trash")
.foregroundColor(.red)
}
.buttonStyle(.borderless)
}
}
} else {
// API Key Input for other providers
VStack(alignment: .leading, spacing: 8) {
Text("Enter your API Key")
.font(.subheadline)
.foregroundColor(.secondary)
SecureField("API Key", text: $apiKey)
.textFieldStyle(RoundedBorderTextFieldStyle())
.font(.system(.body, design: .monospaced))
Spacer()
Button(action: {
aiService.clearAPIKey()
}) {
Label("Remove Key", systemImage: "trash")
.foregroundColor(.red)
}
.buttonStyle(.borderless)
}
}
} else {
// API Key Input for other providers
VStack(alignment: .leading, spacing: 8) {
Text("Enter your API Key")
.font(.subheadline)
.foregroundColor(.secondary)
SecureField("API Key", text: $apiKey)
.textFieldStyle(RoundedBorderTextFieldStyle())
.font(.system(.body, design: .monospaced))
HStack {
Button(action: {
isVerifying = true
aiService.saveAPIKey(apiKey) { success in
isVerifying = false
if !success {
alertMessage = "Invalid API key. Please check and try again."
showAlert = true
HStack {
Button(action: {
isVerifying = true
aiService.saveAPIKey(apiKey) { success in
isVerifying = false
if !success {
alertMessage = "Invalid API key. Please check and try again."
showAlert = true
}
apiKey = ""
}
apiKey = ""
}
}) {
HStack {
if isVerifying {
ProgressView()
.scaleEffect(0.5)
.frame(width: 16, height: 16)
} else {
Image(systemName: "checkmark.circle.fill")
}) {
HStack {
if isVerifying {
ProgressView()
.scaleEffect(0.5)
.frame(width: 16, height: 16)
} else {
Image(systemName: "checkmark.circle.fill")
}
Text("Verify and Save")
}
Text("Verify and Save")
}
}
Spacer()
HStack(spacing: 8) {
Text(aiService.selectedProvider == .groq || aiService.selectedProvider == .gemini ? "Free" : "Paid")
.font(.caption2)
.foregroundColor(.secondary)
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Color.secondary.opacity(0.1))
.cornerRadius(4)
if aiService.selectedProvider != .ollama && aiService.selectedProvider != .custom {
Button {
let url = switch aiService.selectedProvider {
case .groq:
URL(string: "https://console.groq.com/keys")!
case .openAI:
URL(string: "https://platform.openai.com/api-keys")!
case .deepSeek:
URL(string: "https://platform.deepseek.com/api-keys")!
case .gemini:
URL(string: "https://makersuite.google.com/app/apikey")!
case .anthropic:
URL(string: "https://console.anthropic.com/settings/keys")!
case .mistral:
URL(string: "https://console.mistral.ai/api-keys")!
case .ollama, .custom:
URL(string: "")! // This case should never be reached
}
NSWorkspace.shared.open(url)
} label: {
HStack(spacing: 4) {
Text("Get API Key")
.foregroundColor(.accentColor)
Image(systemName: "arrow.up.right")
.font(.caption)
.foregroundColor(.accentColor)
Spacer()
HStack(spacing: 8) {
Text(aiService.selectedProvider == .groq || aiService.selectedProvider == .gemini ? "Free" : "Paid")
.font(.caption2)
.foregroundColor(.secondary)
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Color.secondary.opacity(0.1))
.cornerRadius(4)
if aiService.selectedProvider != .ollama && aiService.selectedProvider != .custom {
Button {
let url = switch aiService.selectedProvider {
case .groq:
URL(string: "https://console.groq.com/keys")!
case .openAI:
URL(string: "https://platform.openai.com/api-keys")!
case .deepSeek:
URL(string: "https://platform.deepseek.com/api-keys")!
case .gemini:
URL(string: "https://makersuite.google.com/app/apikey")!
case .anthropic:
URL(string: "https://console.anthropic.com/settings/keys")!
case .mistral:
URL(string: "https://console.mistral.ai/api-keys")!
case .ollama, .custom:
URL(string: "")! // This case should never be reached
}
NSWorkspace.shared.open(url)
} label: {
HStack(spacing: 4) {
Text("Get API Key")
.foregroundColor(.accentColor)
Image(systemName: "arrow.up.right")
.font(.caption)
.foregroundColor(.accentColor)
}
}
.buttonStyle(.plain)
}
.buttonStyle(.plain)
}
}
}