Merge pull request #443 from Beingpax/refactor/centralize-transcription-services

Add TranscriptionServiceRegistry to eliminate duplicate service
This commit is contained in:
Prakash Joshi Pax 2025-12-19 09:33:52 +05:45 committed by GitHub
commit aa10377a55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 73 additions and 122 deletions

View File

@ -17,12 +17,6 @@ class AudioTranscriptionManager: ObservableObject {
private let audioProcessor = AudioProcessor() private let audioProcessor = AudioProcessor()
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager") private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
// Transcription services - will be initialized when needed
private var localTranscriptionService: LocalTranscriptionService?
private lazy var cloudTranscriptionService = CloudTranscriptionService()
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
private var parakeetTranscriptionService: ParakeetTranscriptionService?
enum ProcessingPhase { enum ProcessingPhase {
case idle case idle
case loading case loading
@ -64,52 +58,31 @@ class AudioTranscriptionManager: ObservableObject {
guard let currentModel = whisperState.currentTranscriptionModel else { guard let currentModel = whisperState.currentTranscriptionModel else {
throw TranscriptionError.noModelSelected throw TranscriptionError.noModelSelected
} }
// Initialize local transcription service if needed let serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
if localTranscriptionService == nil { defer {
localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState) serviceRegistry.cleanup()
} }
// Initialize parakeet transcription service if needed
if parakeetTranscriptionService == nil {
parakeetTranscriptionService = ParakeetTranscriptionService()
}
// Process audio file
processingPhase = .processingAudio processingPhase = .processingAudio
let samples = try await audioProcessor.processAudioToSamples(url) let samples = try await audioProcessor.processAudioToSamples(url)
// Get audio duration
let audioAsset = AVURLAsset(url: url) let audioAsset = AVURLAsset(url: url)
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration)) let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
// Create permanent copy of the audio file
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk") .appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("Recordings") .appendingPathComponent("Recordings")
let fileName = "transcribed_\(UUID().uuidString).wav" let fileName = "transcribed_\(UUID().uuidString).wav"
let permanentURL = recordingsDirectory.appendingPathComponent(fileName) let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true) try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL) try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
// Transcribe using appropriate service
processingPhase = .transcribing processingPhase = .transcribing
let transcriptionStart = Date() let transcriptionStart = Date()
var text: String var text = try await serviceRegistry.transcribe(audioURL: permanentURL, model: currentModel)
switch currentModel.provider {
case .local:
text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
case .parakeet:
text = try await parakeetTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
case .nativeApple:
text = try await nativeAppleTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
default: // Cloud models
text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
}
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart) let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
text = TranscriptionOutputFilter.filter(text) text = TranscriptionOutputFilter.filter(text)
text = text.trimmingCharacters(in: .whitespacesAndNewlines) text = text.trimmingCharacters(in: .whitespacesAndNewlines)

View File

