From 7781ecc19bcac4adb6c5cf8be003e024c6a45887 Mon Sep 17 00:00:00 2001 From: Beingpax Date: Thu, 29 May 2025 09:38:49 +0545 Subject: [PATCH] Improve AI trigger words and retry logic --- VoiceInk.xcodeproj/project.pbxproj | 8 +- VoiceInk/CursorPaster.swift | 1 - VoiceInk/Services/AIEnhancementService.swift | 137 ++++++------------ .../Services/PromptDetectionService.swift | 127 ++++++++++++++++ VoiceInk/Whisper/WhisperState.swift | 32 +++- 5 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 VoiceInk/Services/PromptDetectionService.swift diff --git a/VoiceInk.xcodeproj/project.pbxproj b/VoiceInk.xcodeproj/project.pbxproj index ae8bd8c..b4edee4 100644 --- a/VoiceInk.xcodeproj/project.pbxproj +++ b/VoiceInk.xcodeproj/project.pbxproj @@ -448,7 +448,7 @@ "CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development"; CODE_SIGN_STYLE = Automatic; COMBINE_HIDPI_IMAGES = YES; - CURRENT_PROJECT_VERSION = 129; + CURRENT_PROJECT_VERSION = 130; DEVELOPMENT_ASSET_PATHS = "\"VoiceInk/Preview Content\""; DEVELOPMENT_TEAM = V6J6A3VWY2; ENABLE_HARDENED_RUNTIME = YES; @@ -463,7 +463,7 @@ "@executable_path/../Frameworks", ); MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 1.29; + MARKETING_VERSION = 1.30; PRODUCT_BUNDLE_IDENTIFIER = com.prakashjoshipax.VoiceInk; PRODUCT_NAME = "$(TARGET_NAME)"; SWIFT_EMIT_LOC_STRINGS = YES; @@ -481,7 +481,7 @@ "CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development"; CODE_SIGN_STYLE = Automatic; COMBINE_HIDPI_IMAGES = YES; - CURRENT_PROJECT_VERSION = 129; + CURRENT_PROJECT_VERSION = 130; DEVELOPMENT_ASSET_PATHS = "\"VoiceInk/Preview Content\""; DEVELOPMENT_TEAM = V6J6A3VWY2; ENABLE_HARDENED_RUNTIME = YES; @@ -496,7 +496,7 @@ "@executable_path/../Frameworks", ); MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 1.29; + MARKETING_VERSION = 1.30; PRODUCT_BUNDLE_IDENTIFIER = com.prakashjoshipax.VoiceInk; PRODUCT_NAME = "$(TARGET_NAME)"; SWIFT_EMIT_LOC_STRINGS = YES; diff --git a/VoiceInk/CursorPaster.swift b/VoiceInk/CursorPaster.swift index 2cac84b..eb2c651 100644 --- a/VoiceInk/CursorPaster.swift +++ b/VoiceInk/CursorPaster.swift @@ -74,4 +74,3 @@ class CursorPaster { cmdUp?.post(tap: .cghidEventTap) } } - diff --git a/VoiceInk/Services/AIEnhancementService.swift b/VoiceInk/Services/AIEnhancementService.swift index c8a0179..71f9c9a 100644 --- a/VoiceInk/Services/AIEnhancementService.swift +++ b/VoiceInk/Services/AIEnhancementService.swift @@ -3,7 +3,7 @@ import os import SwiftData import AppKit -enum EnhancementMode { +enum EnhancementPrompt { case transcriptionEnhancement case aiAssistant } @@ -74,14 +74,11 @@ class AIEnhancementService: ObservableObject { private let screenCaptureService: ScreenCaptureService private var currentCaptureTask: Task? private let maxRetries = 3 - private let baseTimeout: TimeInterval = 4 + private let baseTimeout: TimeInterval = 10 private let rateLimitInterval: TimeInterval = 1.0 private var lastRequestTime: Date? private let modelContext: ModelContext - // Store the original prompt ID when temporarily switching due to trigger word - private var originalSelectedPromptId: UUID? - init(aiService: AIService = AIService(), modelContext: ModelContext) { self.aiService = aiService self.modelContext = modelContext @@ -146,43 +143,7 @@ class AIEnhancementService: ObservableObject { lastRequestTime = Date() } - private func determineMode(text: String) -> EnhancementMode { - let lowerText = text.lowercased() - - // First check if the text starts with the global assistant trigger word - if lowerText.hasPrefix(assistantTriggerWord.lowercased()) { - logger.notice("🔍 Detected assistant trigger word: \(self.assistantTriggerWord)") - return .aiAssistant - } - - // Then check for custom trigger words in all prompts - for prompt in allPrompts { - if let triggerWord = prompt.triggerWord?.lowercased().trimmingCharacters(in: .whitespacesAndNewlines), - !triggerWord.isEmpty, - lowerText.hasPrefix(triggerWord) { - - logger.notice("🔍 Detected custom trigger word: '\(triggerWord)' for mode: \(prompt.title)") - - // Only store the original prompt ID if we haven't already - if originalSelectedPromptId == nil { - originalSelectedPromptId = selectedPromptId - logger.notice("💾 Stored original prompt ID: \(String(describing: self.originalSelectedPromptId))") - } - - // Update to the new prompt - selectedPromptId = prompt.id - logger.notice("🔄 Switched to prompt: \(prompt.title) (ID: \(prompt.id))") - - return .transcriptionEnhancement - } - } - - // Default to transcription enhancement with currently selected prompt - logger.notice("â„šī¸ No trigger word detected, using default enhancement mode") - return .transcriptionEnhancement - } - - private func getSystemMessage(for mode: EnhancementMode) -> String { + private func getSystemMessage(for mode: EnhancementPrompt) -> String { let clipboardContext = if useClipboardContext, let clipboardText = NSPasteboard.general.string(forType: .string), !clipboardText.isEmpty { @@ -221,7 +182,7 @@ class AIEnhancementService: ObservableObject { } } - private func makeRequest(text: String, mode: EnhancementMode, retryCount: Int = 0) async throws -> String { + private func makeRequest(text: String, mode: EnhancementPrompt, retryCount: Int = 0) async throws -> String { guard isConfigured else { logger.error("AI Enhancement: API not configured") throw EnhancementError.notConfigured @@ -317,6 +278,7 @@ class AIEnhancementService: ObservableObject { case 429: throw EnhancementError.rateLimitExceeded case 500...599: + logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")") throw EnhancementError.serverError default: throw EnhancementError.apiError @@ -372,6 +334,7 @@ class AIEnhancementService: ObservableObject { case 429: throw EnhancementError.rateLimitExceeded case 500...599: + logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")") throw EnhancementError.serverError default: throw EnhancementError.apiError @@ -434,6 +397,7 @@ class AIEnhancementService: ObservableObject { case 429: throw EnhancementError.rateLimitExceeded case 500...599: + logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")") throw EnhancementError.serverError default: throw EnhancementError.apiError @@ -454,64 +418,43 @@ class AIEnhancementService: ObservableObject { func enhance(_ text: String) async throws -> String { logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)") - // Determine the mode and potentially set the active prompt based on trigger word - let mode = determineMode(text: text) - - // If a custom trigger word was detected, remove it from the text - var processedText = text - if mode == .transcriptionEnhancement, let activePrompt = activePrompt, let triggerWord = activePrompt.triggerWord, !triggerWord.isEmpty { - // Check if the text starts with the trigger word (case insensitive) - if text.lowercased().hasPrefix(triggerWord.lowercased()) { - // Remove the trigger word from the beginning of the text - let index = text.index(text.startIndex, offsetBy: triggerWord.count) - processedText = String(text[index...]).trimmingCharacters(in: .whitespacesAndNewlines) - logger.notice("🔍 Detected trigger word '\(triggerWord)' for mode '\(activePrompt.title)'. Processing: \(processedText)") + let enhancementPrompt: EnhancementPrompt = { + if let activePrompt = activePrompt, activePrompt.id == PredefinedPrompts.assistantPromptId { + return .aiAssistant } - } else if mode == .aiAssistant { - // Remove the assistant trigger word if present - if text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) { - let index = text.index(text.startIndex, offsetBy: assistantTriggerWord.count) - processedText = String(text[index...]).trimmingCharacters(in: .whitespacesAndNewlines) - } - } + return .transcriptionEnhancement + }() - // Process the text with the appropriate mode var retryCount = 0 while retryCount < maxRetries { do { - let result = try await makeRequest(text: processedText, mode: mode, retryCount: retryCount) + let result = try await makeRequest(text: text, mode: enhancementPrompt, retryCount: retryCount) logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)") - - // After successful enhancement, restore the original prompt if we temporarily switched - // due to a trigger word - Task { @MainActor in - self.restoreOriginalPrompt() - } - return result - } catch EnhancementError.rateLimitExceeded where retryCount < maxRetries - 1 { - logger.notice("âš ī¸ Rate limit exceeded, retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))") - retryCount += 1 - try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000)) - continue + } catch let error as EnhancementError { + if shouldRetry(error: error, retryCount: retryCount) { + let errorType = switch error { + case .rateLimitExceeded: "Rate limit exceeded" + case .serverError: "Server error occurred" + case .networkError: "Network error occurred" + default: "Unknown error" + } + + logger.notice("âš ī¸ \(errorType), retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))") + retryCount += 1 + let delaySeconds = getRetryDelay(for: retryCount) + try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000)) + continue + } else { + logger.notice("❌ AI enhancement failed: \(error.localizedDescription)") + throw error + } } catch { logger.notice("❌ AI enhancement failed: \(error.localizedDescription)") - - // Even if enhancement fails, we should restore the original prompt - Task { @MainActor in - self.restoreOriginalPrompt() - } - throw error } } logger.notice("❌ AI enhancement failed: maximum retries exceeded") - - // If we exceed max retries, also restore the original prompt - Task { @MainActor in - self.restoreOriginalPrompt() - } - throw EnhancementError.maxRetriesExceeded } @@ -554,14 +497,20 @@ class AIEnhancementService: ObservableObject { selectedPromptId = prompt.id } - /// Restores the original prompt ID if it was temporarily changed due to a trigger word - func restoreOriginalPrompt() { - if let originalId = originalSelectedPromptId { - selectedPromptId = originalId - originalSelectedPromptId = nil - logger.notice("🔄 Restored original enhancement mode after trigger word activation") + private func shouldRetry(error: EnhancementError, retryCount: Int) -> Bool { + guard retryCount < maxRetries - 1 else { return false } + + switch error { + case .rateLimitExceeded, .serverError, .networkError: + return true + default: + return false } } + + private func getRetryDelay(for retryCount: Int) -> TimeInterval { + return retryCount == 1 ? 1.0 : 2.0 + } } enum EnhancementError: Error { diff --git a/VoiceInk/Services/PromptDetectionService.swift b/VoiceInk/Services/PromptDetectionService.swift new file mode 100644 index 0000000..8b17b87 --- /dev/null +++ b/VoiceInk/Services/PromptDetectionService.swift @@ -0,0 +1,127 @@ +import Foundation +import os + +class PromptDetectionService { + private let logger = Logger( + subsystem: "com.prakashjoshipax.VoiceInk", + category: "promptdetection" + ) + + struct PromptDetectionResult { + let shouldEnableAI: Bool + let selectedPromptId: UUID? + let processedText: String + let detectedTriggerWord: String? + let originalEnhancementState: Bool + let originalPromptId: UUID? + } + + func analyzeText(_ text: String, with enhancementService: AIEnhancementService) -> PromptDetectionResult { + let originalEnhancementState = enhancementService.isEnhancementEnabled + let originalPromptId = enhancementService.selectedPromptId + + if let result = checkAssistantTrigger(text: text, triggerWord: enhancementService.assistantTriggerWord) { + return PromptDetectionResult( + shouldEnableAI: true, + selectedPromptId: PredefinedPrompts.assistantPromptId, + processedText: result, + detectedTriggerWord: enhancementService.assistantTriggerWord, + originalEnhancementState: originalEnhancementState, + originalPromptId: originalPromptId + ) + } + + for prompt in enhancementService.allPrompts { + if let triggerWord = prompt.triggerWord?.trimmingCharacters(in: .whitespacesAndNewlines), + !triggerWord.isEmpty, + let result = checkCustomTrigger(text: text, triggerWord: triggerWord) { + + return PromptDetectionResult( + shouldEnableAI: true, + selectedPromptId: prompt.id, + processedText: result, + detectedTriggerWord: triggerWord, + originalEnhancementState: originalEnhancementState, + originalPromptId: originalPromptId + ) + } + } + + return PromptDetectionResult( + shouldEnableAI: false, + selectedPromptId: nil, + processedText: text, + detectedTriggerWord: nil, + originalEnhancementState: originalEnhancementState, + originalPromptId: originalPromptId + ) + } + + func applyDetectionResult(_ result: PromptDetectionResult, to enhancementService: AIEnhancementService) async { + await MainActor.run { + if result.shouldEnableAI { + if !enhancementService.isEnhancementEnabled { + enhancementService.isEnhancementEnabled = true + } + if let promptId = result.selectedPromptId { + enhancementService.selectedPromptId = promptId + } + } + } + + if result.shouldEnableAI { + try? await Task.sleep(nanoseconds: 50_000_000) + } + } + + func restoreOriginalSettings(_ result: PromptDetectionResult, to enhancementService: AIEnhancementService) async { + if result.shouldEnableAI { + await MainActor.run { + if enhancementService.isEnhancementEnabled != result.originalEnhancementState { + enhancementService.isEnhancementEnabled = result.originalEnhancementState + } + if let originalId = result.originalPromptId, enhancementService.selectedPromptId != originalId { + enhancementService.selectedPromptId = originalId + } + } + } + } + + private func checkAssistantTrigger(text: String, triggerWord: String) -> String? { + return removeTriggerWord(from: text, triggerWord: triggerWord) + } + + private func checkCustomTrigger(text: String, triggerWord: String) -> String? { + return removeTriggerWord(from: text, triggerWord: triggerWord) + } + + private func removeTriggerWord(from text: String, triggerWord: String) -> String? { + let trimmedText = text.trimmingCharacters(in: .whitespacesAndNewlines) + let lowerText = trimmedText.lowercased() + let lowerTrigger = triggerWord.lowercased() + + guard lowerText.hasPrefix(lowerTrigger) else { return nil } + + let triggerEndIndex = trimmedText.index(trimmedText.startIndex, offsetBy: triggerWord.count) + + if triggerEndIndex >= trimmedText.endIndex { + return "" + } + + var remainingText = String(trimmedText[triggerEndIndex...]) + + remainingText = remainingText.replacingOccurrences( + of: "^[,\\.!\\?;:\\s]+", + with: "", + options: .regularExpression + ) + + remainingText = remainingText.trimmingCharacters(in: .whitespacesAndNewlines) + + if !remainingText.isEmpty { + remainingText = remainingText.prefix(1).uppercased() + remainingText.dropFirst() + } + + return remainingText + } +} \ No newline at end of file diff --git a/VoiceInk/Whisper/WhisperState.swift b/VoiceInk/Whisper/WhisperState.swift index 7349a32..7851bb1 100644 --- a/VoiceInk/Whisper/WhisperState.swift +++ b/VoiceInk/Whisper/WhisperState.swift @@ -48,6 +48,9 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { var recordedFile: URL? = nil let whisperPrompt = WhisperPrompt() + // Prompt detection service for trigger word handling + private let promptDetectionService = PromptDetectionService() + let modelContext: ModelContext private var modelUrl: URL? { @@ -295,24 +298,36 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { text = WordReplacementService.shared.applyReplacements(to: text) logger.notice("✅ Word replacements applied") } + + var promptDetectionResult: PromptDetectionService.PromptDetectionResult? = nil + let originalText = text + + if let enhancementService = enhancementService, enhancementService.isConfigured { + let detectionResult = promptDetectionService.analyzeText(text, with: enhancementService) + promptDetectionResult = detectionResult + await promptDetectionService.applyDetectionResult(detectionResult, to: enhancementService) + } + if let enhancementService = enhancementService, enhancementService.isEnhancementEnabled, enhancementService.isConfigured { do { if shouldCancelRecording { return } - let enhancedText = try await enhancementService.enhance(text) + // Use processed text (without trigger words) for AI enhancement + let textForAI = promptDetectionResult?.processedText ?? text + let enhancedText = try await enhancementService.enhance(textForAI) let newTranscription = Transcription( - text: text, + text: originalText, duration: actualDuration, enhancedText: enhancedText, audioFileURL: permanentURLString ) modelContext.insert(newTranscription) try? modelContext.save() - text = enhancedText + text = enhancedText } catch { let newTranscription = Transcription( - text: text, + text: originalText, duration: actualDuration, audioFileURL: permanentURLString ) @@ -321,7 +336,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { } } else { let newTranscription = Transcription( - text: text, + text: originalText, duration: actualDuration, audioFileURL: permanentURLString ) @@ -353,6 +368,13 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate { } } try? FileManager.default.removeItem(at: url) + + if let result = promptDetectionResult, + let enhancementService = enhancementService, + result.shouldEnableAI { + await promptDetectionService.restoreOriginalSettings(result, to: enhancementService) + } + await dismissMiniRecorder() await cleanupModelResources()