600 lines
24 KiB
Swift
600 lines
24 KiB
Swift
import Foundation
|
|
import os
|
|
import SwiftData
|
|
import AppKit
|
|
|
|
enum EnhancementMode {
|
|
case transcriptionEnhancement
|
|
case aiAssistant
|
|
}
|
|
|
|
class AIEnhancementService: ObservableObject {
|
|
private let logger = Logger(
|
|
subsystem: "com.voiceink.enhancement",
|
|
category: "AI"
|
|
)
|
|
|
|
@Published var isEnhancementEnabled: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
|
|
// When enhancement is enabled, ensure a prompt is selected
|
|
if isEnhancementEnabled && selectedPromptId == nil {
|
|
// Select the first prompt (default) if none is selected
|
|
selectedPromptId = customPrompts.first?.id
|
|
}
|
|
}
|
|
}
|
|
@Published var useClipboardContext: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext")
|
|
}
|
|
}
|
|
|
|
@Published var useScreenCaptureContext: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(useScreenCaptureContext, forKey: "useScreenCaptureContext")
|
|
}
|
|
}
|
|
|
|
@Published var assistantTriggerWord: String {
|
|
didSet {
|
|
UserDefaults.standard.set(assistantTriggerWord, forKey: "assistantTriggerWord")
|
|
}
|
|
}
|
|
|
|
@Published var customPrompts: [CustomPrompt] {
|
|
didSet {
|
|
if let encoded = try? JSONEncoder().encode(customPrompts.filter { !$0.isPredefined }) {
|
|
UserDefaults.standard.set(encoded, forKey: "customPrompts")
|
|
}
|
|
}
|
|
}
|
|
|
|
@Published var selectedPromptId: UUID? {
|
|
didSet {
|
|
UserDefaults.standard.set(selectedPromptId?.uuidString, forKey: "selectedPromptId")
|
|
}
|
|
}
|
|
|
|
var activePrompt: CustomPrompt? {
|
|
allPrompts.first { $0.id == selectedPromptId }
|
|
}
|
|
|
|
var allPrompts: [CustomPrompt] {
|
|
// Always include the latest default prompt first, followed by custom prompts
|
|
PredefinedPrompts.createDefaultPrompts() + customPrompts.filter { !$0.isPredefined }
|
|
}
|
|
|
|
private let aiService: AIService
|
|
private let screenCaptureService: ScreenCaptureService
|
|
private let maxRetries = 3
|
|
private let baseTimeout: TimeInterval = 4
|
|
private let rateLimitInterval: TimeInterval = 1.0 // 1 request per second
|
|
private var lastRequestTime: Date?
|
|
private let modelContext: ModelContext
|
|
|
|
init(aiService: AIService = AIService(), modelContext: ModelContext) {
|
|
self.aiService = aiService
|
|
self.modelContext = modelContext
|
|
self.screenCaptureService = ScreenCaptureService()
|
|
|
|
// Print UserDefaults domain
|
|
if let domain = Bundle.main.bundleIdentifier {
|
|
print("⚙️ UserDefaults domain: \(domain)")
|
|
if let prefsPath = NSSearchPathForDirectoriesInDomains(.libraryDirectory, .userDomainMask, true).first {
|
|
print("⚙️ Preferences directory: \(prefsPath)/Preferences/\(domain).plist")
|
|
}
|
|
}
|
|
|
|
self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled")
|
|
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
|
|
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
|
|
self.assistantTriggerWord = UserDefaults.standard.string(forKey: "assistantTriggerWord") ?? "hey"
|
|
|
|
// Load only custom prompts (non-predefined ones)
|
|
if let savedPromptsData = UserDefaults.standard.data(forKey: "customPrompts"),
|
|
let decodedPrompts = try? JSONDecoder().decode([CustomPrompt].self, from: savedPromptsData) {
|
|
self.customPrompts = decodedPrompts
|
|
} else {
|
|
self.customPrompts = []
|
|
}
|
|
|
|
// Load selected prompt ID
|
|
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
|
|
self.selectedPromptId = UUID(uuidString: savedPromptId)
|
|
}
|
|
|
|
// Ensure a prompt is selected if enhancement is enabled
|
|
if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) {
|
|
// Set first prompt (default) as selected
|
|
self.selectedPromptId = allPrompts.first?.id
|
|
}
|
|
|
|
// Setup notification observer for API key changes
|
|
NotificationCenter.default.addObserver(
|
|
self,
|
|
selector: #selector(handleAPIKeyChange),
|
|
name: .aiProviderKeyChanged,
|
|
object: nil
|
|
)
|
|
}
|
|
|
|
deinit {
|
|
NotificationCenter.default.removeObserver(self)
|
|
}
|
|
|
|
@objc private func handleAPIKeyChange() {
|
|
DispatchQueue.main.async {
|
|
self.objectWillChange.send()
|
|
// Optionally disable enhancement if API key is cleared
|
|
if !self.aiService.isAPIKeyValid {
|
|
self.isEnhancementEnabled = false
|
|
}
|
|
}
|
|
}
|
|
|
|
var isConfigured: Bool {
|
|
aiService.isAPIKeyValid
|
|
}
|
|
|
|
private func waitForRateLimit() async throws {
|
|
if let lastRequest = lastRequestTime {
|
|
let timeSinceLastRequest = Date().timeIntervalSince(lastRequest)
|
|
if timeSinceLastRequest < rateLimitInterval {
|
|
try await Task.sleep(nanoseconds: UInt64((rateLimitInterval - timeSinceLastRequest) * 1_000_000_000))
|
|
}
|
|
}
|
|
lastRequestTime = Date()
|
|
}
|
|
|
|
private func determineMode(text: String) -> EnhancementMode {
|
|
// Only use AI assistant mode if text starts with configured trigger word
|
|
if text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) {
|
|
return .aiAssistant
|
|
}
|
|
return .transcriptionEnhancement
|
|
}
|
|
|
|
private func getSystemMessage(for mode: EnhancementMode) -> String {
|
|
// Get clipboard context if enabled and available
|
|
let clipboardContext = if useClipboardContext,
|
|
let clipboardText = NSPasteboard.general.string(forType: .string),
|
|
!clipboardText.isEmpty {
|
|
"""
|
|
|
|
Context Awareness
|
|
Available Clipboard Context: \(clipboardText)
|
|
"""
|
|
} else {
|
|
""
|
|
}
|
|
|
|
// Get screen capture context if enabled and available
|
|
let screenCaptureContext = if useScreenCaptureContext,
|
|
let capturedText = screenCaptureService.lastCapturedText,
|
|
!capturedText.isEmpty {
|
|
"""
|
|
|
|
Active Window Context: \(capturedText)
|
|
"""
|
|
} else {
|
|
""
|
|
}
|
|
|
|
// Get word replacements if available
|
|
let wordReplacements = if let replacements = UserDefaults.standard.dictionary(forKey: "wordReplacements") as? [String: String],
|
|
!replacements.isEmpty {
|
|
"""
|
|
|
|
Word Replacements:
|
|
\(replacements.map { "Replace '\($0.key)' with '\($0.value)'" }.joined(separator: "\n"))
|
|
"""
|
|
} else {
|
|
""
|
|
}
|
|
|
|
switch mode {
|
|
case .transcriptionEnhancement:
|
|
// Always use activePrompt since we've removed the toggle
|
|
var systemMessage = String(format: AIPrompts.customPromptTemplate, activePrompt!.promptText)
|
|
systemMessage += "\n\n" + AIPrompts.contextInstructions
|
|
systemMessage += clipboardContext + screenCaptureContext + wordReplacements
|
|
return systemMessage
|
|
|
|
case .aiAssistant:
|
|
return AIPrompts.assistantMode + clipboardContext + screenCaptureContext
|
|
}
|
|
}
|
|
|
|
private func makeRequest(text: String, retryCount: Int = 0) async throws -> String {
|
|
guard isConfigured else {
|
|
logger.error("AI Enhancement: API not configured")
|
|
throw EnhancementError.notConfigured
|
|
}
|
|
|
|
guard !text.isEmpty else {
|
|
logger.error("AI Enhancement: Empty text received")
|
|
throw EnhancementError.emptyText
|
|
}
|
|
|
|
// Determine mode and get system message
|
|
let mode = determineMode(text: text)
|
|
let systemMessage = getSystemMessage(for: mode)
|
|
|
|
// Handle Ollama requests differently
|
|
if aiService.selectedProvider == .ollama {
|
|
logger.info("AI Enhancement: Using Ollama for enhancement")
|
|
do {
|
|
return try await aiService.enhanceWithOllama(text: text, systemPrompt: systemMessage)
|
|
} catch let error as LocalAIError {
|
|
logger.error("AI Enhancement: Ollama error - \(error.localizedDescription)")
|
|
switch error {
|
|
case .serviceUnavailable:
|
|
throw EnhancementError.notConfigured
|
|
case .modelNotFound:
|
|
throw EnhancementError.enhancementFailed
|
|
case .serverError:
|
|
throw EnhancementError.serverError
|
|
default:
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle cloud provider requests
|
|
// Wait for rate limit
|
|
try await waitForRateLimit()
|
|
|
|
// Special handling for Gemini and Anthropic
|
|
switch aiService.selectedProvider {
|
|
case .gemini:
|
|
var urlComponents = URLComponents(string: aiService.selectedProvider.baseURL)!
|
|
urlComponents.queryItems = [URLQueryItem(name: "key", value: aiService.apiKey)]
|
|
|
|
guard let url = urlComponents.url else {
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
|
|
let timeout = baseTimeout * pow(2.0, Double(retryCount))
|
|
request.timeoutInterval = timeout
|
|
|
|
let requestBody: [String: Any] = [
|
|
"contents": [
|
|
[
|
|
"parts": [
|
|
["text": systemMessage],
|
|
["text": "Transcript:\n\(text)"]
|
|
]
|
|
]
|
|
],
|
|
"generationConfig": [
|
|
"temperature": 0.3,
|
|
]
|
|
]
|
|
|
|
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
|
|
|
|
do {
|
|
logger.info("AI Enhancement: Sending request to Gemini API (attempt \(retryCount + 1))")
|
|
logger.debug("""
|
|
AI Enhancement Debug (Gemini):
|
|
System Message: \(systemMessage)
|
|
Original Text: \(text)
|
|
""")
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
logger.error("AI Enhancement: Invalid response received from Gemini")
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
switch httpResponse.statusCode {
|
|
case 200:
|
|
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let candidates = jsonResponse["candidates"] as? [[String: Any]],
|
|
let firstCandidate = candidates.first,
|
|
let content = firstCandidate["content"] as? [String: Any],
|
|
let parts = content["parts"] as? [[String: Any]],
|
|
let firstPart = parts.first,
|
|
let enhancedText = firstPart["text"] as? String else {
|
|
logger.error("AI Enhancement: Failed to parse Gemini API response")
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
logger.info("AI Enhancement: Successfully enhanced text using Gemini")
|
|
logger.debug("""
|
|
AI Enhancement Debug (Gemini):
|
|
Original Text: \(text)
|
|
Enhanced Text: \(result)
|
|
""")
|
|
return result
|
|
|
|
case 401:
|
|
logger.error("AI Enhancement: Authentication failed")
|
|
throw EnhancementError.authenticationFailed
|
|
|
|
case 429:
|
|
logger.error("AI Enhancement: Rate limit exceeded")
|
|
throw EnhancementError.rateLimitExceeded
|
|
|
|
case 500...599:
|
|
logger.error("AI Enhancement: Server error - Status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.serverError
|
|
|
|
default:
|
|
logger.error("AI Enhancement: Unexpected status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.apiError
|
|
}
|
|
} catch let error as EnhancementError {
|
|
throw error
|
|
} catch {
|
|
logger.error("AI Enhancement: Network error - \(error.localizedDescription)")
|
|
|
|
if retryCount < maxRetries {
|
|
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
|
|
return try await makeRequest(text: text, retryCount: retryCount + 1)
|
|
}
|
|
|
|
throw EnhancementError.networkError
|
|
}
|
|
|
|
case .anthropic:
|
|
let requestBody: [String: Any] = [
|
|
"model": aiService.selectedProvider.defaultModel,
|
|
"max_tokens": 1024,
|
|
"system": systemMessage,
|
|
"messages": [
|
|
["role": "user", "content": text]
|
|
]
|
|
]
|
|
|
|
var request = URLRequest(url: URL(string: aiService.selectedProvider.baseURL)!)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key")
|
|
request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
|
|
|
|
let timeout = baseTimeout * pow(2.0, Double(retryCount))
|
|
request.timeoutInterval = timeout
|
|
|
|
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
|
|
|
|
do {
|
|
logger.info("AI Enhancement: Sending request to Anthropic API (attempt \(retryCount + 1))")
|
|
logger.debug("""
|
|
AI Enhancement Debug (Anthropic):
|
|
System Message: \(systemMessage)
|
|
Original Text: \(text)
|
|
""")
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
logger.error("AI Enhancement: Invalid response received from Anthropic")
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
switch httpResponse.statusCode {
|
|
case 200:
|
|
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let content = jsonResponse["content"] as? [[String: Any]],
|
|
let firstContent = content.first,
|
|
let enhancedText = firstContent["text"] as? String else {
|
|
logger.error("AI Enhancement: Failed to parse Anthropic API response")
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
logger.info("AI Enhancement: Successfully enhanced text using Anthropic")
|
|
logger.debug("""
|
|
AI Enhancement Debug (Anthropic):
|
|
Original Text: \(text)
|
|
Enhanced Text: \(result)
|
|
""")
|
|
return result
|
|
|
|
case 401:
|
|
logger.error("AI Enhancement: Authentication failed")
|
|
throw EnhancementError.authenticationFailed
|
|
|
|
case 429:
|
|
logger.error("AI Enhancement: Rate limit exceeded")
|
|
throw EnhancementError.rateLimitExceeded
|
|
|
|
case 500...599:
|
|
logger.error("AI Enhancement: Server error - Status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.serverError
|
|
|
|
default:
|
|
logger.error("AI Enhancement: Unexpected status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.apiError
|
|
}
|
|
} catch let error as EnhancementError {
|
|
throw error
|
|
} catch {
|
|
logger.error("AI Enhancement: Network error - \(error.localizedDescription)")
|
|
|
|
if retryCount < maxRetries {
|
|
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
|
|
return try await makeRequest(text: text, retryCount: retryCount + 1)
|
|
}
|
|
|
|
throw EnhancementError.networkError
|
|
}
|
|
|
|
default:
|
|
// Handle OpenAI compatible providers
|
|
let url = URL(string: aiService.selectedProvider.baseURL)!
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization")
|
|
|
|
// Set timeout based on retry count with exponential backoff
|
|
let timeout = baseTimeout * pow(2.0, Double(retryCount))
|
|
request.timeoutInterval = timeout
|
|
|
|
logger.debug("Full system message: \(systemMessage)")
|
|
|
|
let messages: [[String: Any]] = [
|
|
["role": "system", "content": systemMessage],
|
|
["role": "user", "content": "Transcript:\n\(text)"]
|
|
]
|
|
|
|
logger.info("Making request to \(self.aiService.selectedProvider.rawValue) with text length: \(text.count) characters")
|
|
|
|
let requestBody: [String: Any] = [
|
|
"model": aiService.selectedProvider.defaultModel,
|
|
"messages": messages,
|
|
"temperature": 0.3,
|
|
"frequency_penalty": 0.0,
|
|
"presence_penalty": 0.0,
|
|
"stream": false
|
|
]
|
|
|
|
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
|
|
|
|
do {
|
|
logger.info("AI Enhancement: Sending request to \(self.aiService.selectedProvider.rawValue) API (attempt \(retryCount + 1))")
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
logger.error("AI Enhancement: Invalid response received from \(self.aiService.selectedProvider.rawValue)")
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
// Handle different HTTP status codes
|
|
switch httpResponse.statusCode {
|
|
case 200:
|
|
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let choices = jsonResponse["choices"] as? [[String: Any]],
|
|
let firstChoice = choices.first,
|
|
let message = firstChoice["message"] as? [String: Any],
|
|
let enhancedText = message["content"] as? String else {
|
|
logger.error("AI Enhancement: Failed to parse \(self.aiService.selectedProvider.rawValue) API response")
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
let result = enhancedText.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
logger.info("AI Enhancement: Successfully enhanced text using \(self.aiService.selectedProvider.rawValue)")
|
|
logger.debug("""
|
|
AI Enhancement Debug:
|
|
Original Text: \(text)
|
|
Enhanced Text: \(result)
|
|
""")
|
|
return result
|
|
|
|
case 401:
|
|
logger.error("AI Enhancement: Authentication failed")
|
|
throw EnhancementError.authenticationFailed
|
|
|
|
case 429:
|
|
logger.error("AI Enhancement: Rate limit exceeded")
|
|
throw EnhancementError.rateLimitExceeded
|
|
|
|
case 500...599:
|
|
logger.error("AI Enhancement: Server error - Status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.serverError
|
|
|
|
default:
|
|
logger.error("AI Enhancement: Unexpected status code: \(httpResponse.statusCode)")
|
|
throw EnhancementError.apiError
|
|
}
|
|
|
|
} catch let error as EnhancementError {
|
|
throw error
|
|
} catch {
|
|
logger.error("AI Enhancement: Network error - \(error.localizedDescription)")
|
|
|
|
if retryCount < maxRetries {
|
|
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
|
|
return try await makeRequest(text: text, retryCount: retryCount + 1)
|
|
}
|
|
|
|
throw EnhancementError.networkError
|
|
}
|
|
}
|
|
}
|
|
|
|
func enhance(_ text: String) async throws -> String {
|
|
var retryCount = 0
|
|
while retryCount < maxRetries {
|
|
do {
|
|
return try await makeRequest(text: text, retryCount: retryCount)
|
|
} catch EnhancementError.rateLimitExceeded where retryCount < maxRetries - 1 {
|
|
retryCount += 1
|
|
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
|
|
continue
|
|
} catch {
|
|
throw error
|
|
}
|
|
}
|
|
throw EnhancementError.maxRetriesExceeded
|
|
}
|
|
|
|
// Add a new method to capture screen context
|
|
func captureScreenContext() async {
|
|
// Only check for screen capture context toggle
|
|
guard useScreenCaptureContext else {
|
|
return
|
|
}
|
|
_ = await screenCaptureService.captureAndExtractText()
|
|
}
|
|
|
|
// MARK: - Prompt Management
|
|
|
|
func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil) {
|
|
let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false)
|
|
customPrompts.append(newPrompt)
|
|
if customPrompts.count == 1 {
|
|
selectedPromptId = newPrompt.id
|
|
}
|
|
}
|
|
|
|
func updatePrompt(_ prompt: CustomPrompt) {
|
|
// Don't allow updates to predefined prompts
|
|
if prompt.isPredefined {
|
|
return
|
|
}
|
|
|
|
if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) {
|
|
customPrompts[index] = prompt
|
|
}
|
|
}
|
|
|
|
func deletePrompt(_ prompt: CustomPrompt) {
|
|
// Don't allow deletion of predefined prompts
|
|
if prompt.isPredefined {
|
|
return
|
|
}
|
|
|
|
customPrompts.removeAll { $0.id == prompt.id }
|
|
if selectedPromptId == prompt.id {
|
|
selectedPromptId = allPrompts.first?.id
|
|
}
|
|
}
|
|
|
|
func setActivePrompt(_ prompt: CustomPrompt) {
|
|
selectedPromptId = prompt.id
|
|
}
|
|
}
|
|
|
|
enum EnhancementError: Error {
|
|
case notConfigured
|
|
case emptyText
|
|
case invalidResponse
|
|
case enhancementFailed
|
|
case authenticationFailed
|
|
case rateLimitExceeded
|
|
case serverError
|
|
case apiError
|
|
case networkError
|
|
case maxRetriesExceeded
|
|
}
|
|
|
|
|