@ -14,12 +14,7 @@ class AudioTranscriptionService: ObservableObject {
private let whisperState: WhisperState private let whisperState: WhisperState
private let promptDetectionService = PromptDetectionService() private let promptDetectionService = PromptDetectionService()
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService") private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
private let serviceRegistry: TranscriptionServiceRegistry
// Transcription services
private let localTranscriptionService: LocalTranscriptionService
private lazy var cloudTranscriptionService = CloudTranscriptionService()
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
private lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
enum TranscriptionError: Error { enum TranscriptionError: Error {
case noAudioFile case noAudioFile
@ -32,7 +27,7 @@ class AudioTranscriptionService: ObservableObject {
self.modelContext = modelContext self.modelContext = modelContext
self.whisperState = whisperState self.whisperState = whisperState
self.enhancementService = whisperState.enhancementService self.enhancementService = whisperState.enhancementService
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState) self.serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
} }
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription { func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
@ -45,21 +40,8 @@ class AudioTranscriptionService: ObservableObject {
} }
do { do {
// Delegate transcription to appropriate service
let transcriptionStart = Date() let transcriptionStart = Date()
var text: String var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
switch model.provider {
case .local:
text = try await localTranscriptionService.transcribe(audioURL: url, model: model)
case .parakeet:
text = try await parakeetTranscriptionService.transcribe(audioURL: url, model: model)
case .nativeApple:
text = try await nativeAppleTranscriptionService.transcribe(audioURL: url, model: model)
default: // Cloud models
text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model)
}
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart) let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
text = TranscriptionOutputFilter.filter(text) text = TranscriptionOutputFilter.filter(text)
text = text.trimmingCharacters(in: .whitespacesAndNewlines) text = text.trimmingCharacters(in: .whitespacesAndNewlines)
@ -75,12 +57,9 @@ class AudioTranscriptionService: ObservableObject {
text = WordReplacementService.shared.applyReplacements(to: text) text = WordReplacementService.shared.applyReplacements(to: text)
logger.notice("✅ Word replacements applied") logger.notice("✅ Word replacements applied")
// Get audio duration
let audioAsset = AVURLAsset(url: url) let audioAsset = AVURLAsset(url: url)
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration)) let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
// Create a permanent copy of the audio file
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk") .appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("Recordings") .appendingPathComponent("Recordings")

View File

