import Foundation
import SwiftData
import AppKit
import os
enum EnhancementPrompt {
case transcriptionEnhancement
case aiAssistant
}
class AIEnhancementService: ObservableObject {
private let logger = Logger(subsystem: "com.voiceink.enhancement", category: "AIEnhancementService")
@Published var isEnhancementEnabled: Bool {
didSet {
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
if isEnhancementEnabled && selectedPromptId == nil {
selectedPromptId = customPrompts.first?.id
}
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
NotificationCenter.default.post(name: .enhancementToggleChanged, object: nil)
}
}
@Published var useClipboardContext: Bool {
didSet {
UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext")
}
}
@Published var useScreenCaptureContext: Bool {
didSet {
UserDefaults.standard.set(useScreenCaptureContext, forKey: "useScreenCaptureContext")
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
}
}
@Published var customPrompts: [CustomPrompt] {
didSet {
if let encoded = try? JSONEncoder().encode(customPrompts) {
UserDefaults.standard.set(encoded, forKey: "customPrompts")
}
}
}
@Published var selectedPromptId: UUID? {
didSet {
UserDefaults.standard.set(selectedPromptId?.uuidString, forKey: "selectedPromptId")
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
NotificationCenter.default.post(name: .promptSelectionChanged, object: nil)
}
}
var activePrompt: CustomPrompt? {
allPrompts.first { $0.id == selectedPromptId }
}
var allPrompts: [CustomPrompt] {
return customPrompts
}
private let aiService: AIService
private let screenCaptureService: ScreenCaptureService
private let dictionaryContextService: DictionaryContextService
private let baseTimeout: TimeInterval = 30
private let rateLimitInterval: TimeInterval = 1.0
private var lastRequestTime: Date?
private let modelContext: ModelContext
init(aiService: AIService = AIService(), modelContext: ModelContext) {
self.aiService = aiService
self.modelContext = modelContext
self.screenCaptureService = ScreenCaptureService()
self.dictionaryContextService = DictionaryContextService.shared
self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled")
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
self.customPrompts = PromptMigrationService.migratePromptsIfNeeded()
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
self.selectedPromptId = UUID(uuidString: savedPromptId)
}
if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) {
self.selectedPromptId = allPrompts.first?.id
}
NotificationCenter.default.addObserver(
self,
selector: #selector(handleAPIKeyChange),
name: .aiProviderKeyChanged,
object: nil
)
initializePredefinedPrompts()
}
deinit {
NotificationCenter.default.removeObserver(self)
}
@objc private func handleAPIKeyChange() {
DispatchQueue.main.async {
self.objectWillChange.send()
if !self.aiService.isAPIKeyValid {
self.isEnhancementEnabled = false
}
}
}
func getAIService() -> AIService? {
return aiService
}
var isConfigured: Bool {
aiService.isAPIKeyValid
}
private func waitForRateLimit() async throws {
if let lastRequest = lastRequestTime {
let timeSinceLastRequest = Date().timeIntervalSince(lastRequest)
if timeSinceLastRequest < rateLimitInterval {
try await Task.sleep(nanoseconds: UInt64((rateLimitInterval - timeSinceLastRequest) * 1_000_000_000))
}
}
lastRequestTime = Date()
}
private func getSystemMessage(for mode: EnhancementPrompt) -> String {
let selectedText = SelectedTextService.fetchSelectedText()
if let activePrompt = activePrompt,
activePrompt.id == PredefinedPrompts.assistantPromptId,
let selectedText = selectedText, !selectedText.isEmpty {
let selectedTextContext = "\n\nSelected Text: \(selectedText)"
let generalContextSection = "\n\n\(selectedTextContext)\n"
let dictionaryContextSection = if !dictionaryContextService.getDictionaryContext().isEmpty {
"\n\n\(dictionaryContextService.getDictionaryContext())\n"
} else {
""
}
return activePrompt.promptText + generalContextSection + dictionaryContextSection
}
let clipboardContext = if useClipboardContext,
let clipboardText = NSPasteboard.general.string(forType: .string),
!clipboardText.isEmpty {
"\n\n\n\(clipboardText)\n"
} else {
""
}
let screenCaptureContext = if useScreenCaptureContext,
let capturedText = screenCaptureService.lastCapturedText,
!capturedText.isEmpty {
"\n\nActive Window Context: \(capturedText)"
} else {
""
}
let dictionaryContext = dictionaryContextService.getDictionaryContext()
let generalContextSection = if !clipboardContext.isEmpty || !screenCaptureContext.isEmpty {
"\n\n\(clipboardContext)\(screenCaptureContext)\n"
} else {
""
}
let dictionaryContextSection = if !dictionaryContext.isEmpty {
"\n\n\(dictionaryContext)\n"
} else {
""
}
guard let activePrompt = activePrompt else {
if let defaultPrompt = allPrompts.first(where: { $0.id == PredefinedPrompts.defaultPromptId }) {
var systemMessage = String(format: AIPrompts.customPromptTemplate, defaultPrompt.promptText)
systemMessage += generalContextSection + dictionaryContextSection
return systemMessage
}
return AIPrompts.assistantMode + generalContextSection + dictionaryContextSection
}
if activePrompt.id == PredefinedPrompts.assistantPromptId {
return activePrompt.promptText + generalContextSection + dictionaryContextSection
}
var systemMessage = String(format: AIPrompts.customPromptTemplate, activePrompt.promptText)
systemMessage += generalContextSection + dictionaryContextSection
return systemMessage
}
private func makeRequest(text: String, mode: EnhancementPrompt) async throws -> String {
guard isConfigured else {
throw EnhancementError.notConfigured
}
guard !text.isEmpty else {
return "" // Silently return empty string instead of throwing error
}
let formattedText = "\n\n\(text)\n"
let systemMessage = getSystemMessage(for: mode)
// Log the message being sent to AI enhancement
logger.notice("AI Enhancement - System Message: \(systemMessage, privacy: .public)")
logger.notice("AI Enhancement - User Message: \(formattedText, privacy: .public)")
if aiService.selectedProvider == .ollama {
do {
let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage)
let filteredResult = AIEnhancementOutputFilter.filter(result)
return filteredResult
} catch {
if let localError = error as? LocalAIError {
throw EnhancementError.customError(localError.errorDescription ?? "An unknown Ollama error occurred.")
} else {
throw EnhancementError.customError(error.localizedDescription)
}
}
}
try await waitForRateLimit()
switch aiService.selectedProvider {
case .anthropic:
let requestBody: [String: Any] = [
"model": aiService.currentModel,
"max_tokens": 8192,
"system": systemMessage,
"messages": [
["role": "user", "content": formattedText]
]
]
var request = URLRequest(url: URL(string: aiService.selectedProvider.baseURL)!)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key")
request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
request.timeoutInterval = baseTimeout
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw EnhancementError.invalidResponse
}
if httpResponse.statusCode == 200 {
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let content = jsonResponse["content"] as? [[String: Any]],
let firstContent = content.first,
let enhancedText = firstContent["text"] as? String else {
throw EnhancementError.enhancementFailed
}
let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines))
return filteredText
} else if httpResponse.statusCode == 429 {
throw EnhancementError.rateLimitExceeded
} else if (500...599).contains(httpResponse.statusCode) {
throw EnhancementError.serverError
} else {
let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response."
throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)")
}
} catch let error as EnhancementError {
throw error
} catch let error as URLError {
throw error
} catch {
throw EnhancementError.customError(error.localizedDescription)
}
default:
let url = URL(string: aiService.selectedProvider.baseURL)!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization")
request.timeoutInterval = baseTimeout
let messages: [[String: Any]] = [
["role": "system", "content": systemMessage],
["role": "user", "content": formattedText]
]
let requestBody: [String: Any] = [
"model": aiService.currentModel,
"messages": messages,
"temperature": aiService.currentModel.lowercased().hasPrefix("gpt-5") ? 1.0 : 0.3,
"stream": false
]
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw EnhancementError.invalidResponse
}
if httpResponse.statusCode == 200 {
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let choices = jsonResponse["choices"] as? [[String: Any]],
let firstChoice = choices.first,
let message = firstChoice["message"] as? [String: Any],
let enhancedText = message["content"] as? String else {
throw EnhancementError.enhancementFailed
}
let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines))
return filteredText
} else if httpResponse.statusCode == 429 {
throw EnhancementError.rateLimitExceeded
} else if (500...599).contains(httpResponse.statusCode) {
throw EnhancementError.serverError
} else {
let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response."
throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)")
}
} catch let error as EnhancementError {
throw error
} catch let error as URLError {
throw error
} catch {
throw EnhancementError.customError(error.localizedDescription)
}
}
}
private func makeRequestWithRetry(text: String, mode: EnhancementPrompt, maxRetries: Int = 3, initialDelay: TimeInterval = 1.0) async throws -> String {
var retries = 0
var currentDelay = initialDelay
while retries < maxRetries {
do {
return try await makeRequest(text: text, mode: mode)
} catch let error as EnhancementError {
switch error {
case .networkError, .serverError, .rateLimitExceeded:
retries += 1
if retries < maxRetries {
logger.warning("Request failed, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))")
try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000))
currentDelay *= 2 // Exponential backoff
} else {
logger.error("Request failed after \(maxRetries) retries.")
throw error
}
default:
throw error
}
} catch {
// For other errors, check if it's a network-related URLError
let nsError = error as NSError
if nsError.domain == NSURLErrorDomain && [NSURLErrorNotConnectedToInternet, NSURLErrorTimedOut, NSURLErrorNetworkConnectionLost].contains(nsError.code) {
retries += 1
if retries < maxRetries {
logger.warning("Request failed with network error, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))")
try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000))
currentDelay *= 2 // Exponential backoff
} else {
logger.error("Request failed after \(maxRetries) retries with network error.")
throw EnhancementError.networkError
}
} else {
throw error
}
}
}
// This part should ideally not be reached, but as a fallback:
throw EnhancementError.enhancementFailed
}
func enhance(_ text: String) async throws -> (String, TimeInterval, String?) {
let startTime = Date()
let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement
let promptName = activePrompt?.title
do {
let result = try await makeRequestWithRetry(text: text, mode: enhancementPrompt)
let endTime = Date()
let duration = endTime.timeIntervalSince(startTime)
return (result, duration, promptName)
} catch {
throw error
}
}
func captureScreenContext() async {
guard useScreenCaptureContext else { return }
if let capturedText = await screenCaptureService.captureAndExtractText() {
await MainActor.run {
self.objectWillChange.send()
}
}
}
func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil, triggerWords: [String] = []) {
let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false, triggerWords: triggerWords)
customPrompts.append(newPrompt)
if customPrompts.count == 1 {
selectedPromptId = newPrompt.id
}
}
func updatePrompt(_ prompt: CustomPrompt) {
if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) {
customPrompts[index] = prompt
}
}
func deletePrompt(_ prompt: CustomPrompt) {
customPrompts.removeAll { $0.id == prompt.id }
if selectedPromptId == prompt.id {
selectedPromptId = allPrompts.first?.id
}
}
func setActivePrompt(_ prompt: CustomPrompt) {
selectedPromptId = prompt.id
}
private func initializePredefinedPrompts() {
let predefinedTemplates = PredefinedPrompts.createDefaultPrompts()
for template in predefinedTemplates {
if let existingIndex = customPrompts.firstIndex(where: { $0.id == template.id }) {
var updatedPrompt = customPrompts[existingIndex]
updatedPrompt = CustomPrompt(
id: updatedPrompt.id,
title: template.title,
promptText: template.promptText,
isActive: updatedPrompt.isActive,
icon: template.icon,
description: template.description,
isPredefined: true,
triggerWords: updatedPrompt.triggerWords
)
customPrompts[existingIndex] = updatedPrompt
} else {
customPrompts.append(template)
}
}
}
}
enum EnhancementError: Error {
case notConfigured
case invalidResponse
case enhancementFailed
case networkError
case serverError
case rateLimitExceeded
case customError(String)
}
extension EnhancementError: LocalizedError {
var errorDescription: String? {
switch self {
case .notConfigured:
return "AI provider not configured. Please check your API key."
case .invalidResponse:
return "Invalid response from AI provider."
case .enhancementFailed:
return "AI enhancement failed to process the text."
case .networkError:
return "Network connection failed. Check your internet."
case .serverError:
return "The AI provider's server encountered an error. Please try again later."
case .rateLimitExceeded:
return "Rate limit exceeded. Please try again later."
case .customError(let message):
return message
}
}
}