Improve AI trigger words and retry logic
This commit is contained in:
parent
d659ee8a3d
commit
7781ecc19b
@ -448,7 +448,7 @@
|
||||
"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
COMBINE_HIDPI_IMAGES = YES;
|
||||
CURRENT_PROJECT_VERSION = 129;
|
||||
CURRENT_PROJECT_VERSION = 130;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"VoiceInk/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = V6J6A3VWY2;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
@ -463,7 +463,7 @@
|
||||
"@executable_path/../Frameworks",
|
||||
);
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.0;
|
||||
MARKETING_VERSION = 1.29;
|
||||
MARKETING_VERSION = 1.30;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.prakashjoshipax.VoiceInk;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
@ -481,7 +481,7 @@
|
||||
"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
COMBINE_HIDPI_IMAGES = YES;
|
||||
CURRENT_PROJECT_VERSION = 129;
|
||||
CURRENT_PROJECT_VERSION = 130;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"VoiceInk/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = V6J6A3VWY2;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
@ -496,7 +496,7 @@
|
||||
"@executable_path/../Frameworks",
|
||||
);
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.0;
|
||||
MARKETING_VERSION = 1.29;
|
||||
MARKETING_VERSION = 1.30;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.prakashjoshipax.VoiceInk;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
|
||||
@ -74,4 +74,3 @@ class CursorPaster {
|
||||
cmdUp?.post(tap: .cghidEventTap)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import SwiftData
|
||||
import AppKit
|
||||
|
||||
enum EnhancementMode {
|
||||
enum EnhancementPrompt {
|
||||
case transcriptionEnhancement
|
||||
case aiAssistant
|
||||
}
|
||||
@ -74,14 +74,11 @@ class AIEnhancementService: ObservableObject {
|
||||
private let screenCaptureService: ScreenCaptureService
|
||||
private var currentCaptureTask: Task<Void, Never>?
|
||||
private let maxRetries = 3
|
||||
private let baseTimeout: TimeInterval = 4
|
||||
private let baseTimeout: TimeInterval = 10
|
||||
private let rateLimitInterval: TimeInterval = 1.0
|
||||
private var lastRequestTime: Date?
|
||||
private let modelContext: ModelContext
|
||||
|
||||
// Store the original prompt ID when temporarily switching due to trigger word
|
||||
private var originalSelectedPromptId: UUID?
|
||||
|
||||
init(aiService: AIService = AIService(), modelContext: ModelContext) {
|
||||
self.aiService = aiService
|
||||
self.modelContext = modelContext
|
||||
@ -146,43 +143,7 @@ class AIEnhancementService: ObservableObject {
|
||||
lastRequestTime = Date()
|
||||
}
|
||||
|
||||
private func determineMode(text: String) -> EnhancementMode {
|
||||
let lowerText = text.lowercased()
|
||||
|
||||
// First check if the text starts with the global assistant trigger word
|
||||
if lowerText.hasPrefix(assistantTriggerWord.lowercased()) {
|
||||
logger.notice("🔍 Detected assistant trigger word: \(self.assistantTriggerWord)")
|
||||
return .aiAssistant
|
||||
}
|
||||
|
||||
// Then check for custom trigger words in all prompts
|
||||
for prompt in allPrompts {
|
||||
if let triggerWord = prompt.triggerWord?.lowercased().trimmingCharacters(in: .whitespacesAndNewlines),
|
||||
!triggerWord.isEmpty,
|
||||
lowerText.hasPrefix(triggerWord) {
|
||||
|
||||
logger.notice("🔍 Detected custom trigger word: '\(triggerWord)' for mode: \(prompt.title)")
|
||||
|
||||
// Only store the original prompt ID if we haven't already
|
||||
if originalSelectedPromptId == nil {
|
||||
originalSelectedPromptId = selectedPromptId
|
||||
logger.notice("💾 Stored original prompt ID: \(String(describing: self.originalSelectedPromptId))")
|
||||
}
|
||||
|
||||
// Update to the new prompt
|
||||
selectedPromptId = prompt.id
|
||||
logger.notice("🔄 Switched to prompt: \(prompt.title) (ID: \(prompt.id))")
|
||||
|
||||
return .transcriptionEnhancement
|
||||
}
|
||||
}
|
||||
|
||||
// Default to transcription enhancement with currently selected prompt
|
||||
logger.notice("ℹ️ No trigger word detected, using default enhancement mode")
|
||||
return .transcriptionEnhancement
|
||||
}
|
||||
|
||||
private func getSystemMessage(for mode: EnhancementMode) -> String {
|
||||
private func getSystemMessage(for mode: EnhancementPrompt) -> String {
|
||||
let clipboardContext = if useClipboardContext,
|
||||
let clipboardText = NSPasteboard.general.string(forType: .string),
|
||||
!clipboardText.isEmpty {
|
||||
@ -221,7 +182,7 @@ class AIEnhancementService: ObservableObject {
|
||||
}
|
||||
}
|
||||
|
||||
private func makeRequest(text: String, mode: EnhancementMode, retryCount: Int = 0) async throws -> String {
|
||||
private func makeRequest(text: String, mode: EnhancementPrompt, retryCount: Int = 0) async throws -> String {
|
||||
guard isConfigured else {
|
||||
logger.error("AI Enhancement: API not configured")
|
||||
throw EnhancementError.notConfigured
|
||||
@ -317,6 +278,7 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -372,6 +334,7 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -434,6 +397,7 @@ class AIEnhancementService: ObservableObject {
|
||||
case 429:
|
||||
throw EnhancementError.rateLimitExceeded
|
||||
case 500...599:
|
||||
logger.error("Server error (HTTP \(httpResponse.statusCode)): \(String(data: data, encoding: .utf8) ?? "No response data")")
|
||||
throw EnhancementError.serverError
|
||||
default:
|
||||
throw EnhancementError.apiError
|
||||
@ -454,64 +418,43 @@ class AIEnhancementService: ObservableObject {
|
||||
func enhance(_ text: String) async throws -> String {
|
||||
logger.notice("🚀 Starting AI enhancement for text (\(text.count) characters)")
|
||||
|
||||
// Determine the mode and potentially set the active prompt based on trigger word
|
||||
let mode = determineMode(text: text)
|
||||
|
||||
// If a custom trigger word was detected, remove it from the text
|
||||
var processedText = text
|
||||
if mode == .transcriptionEnhancement, let activePrompt = activePrompt, let triggerWord = activePrompt.triggerWord, !triggerWord.isEmpty {
|
||||
// Check if the text starts with the trigger word (case insensitive)
|
||||
if text.lowercased().hasPrefix(triggerWord.lowercased()) {
|
||||
// Remove the trigger word from the beginning of the text
|
||||
let index = text.index(text.startIndex, offsetBy: triggerWord.count)
|
||||
processedText = String(text[index...]).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
logger.notice("🔍 Detected trigger word '\(triggerWord)' for mode '\(activePrompt.title)'. Processing: \(processedText)")
|
||||
let enhancementPrompt: EnhancementPrompt = {
|
||||
if let activePrompt = activePrompt, activePrompt.id == PredefinedPrompts.assistantPromptId {
|
||||
return .aiAssistant
|
||||
}
|
||||
} else if mode == .aiAssistant {
|
||||
// Remove the assistant trigger word if present
|
||||
if text.lowercased().hasPrefix(assistantTriggerWord.lowercased()) {
|
||||
let index = text.index(text.startIndex, offsetBy: assistantTriggerWord.count)
|
||||
processedText = String(text[index...]).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
}
|
||||
}
|
||||
return .transcriptionEnhancement
|
||||
}()
|
||||
|
||||
// Process the text with the appropriate mode
|
||||
var retryCount = 0
|
||||
while retryCount < maxRetries {
|
||||
do {
|
||||
let result = try await makeRequest(text: processedText, mode: mode, retryCount: retryCount)
|
||||
let result = try await makeRequest(text: text, mode: enhancementPrompt, retryCount: retryCount)
|
||||
logger.notice("✅ AI enhancement completed successfully (\(result.count) characters)")
|
||||
|
||||
// After successful enhancement, restore the original prompt if we temporarily switched
|
||||
// due to a trigger word
|
||||
Task { @MainActor in
|
||||
self.restoreOriginalPrompt()
|
||||
}
|
||||
|
||||
return result
|
||||
} catch EnhancementError.rateLimitExceeded where retryCount < maxRetries - 1 {
|
||||
logger.notice("⚠️ Rate limit exceeded, retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))")
|
||||
retryCount += 1
|
||||
try await Task.sleep(nanoseconds: UInt64(pow(2.0, Double(retryCount)) * 1_000_000_000))
|
||||
continue
|
||||
} catch let error as EnhancementError {
|
||||
if shouldRetry(error: error, retryCount: retryCount) {
|
||||
let errorType = switch error {
|
||||
case .rateLimitExceeded: "Rate limit exceeded"
|
||||
case .serverError: "Server error occurred"
|
||||
case .networkError: "Network error occurred"
|
||||
default: "Unknown error"
|
||||
}
|
||||
|
||||
logger.notice("⚠️ \(errorType), retrying AI enhancement (attempt \(retryCount + 1) of \(self.maxRetries))")
|
||||
retryCount += 1
|
||||
let delaySeconds = getRetryDelay(for: retryCount)
|
||||
try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000))
|
||||
continue
|
||||
} else {
|
||||
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
} catch {
|
||||
logger.notice("❌ AI enhancement failed: \(error.localizedDescription)")
|
||||
|
||||
// Even if enhancement fails, we should restore the original prompt
|
||||
Task { @MainActor in
|
||||
self.restoreOriginalPrompt()
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
}
|
||||
logger.notice("❌ AI enhancement failed: maximum retries exceeded")
|
||||
|
||||
// If we exceed max retries, also restore the original prompt
|
||||
Task { @MainActor in
|
||||
self.restoreOriginalPrompt()
|
||||
}
|
||||
|
||||
throw EnhancementError.maxRetriesExceeded
|
||||
}
|
||||
|
||||
@ -554,14 +497,20 @@ class AIEnhancementService: ObservableObject {
|
||||
selectedPromptId = prompt.id
|
||||
}
|
||||
|
||||
/// Restores the original prompt ID if it was temporarily changed due to a trigger word
|
||||
func restoreOriginalPrompt() {
|
||||
if let originalId = originalSelectedPromptId {
|
||||
selectedPromptId = originalId
|
||||
originalSelectedPromptId = nil
|
||||
logger.notice("🔄 Restored original enhancement mode after trigger word activation")
|
||||
private func shouldRetry(error: EnhancementError, retryCount: Int) -> Bool {
|
||||
guard retryCount < maxRetries - 1 else { return false }
|
||||
|
||||
switch error {
|
||||
case .rateLimitExceeded, .serverError, .networkError:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func getRetryDelay(for retryCount: Int) -> TimeInterval {
|
||||
return retryCount == 1 ? 1.0 : 2.0
|
||||
}
|
||||
}
|
||||
|
||||
enum EnhancementError: Error {
|
||||
|
||||
127
VoiceInk/Services/PromptDetectionService.swift
Normal file
127
VoiceInk/Services/PromptDetectionService.swift
Normal file
@ -0,0 +1,127 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
class PromptDetectionService {
|
||||
private let logger = Logger(
|
||||
subsystem: "com.prakashjoshipax.VoiceInk",
|
||||
category: "promptdetection"
|
||||
)
|
||||
|
||||
struct PromptDetectionResult {
|
||||
let shouldEnableAI: Bool
|
||||
let selectedPromptId: UUID?
|
||||
let processedText: String
|
||||
let detectedTriggerWord: String?
|
||||
let originalEnhancementState: Bool
|
||||
let originalPromptId: UUID?
|
||||
}
|
||||
|
||||
func analyzeText(_ text: String, with enhancementService: AIEnhancementService) -> PromptDetectionResult {
|
||||
let originalEnhancementState = enhancementService.isEnhancementEnabled
|
||||
let originalPromptId = enhancementService.selectedPromptId
|
||||
|
||||
if let result = checkAssistantTrigger(text: text, triggerWord: enhancementService.assistantTriggerWord) {
|
||||
return PromptDetectionResult(
|
||||
shouldEnableAI: true,
|
||||
selectedPromptId: PredefinedPrompts.assistantPromptId,
|
||||
processedText: result,
|
||||
detectedTriggerWord: enhancementService.assistantTriggerWord,
|
||||
originalEnhancementState: originalEnhancementState,
|
||||
originalPromptId: originalPromptId
|
||||
)
|
||||
}
|
||||
|
||||
for prompt in enhancementService.allPrompts {
|
||||
if let triggerWord = prompt.triggerWord?.trimmingCharacters(in: .whitespacesAndNewlines),
|
||||
!triggerWord.isEmpty,
|
||||
let result = checkCustomTrigger(text: text, triggerWord: triggerWord) {
|
||||
|
||||
return PromptDetectionResult(
|
||||
shouldEnableAI: true,
|
||||
selectedPromptId: prompt.id,
|
||||
processedText: result,
|
||||
detectedTriggerWord: triggerWord,
|
||||
originalEnhancementState: originalEnhancementState,
|
||||
originalPromptId: originalPromptId
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return PromptDetectionResult(
|
||||
shouldEnableAI: false,
|
||||
selectedPromptId: nil,
|
||||
processedText: text,
|
||||
detectedTriggerWord: nil,
|
||||
originalEnhancementState: originalEnhancementState,
|
||||
originalPromptId: originalPromptId
|
||||
)
|
||||
}
|
||||
|
||||
func applyDetectionResult(_ result: PromptDetectionResult, to enhancementService: AIEnhancementService) async {
|
||||
await MainActor.run {
|
||||
if result.shouldEnableAI {
|
||||
if !enhancementService.isEnhancementEnabled {
|
||||
enhancementService.isEnhancementEnabled = true
|
||||
}
|
||||
if let promptId = result.selectedPromptId {
|
||||
enhancementService.selectedPromptId = promptId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.shouldEnableAI {
|
||||
try? await Task.sleep(nanoseconds: 50_000_000)
|
||||
}
|
||||
}
|
||||
|
||||
func restoreOriginalSettings(_ result: PromptDetectionResult, to enhancementService: AIEnhancementService) async {
|
||||
if result.shouldEnableAI {
|
||||
await MainActor.run {
|
||||
if enhancementService.isEnhancementEnabled != result.originalEnhancementState {
|
||||
enhancementService.isEnhancementEnabled = result.originalEnhancementState
|
||||
}
|
||||
if let originalId = result.originalPromptId, enhancementService.selectedPromptId != originalId {
|
||||
enhancementService.selectedPromptId = originalId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func checkAssistantTrigger(text: String, triggerWord: String) -> String? {
|
||||
return removeTriggerWord(from: text, triggerWord: triggerWord)
|
||||
}
|
||||
|
||||
private func checkCustomTrigger(text: String, triggerWord: String) -> String? {
|
||||
return removeTriggerWord(from: text, triggerWord: triggerWord)
|
||||
}
|
||||
|
||||
private func removeTriggerWord(from text: String, triggerWord: String) -> String? {
|
||||
let trimmedText = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let lowerText = trimmedText.lowercased()
|
||||
let lowerTrigger = triggerWord.lowercased()
|
||||
|
||||
guard lowerText.hasPrefix(lowerTrigger) else { return nil }
|
||||
|
||||
let triggerEndIndex = trimmedText.index(trimmedText.startIndex, offsetBy: triggerWord.count)
|
||||
|
||||
if triggerEndIndex >= trimmedText.endIndex {
|
||||
return ""
|
||||
}
|
||||
|
||||
var remainingText = String(trimmedText[triggerEndIndex...])
|
||||
|
||||
remainingText = remainingText.replacingOccurrences(
|
||||
of: "^[,\\.!\\?;:\\s]+",
|
||||
with: "",
|
||||
options: .regularExpression
|
||||
)
|
||||
|
||||
remainingText = remainingText.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
if !remainingText.isEmpty {
|
||||
remainingText = remainingText.prefix(1).uppercased() + remainingText.dropFirst()
|
||||
}
|
||||
|
||||
return remainingText
|
||||
}
|
||||
}
|
||||
@ -48,6 +48,9 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
var recordedFile: URL? = nil
|
||||
let whisperPrompt = WhisperPrompt()
|
||||
|
||||
// Prompt detection service for trigger word handling
|
||||
private let promptDetectionService = PromptDetectionService()
|
||||
|
||||
let modelContext: ModelContext
|
||||
|
||||
private var modelUrl: URL? {
|
||||
@ -295,24 +298,36 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
text = WordReplacementService.shared.applyReplacements(to: text)
|
||||
logger.notice("✅ Word replacements applied")
|
||||
}
|
||||
|
||||
var promptDetectionResult: PromptDetectionService.PromptDetectionResult? = nil
|
||||
let originalText = text
|
||||
|
||||
if let enhancementService = enhancementService, enhancementService.isConfigured {
|
||||
let detectionResult = promptDetectionService.analyzeText(text, with: enhancementService)
|
||||
promptDetectionResult = detectionResult
|
||||
await promptDetectionService.applyDetectionResult(detectionResult, to: enhancementService)
|
||||
}
|
||||
|
||||
if let enhancementService = enhancementService,
|
||||
enhancementService.isEnhancementEnabled,
|
||||
enhancementService.isConfigured {
|
||||
do {
|
||||
if shouldCancelRecording { return }
|
||||
let enhancedText = try await enhancementService.enhance(text)
|
||||
// Use processed text (without trigger words) for AI enhancement
|
||||
let textForAI = promptDetectionResult?.processedText ?? text
|
||||
let enhancedText = try await enhancementService.enhance(textForAI)
|
||||
let newTranscription = Transcription(
|
||||
text: text,
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
enhancedText: enhancedText,
|
||||
audioFileURL: permanentURLString
|
||||
)
|
||||
modelContext.insert(newTranscription)
|
||||
try? modelContext.save()
|
||||
text = enhancedText
|
||||
text = enhancedText
|
||||
} catch {
|
||||
let newTranscription = Transcription(
|
||||
text: text,
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
audioFileURL: permanentURLString
|
||||
)
|
||||
@ -321,7 +336,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
}
|
||||
} else {
|
||||
let newTranscription = Transcription(
|
||||
text: text,
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
audioFileURL: permanentURLString
|
||||
)
|
||||
@ -353,6 +368,13 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
}
|
||||
}
|
||||
try? FileManager.default.removeItem(at: url)
|
||||
|
||||
if let result = promptDetectionResult,
|
||||
let enhancementService = enhancementService,
|
||||
result.shouldEnableAI {
|
||||
await promptDetectionService.restoreOriginalSettings(result, to: enhancementService)
|
||||
}
|
||||
|
||||
await dismissMiniRecorder()
|
||||
await cleanupModelResources()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user