@ -8,17 +8,11 @@ final class ModelPrewarmService: ObservableObject {
private let whisperState: WhisperState private let whisperState: WhisperState
private let modelContext: ModelContext private let modelContext: ModelContext
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "ModelPrewarm") private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "ModelPrewarm")
private lazy var serviceRegistry = TranscriptionServiceRegistry(
// Services (initialized lazily) whisperState: whisperState,
private var localTranscriptionService: LocalTranscriptionService? modelsDirectory: whisperState.modelsDirectory
private var parakeetTranscriptionService: ParakeetTranscriptionService? )
private let nativeAppleTranscriptionService = NativeAppleTranscriptionService()
private let cloudTranscriptionService = CloudTranscriptionService()
// Sample audio for prewarming
private let prewarmAudioURL = Bundle.main.url(forResource: "esc", withExtension: "wav") private let prewarmAudioURL = Bundle.main.url(forResource: "esc", withExtension: "wav")
// User preference key
private let prewarmEnabledKey = "PrewarmModelOnWake" private let prewarmEnabledKey = "PrewarmModelOnWake"
init(whisperState: WhisperState, modelContext: ModelContext) { init(whisperState: WhisperState, modelContext: ModelContext) {
@ -83,33 +77,9 @@ final class ModelPrewarmService: ObservableObject {
let startTime = Date() let startTime = Date()
do { do {
// Initialize services lazily let transcribedText = try await serviceRegistry.transcribe(audioURL: audioURL, model: currentModel)
if localTranscriptionService == nil {
localTranscriptionService = LocalTranscriptionService(
modelsDirectory: whisperState.modelsDirectory,
whisperState: whisperState
)
}
if parakeetTranscriptionService == nil {
parakeetTranscriptionService = ParakeetTranscriptionService()
}
// Run transcription to trigger model loading and ANE compilation
let transcribedText: String
switch currentModel.provider {
case .local:
transcribedText = try await localTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
case .parakeet:
transcribedText = try await parakeetTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
case .nativeApple:
transcribedText = try await nativeAppleTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
default:
transcribedText = try await cloudTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
}
let duration = Date().timeIntervalSince(startTime) let duration = Date().timeIntervalSince(startTime)
// Save for telemetry
let transcription = Transcription( let transcription = Transcription(
text: "[PREWARM] \(transcribedText)", text: "[PREWARM] \(transcribedText)",
duration: 1.0, duration: 1.0,

View File

@ -0,0 +1,46 @@
import Foundation
import SwiftUI
import os
@MainActor
class TranscriptionServiceRegistry {
private let whisperState: WhisperState
private let modelsDirectory: URL
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "TranscriptionServiceRegistry")
private(set) lazy var localTranscriptionService = LocalTranscriptionService(
modelsDirectory: modelsDirectory,
whisperState: whisperState
)
private(set) lazy var cloudTranscriptionService = CloudTranscriptionService()
private(set) lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
private(set) lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
init(whisperState: WhisperState, modelsDirectory: URL) {
self.whisperState = whisperState
self.modelsDirectory = modelsDirectory
}
func service(for provider: ModelProvider) -> TranscriptionService {
switch provider {
case .local:
return localTranscriptionService
case .parakeet:
return parakeetTranscriptionService
case .nativeApple:
return nativeAppleTranscriptionService
default:
return cloudTranscriptionService
}
}
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
let service = service(for: model.provider)
logger.debug("Transcribing with \(model.displayName) using \(String(describing: type(of: service)))")
return try await service.transcribe(audioURL: audioURL, model: model)
}
func cleanup() {
parakeetTranscriptionService.cleanup()
}
}

View File

@ -340,8 +340,7 @@ extension WhisperState {
await whisperContext?.releaseResources() await whisperContext?.releaseResources()
whisperContext = nil whisperContext = nil
isModelLoaded = false isModelLoaded = false
serviceRegistry.cleanup()
parakeetTranscriptionService.cleanup()
} }
// MARK: - Helper Methods // MARK: - Helper Methods

View File

@ -68,11 +68,7 @@ class WhisperState: NSObject, ObservableObject {
let modelContext: ModelContext let modelContext: ModelContext
// Transcription Services internal var serviceRegistry: TranscriptionServiceRegistry!
private var localTranscriptionService: LocalTranscriptionService!
private lazy var cloudTranscriptionService = CloudTranscriptionService()
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
private var modelUrl: URL? { private var modelUrl: URL? {
let possibleURLs = [ let possibleURLs = [
@ -122,9 +118,9 @@ class WhisperState: NSObject, ObservableObject {
if let enhancementService = enhancementService { if let enhancementService = enhancementService {
PowerModeSessionManager.shared.configure(whisperState: self, enhancementService: enhancementService) PowerModeSessionManager.shared.configure(whisperState: self, enhancementService: enhancementService)
} }
// Set the whisperState reference after super.init() // Initialize the transcription service registry
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self) self.serviceRegistry = TranscriptionServiceRegistry(whisperState: self, modelsDirectory: self.modelsDirectory)
setupNotifications() setupNotifications()
createModelsDirectoryIfNeeded() createModelsDirectoryIfNeeded()
@ -216,7 +212,7 @@ class WhisperState: NSObject, ObservableObject {
} }
} }
} else if let parakeetModel = await self.currentTranscriptionModel as? ParakeetModel { } else if let parakeetModel = await self.currentTranscriptionModel as? ParakeetModel {
try? await self.parakeetTranscriptionService.loadModel(for: parakeetModel) try? await self.serviceRegistry.parakeetTranscriptionService.loadModel(for: parakeetModel)
} }
if let enhancementService = await self.enhancementService { if let enhancementService = await self.enhancementService {
@ -299,20 +295,8 @@ class WhisperState: NSObject, ObservableObject {
throw WhisperStateError.transcriptionFailed throw WhisperStateError.transcriptionFailed
} }
let transcriptionService: TranscriptionService
switch model.provider {
case .local:
transcriptionService = localTranscriptionService
case .parakeet:
transcriptionService = parakeetTranscriptionService
case .nativeApple:
transcriptionService = nativeAppleTranscriptionService
default:
transcriptionService = cloudTranscriptionService
}
let transcriptionStart = Date() let transcriptionStart = Date()
var text = try await transcriptionService.transcribe(audioURL: url, model: model) var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
logger.notice("📝 Raw transcript: \(text, privacy: .public)") logger.notice("📝 Raw transcript: \(text, privacy: .public)")
text = TranscriptionOutputFilter.filter(text) text = TranscriptionOutputFilter.filter(text)
logger.notice("📝 Output filter result: \(text, privacy: .public)") logger.notice("📝 Output filter result: \(text, privacy: .public)")