512 lines
20 KiB
Swift
512 lines
20 KiB
Swift
import Foundation
|
|
import SwiftData
|
|
import AppKit
|
|
import os
|
|
|
|
enum EnhancementPrompt {
|
|
case transcriptionEnhancement
|
|
case aiAssistant
|
|
}
|
|
|
|
class AIEnhancementService: ObservableObject {
|
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AIEnhancementService")
|
|
|
|
@Published var isEnhancementEnabled: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(isEnhancementEnabled, forKey: "isAIEnhancementEnabled")
|
|
if isEnhancementEnabled && selectedPromptId == nil {
|
|
selectedPromptId = customPrompts.first?.id
|
|
}
|
|
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
|
|
NotificationCenter.default.post(name: .enhancementToggleChanged, object: nil)
|
|
}
|
|
}
|
|
|
|
@Published var useClipboardContext: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(useClipboardContext, forKey: "useClipboardContext")
|
|
}
|
|
}
|
|
|
|
@Published var useScreenCaptureContext: Bool {
|
|
didSet {
|
|
UserDefaults.standard.set(useScreenCaptureContext, forKey: "useScreenCaptureContext")
|
|
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
|
|
}
|
|
}
|
|
|
|
@Published var customPrompts: [CustomPrompt] {
|
|
didSet {
|
|
if let encoded = try? JSONEncoder().encode(customPrompts) {
|
|
UserDefaults.standard.set(encoded, forKey: "customPrompts")
|
|
}
|
|
}
|
|
}
|
|
|
|
@Published var selectedPromptId: UUID? {
|
|
didSet {
|
|
UserDefaults.standard.set(selectedPromptId?.uuidString, forKey: "selectedPromptId")
|
|
NotificationCenter.default.post(name: .AppSettingsDidChange, object: nil)
|
|
NotificationCenter.default.post(name: .promptSelectionChanged, object: nil)
|
|
}
|
|
}
|
|
|
|
@Published var lastSystemMessageSent: String?
|
|
@Published var lastUserMessageSent: String?
|
|
|
|
var activePrompt: CustomPrompt? {
|
|
allPrompts.first { $0.id == selectedPromptId }
|
|
}
|
|
|
|
var allPrompts: [CustomPrompt] {
|
|
return customPrompts
|
|
}
|
|
|
|
private let aiService: AIService
|
|
private let screenCaptureService: ScreenCaptureService
|
|
private let dictionaryContextService: DictionaryContextService
|
|
private let baseTimeout: TimeInterval = 30
|
|
private let rateLimitInterval: TimeInterval = 1.0
|
|
private var lastRequestTime: Date?
|
|
private let modelContext: ModelContext
|
|
|
|
@Published var lastCapturedClipboard: String?
|
|
|
|
init(aiService: AIService = AIService(), modelContext: ModelContext) {
|
|
self.aiService = aiService
|
|
self.modelContext = modelContext
|
|
self.screenCaptureService = ScreenCaptureService()
|
|
self.dictionaryContextService = DictionaryContextService.shared
|
|
|
|
self.isEnhancementEnabled = UserDefaults.standard.bool(forKey: "isAIEnhancementEnabled")
|
|
self.useClipboardContext = UserDefaults.standard.bool(forKey: "useClipboardContext")
|
|
self.useScreenCaptureContext = UserDefaults.standard.bool(forKey: "useScreenCaptureContext")
|
|
|
|
if let savedPromptsData = UserDefaults.standard.data(forKey: "customPrompts"),
|
|
let decodedPrompts = try? JSONDecoder().decode([CustomPrompt].self, from: savedPromptsData) {
|
|
self.customPrompts = decodedPrompts
|
|
} else {
|
|
self.customPrompts = []
|
|
}
|
|
|
|
if let savedPromptId = UserDefaults.standard.string(forKey: "selectedPromptId") {
|
|
self.selectedPromptId = UUID(uuidString: savedPromptId)
|
|
}
|
|
|
|
if isEnhancementEnabled && (selectedPromptId == nil || !allPrompts.contains(where: { $0.id == selectedPromptId })) {
|
|
self.selectedPromptId = allPrompts.first?.id
|
|
}
|
|
|
|
NotificationCenter.default.addObserver(
|
|
self,
|
|
selector: #selector(handleAPIKeyChange),
|
|
name: .aiProviderKeyChanged,
|
|
object: nil
|
|
)
|
|
|
|
initializePredefinedPrompts()
|
|
}
|
|
|
|
deinit {
|
|
NotificationCenter.default.removeObserver(self)
|
|
}
|
|
|
|
@objc private func handleAPIKeyChange() {
|
|
DispatchQueue.main.async {
|
|
self.objectWillChange.send()
|
|
if !self.aiService.isAPIKeyValid {
|
|
self.isEnhancementEnabled = false
|
|
}
|
|
}
|
|
}
|
|
|
|
func getAIService() -> AIService? {
|
|
return aiService
|
|
}
|
|
|
|
var isConfigured: Bool {
|
|
aiService.isAPIKeyValid
|
|
}
|
|
|
|
private func waitForRateLimit() async throws {
|
|
if let lastRequest = lastRequestTime {
|
|
let timeSinceLastRequest = Date().timeIntervalSince(lastRequest)
|
|
if timeSinceLastRequest < rateLimitInterval {
|
|
try await Task.sleep(nanoseconds: UInt64((rateLimitInterval - timeSinceLastRequest) * 1_000_000_000))
|
|
}
|
|
}
|
|
lastRequestTime = Date()
|
|
}
|
|
|
|
private func getSystemMessage(for mode: EnhancementPrompt) -> String {
|
|
let selectedText = SelectedTextService.fetchSelectedText()
|
|
|
|
if let selectedText = selectedText, !selectedText.isEmpty {
|
|
let generalContextSection = "\n\n<CONTEXT_INFORMATION>\n\(selectedText)\n</CONTEXT_INFORMATION>"
|
|
|
|
if let activePrompt = activePrompt {
|
|
if activePrompt.id == PredefinedPrompts.assistantPromptId {
|
|
return activePrompt.promptText + generalContextSection
|
|
} else {
|
|
return activePrompt.finalPromptText + generalContextSection
|
|
}
|
|
} else {
|
|
if let defaultPrompt = allPrompts.first(where: { $0.id == PredefinedPrompts.defaultPromptId }) {
|
|
return defaultPrompt.finalPromptText + generalContextSection
|
|
}
|
|
return AIPrompts.assistantMode + generalContextSection
|
|
}
|
|
}
|
|
|
|
let clipboardContext = if useClipboardContext,
|
|
let clipboardText = lastCapturedClipboard,
|
|
!clipboardText.isEmpty {
|
|
"\n\n<CLIPBOARD_CONTEXT>\n\(clipboardText)\n</CLIPBOARD_CONTEXT>"
|
|
} else {
|
|
""
|
|
}
|
|
|
|
let screenCaptureContext = if useScreenCaptureContext,
|
|
let capturedText = screenCaptureService.lastCapturedText,
|
|
!capturedText.isEmpty {
|
|
"\n\n<CURRENT_WINDOW_CONTEXT>\n\(capturedText)\n</CURRENT_WINDOW_CONTEXT>"
|
|
} else {
|
|
""
|
|
}
|
|
|
|
let dictionaryContext = dictionaryContextService.getDictionaryContext()
|
|
|
|
let generalContextSection = clipboardContext + screenCaptureContext
|
|
|
|
let dictionaryContextSection = if !dictionaryContext.isEmpty {
|
|
"\n\n<DICTIONARY_CONTEXT>\(dictionaryContext)\n</DICTIONARY_CONTEXT>"
|
|
} else {
|
|
""
|
|
}
|
|
|
|
guard let activePrompt = activePrompt else {
|
|
if let defaultPrompt = allPrompts.first(where: { $0.id == PredefinedPrompts.defaultPromptId }) {
|
|
var systemMessage = defaultPrompt.finalPromptText
|
|
systemMessage += generalContextSection + dictionaryContextSection
|
|
return systemMessage
|
|
}
|
|
return AIPrompts.assistantMode + generalContextSection + dictionaryContextSection
|
|
}
|
|
|
|
if activePrompt.id == PredefinedPrompts.assistantPromptId {
|
|
return activePrompt.promptText + generalContextSection + dictionaryContextSection
|
|
}
|
|
|
|
var systemMessage = activePrompt.finalPromptText
|
|
systemMessage += generalContextSection + dictionaryContextSection
|
|
return systemMessage
|
|
}
|
|
|
|
private func makeRequest(text: String, mode: EnhancementPrompt) async throws -> String {
|
|
guard isConfigured else {
|
|
throw EnhancementError.notConfigured
|
|
}
|
|
|
|
guard !text.isEmpty else {
|
|
return "" // Silently return empty string instead of throwing error
|
|
}
|
|
|
|
let formattedText = "\n<TRANSCRIPT>\n\(text)\n</TRANSCRIPT>"
|
|
let systemMessage = getSystemMessage(for: mode)
|
|
|
|
// Persist the exact payload being sent (also used for UI)
|
|
await MainActor.run {
|
|
self.lastSystemMessageSent = systemMessage
|
|
self.lastUserMessageSent = formattedText
|
|
}
|
|
|
|
// Log the message being sent to AI enhancement
|
|
logger.notice("AI Enhancement - System Message: \(systemMessage, privacy: .public)")
|
|
logger.notice("AI Enhancement - User Message: \(formattedText, privacy: .public)")
|
|
|
|
if aiService.selectedProvider == .ollama {
|
|
do {
|
|
let result = try await aiService.enhanceWithOllama(text: formattedText, systemPrompt: systemMessage)
|
|
let filteredResult = AIEnhancementOutputFilter.filter(result)
|
|
return filteredResult
|
|
} catch {
|
|
if let localError = error as? LocalAIError {
|
|
throw EnhancementError.customError(localError.errorDescription ?? "An unknown Ollama error occurred.")
|
|
} else {
|
|
throw EnhancementError.customError(error.localizedDescription)
|
|
}
|
|
}
|
|
}
|
|
|
|
try await waitForRateLimit()
|
|
|
|
switch aiService.selectedProvider {
|
|
case .anthropic:
|
|
let requestBody: [String: Any] = [
|
|
"model": aiService.currentModel,
|
|
"max_tokens": 8192,
|
|
"system": systemMessage,
|
|
"messages": [
|
|
["role": "user", "content": formattedText]
|
|
]
|
|
]
|
|
|
|
var request = URLRequest(url: URL(string: aiService.selectedProvider.baseURL)!)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.addValue(aiService.apiKey, forHTTPHeaderField: "x-api-key")
|
|
request.addValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
|
|
request.timeoutInterval = baseTimeout
|
|
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
|
|
|
|
do {
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
if httpResponse.statusCode == 200 {
|
|
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let content = jsonResponse["content"] as? [[String: Any]],
|
|
let firstContent = content.first,
|
|
let enhancedText = firstContent["text"] as? String else {
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines))
|
|
return filteredText
|
|
} else if httpResponse.statusCode == 429 {
|
|
throw EnhancementError.rateLimitExceeded
|
|
} else if (500...599).contains(httpResponse.statusCode) {
|
|
throw EnhancementError.serverError
|
|
} else {
|
|
let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response."
|
|
throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)")
|
|
}
|
|
|
|
} catch let error as EnhancementError {
|
|
throw error
|
|
} catch let error as URLError {
|
|
throw error
|
|
} catch {
|
|
throw EnhancementError.customError(error.localizedDescription)
|
|
}
|
|
|
|
default:
|
|
let url = URL(string: aiService.selectedProvider.baseURL)!
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.addValue("Bearer \(aiService.apiKey)", forHTTPHeaderField: "Authorization")
|
|
request.timeoutInterval = baseTimeout
|
|
|
|
let messages: [[String: Any]] = [
|
|
["role": "system", "content": systemMessage],
|
|
["role": "user", "content": formattedText]
|
|
]
|
|
|
|
let requestBody: [String: Any] = [
|
|
"model": aiService.currentModel,
|
|
"messages": messages,
|
|
"temperature": aiService.currentModel.lowercased().hasPrefix("gpt-5") ? 1.0 : 0.3,
|
|
"stream": false
|
|
]
|
|
|
|
request.httpBody = try? JSONSerialization.data(withJSONObject: requestBody)
|
|
|
|
do {
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw EnhancementError.invalidResponse
|
|
}
|
|
|
|
if httpResponse.statusCode == 200 {
|
|
guard let jsonResponse = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
let choices = jsonResponse["choices"] as? [[String: Any]],
|
|
let firstChoice = choices.first,
|
|
let message = firstChoice["message"] as? [String: Any],
|
|
let enhancedText = message["content"] as? String else {
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
let filteredText = AIEnhancementOutputFilter.filter(enhancedText.trimmingCharacters(in: .whitespacesAndNewlines))
|
|
return filteredText
|
|
} else if httpResponse.statusCode == 429 {
|
|
throw EnhancementError.rateLimitExceeded
|
|
} else if (500...599).contains(httpResponse.statusCode) {
|
|
throw EnhancementError.serverError
|
|
} else {
|
|
let errorString = String(data: data, encoding: .utf8) ?? "Could not decode error response."
|
|
throw EnhancementError.customError("HTTP \(httpResponse.statusCode): \(errorString)")
|
|
}
|
|
|
|
} catch let error as EnhancementError {
|
|
throw error
|
|
} catch let error as URLError {
|
|
throw error
|
|
} catch {
|
|
throw EnhancementError.customError(error.localizedDescription)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func makeRequestWithRetry(text: String, mode: EnhancementPrompt, maxRetries: Int = 3, initialDelay: TimeInterval = 1.0) async throws -> String {
|
|
var retries = 0
|
|
var currentDelay = initialDelay
|
|
|
|
while retries < maxRetries {
|
|
do {
|
|
return try await makeRequest(text: text, mode: mode)
|
|
} catch let error as EnhancementError {
|
|
switch error {
|
|
case .networkError, .serverError, .rateLimitExceeded:
|
|
retries += 1
|
|
if retries < maxRetries {
|
|
logger.warning("Request failed, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))")
|
|
try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000))
|
|
currentDelay *= 2 // Exponential backoff
|
|
} else {
|
|
logger.error("Request failed after \(maxRetries) retries.")
|
|
throw error
|
|
}
|
|
default:
|
|
throw error
|
|
}
|
|
} catch {
|
|
// For other errors, check if it's a network-related URLError
|
|
let nsError = error as NSError
|
|
if nsError.domain == NSURLErrorDomain && [NSURLErrorNotConnectedToInternet, NSURLErrorTimedOut, NSURLErrorNetworkConnectionLost].contains(nsError.code) {
|
|
retries += 1
|
|
if retries < maxRetries {
|
|
logger.warning("Request failed with network error, retrying in \(currentDelay)s... (Attempt \(retries)/\(maxRetries))")
|
|
try await Task.sleep(nanoseconds: UInt64(currentDelay * 1_000_000_000))
|
|
currentDelay *= 2 // Exponential backoff
|
|
} else {
|
|
logger.error("Request failed after \(maxRetries) retries with network error.")
|
|
throw EnhancementError.networkError
|
|
}
|
|
} else {
|
|
throw error
|
|
}
|
|
}
|
|
}
|
|
|
|
// This part should ideally not be reached, but as a fallback:
|
|
throw EnhancementError.enhancementFailed
|
|
}
|
|
|
|
func enhance(_ text: String) async throws -> (String, TimeInterval, String?) {
|
|
let startTime = Date()
|
|
let enhancementPrompt: EnhancementPrompt = .transcriptionEnhancement
|
|
let promptName = activePrompt?.title
|
|
|
|
do {
|
|
let result = try await makeRequestWithRetry(text: text, mode: enhancementPrompt)
|
|
let endTime = Date()
|
|
let duration = endTime.timeIntervalSince(startTime)
|
|
return (result, duration, promptName)
|
|
} catch {
|
|
throw error
|
|
}
|
|
}
|
|
|
|
func captureScreenContext() async {
|
|
if let capturedText = await screenCaptureService.captureAndExtractText() {
|
|
await MainActor.run {
|
|
self.objectWillChange.send()
|
|
}
|
|
}
|
|
}
|
|
|
|
func captureClipboardContext() {
|
|
lastCapturedClipboard = NSPasteboard.general.string(forType: .string)
|
|
}
|
|
|
|
func clearCapturedContexts() {
|
|
lastCapturedClipboard = nil
|
|
screenCaptureService.lastCapturedText = nil
|
|
}
|
|
|
|
func addPrompt(title: String, promptText: String, icon: PromptIcon = .documentFill, description: String? = nil, triggerWords: [String] = [], useSystemInstructions: Bool = true) {
|
|
let newPrompt = CustomPrompt(title: title, promptText: promptText, icon: icon, description: description, isPredefined: false, triggerWords: triggerWords, useSystemInstructions: useSystemInstructions)
|
|
customPrompts.append(newPrompt)
|
|
if customPrompts.count == 1 {
|
|
selectedPromptId = newPrompt.id
|
|
}
|
|
}
|
|
|
|
func updatePrompt(_ prompt: CustomPrompt) {
|
|
if let index = customPrompts.firstIndex(where: { $0.id == prompt.id }) {
|
|
customPrompts[index] = prompt
|
|
}
|
|
}
|
|
|
|
func deletePrompt(_ prompt: CustomPrompt) {
|
|
customPrompts.removeAll { $0.id == prompt.id }
|
|
if selectedPromptId == prompt.id {
|
|
selectedPromptId = allPrompts.first?.id
|
|
}
|
|
}
|
|
|
|
func setActivePrompt(_ prompt: CustomPrompt) {
|
|
selectedPromptId = prompt.id
|
|
}
|
|
|
|
private func initializePredefinedPrompts() {
|
|
let predefinedTemplates = PredefinedPrompts.createDefaultPrompts()
|
|
|
|
for template in predefinedTemplates {
|
|
if let existingIndex = customPrompts.firstIndex(where: { $0.id == template.id }) {
|
|
var updatedPrompt = customPrompts[existingIndex]
|
|
updatedPrompt = CustomPrompt(
|
|
id: updatedPrompt.id,
|
|
title: template.title,
|
|
promptText: template.promptText,
|
|
isActive: updatedPrompt.isActive,
|
|
icon: template.icon,
|
|
description: template.description,
|
|
isPredefined: true,
|
|
triggerWords: updatedPrompt.triggerWords,
|
|
useSystemInstructions: template.useSystemInstructions
|
|
)
|
|
customPrompts[existingIndex] = updatedPrompt
|
|
} else {
|
|
customPrompts.append(template)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
enum EnhancementError: Error {
|
|
case notConfigured
|
|
case invalidResponse
|
|
case enhancementFailed
|
|
case networkError
|
|
case serverError
|
|
case rateLimitExceeded
|
|
case customError(String)
|
|
}
|
|
|
|
extension EnhancementError: LocalizedError {
|
|
var errorDescription: String? {
|
|
switch self {
|
|
case .notConfigured:
|
|
return "AI provider not configured. Please check your API key."
|
|
case .invalidResponse:
|
|
return "Invalid response from AI provider."
|
|
case .enhancementFailed:
|
|
return "AI enhancement failed to process the text."
|
|
case .networkError:
|
|
return "Network connection failed. Check your internet."
|
|
case .serverError:
|
|
return "The AI provider's server encountered an error. Please try again later."
|
|
case .rateLimitExceeded:
|
|
return "Rate limit exceeded. Please try again later."
|
|
case .customError(let message):
|
|
return message
|
|
}
|
|
}
|
|
}
|