Merge pull request #443 from Beingpax/refactor/centralize-transcription-services
Add TranscriptionServiceRegistry to eliminate duplicate service
This commit is contained in:
commit
aa10377a55
@ -17,12 +17,6 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
private let audioProcessor = AudioProcessor()
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
|
||||
|
||||
// Transcription services - will be initialized when needed
|
||||
private var localTranscriptionService: LocalTranscriptionService?
|
||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
private var parakeetTranscriptionService: ParakeetTranscriptionService?
|
||||
|
||||
enum ProcessingPhase {
|
||||
case idle
|
||||
case loading
|
||||
@ -64,52 +58,31 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
guard let currentModel = whisperState.currentTranscriptionModel else {
|
||||
throw TranscriptionError.noModelSelected
|
||||
}
|
||||
|
||||
// Initialize local transcription service if needed
|
||||
if localTranscriptionService == nil {
|
||||
localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
||||
|
||||
let serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
|
||||
defer {
|
||||
serviceRegistry.cleanup()
|
||||
}
|
||||
|
||||
// Initialize parakeet transcription service if needed
|
||||
if parakeetTranscriptionService == nil {
|
||||
parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||
}
|
||||
|
||||
// Process audio file
|
||||
|
||||
processingPhase = .processingAudio
|
||||
let samples = try await audioProcessor.processAudioToSamples(url)
|
||||
|
||||
// Get audio duration
|
||||
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
|
||||
// Create permanent copy of the audio file
|
||||
|
||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("Recordings")
|
||||
|
||||
|
||||
let fileName = "transcribed_\(UUID().uuidString).wav"
|
||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||
|
||||
|
||||
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
|
||||
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
|
||||
|
||||
// Transcribe using appropriate service
|
||||
|
||||
processingPhase = .transcribing
|
||||
let transcriptionStart = Date()
|
||||
var text: String
|
||||
|
||||
switch currentModel.provider {
|
||||
case .local:
|
||||
text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
case .parakeet:
|
||||
text = try await parakeetTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
case .nativeApple:
|
||||
text = try await nativeAppleTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
default: // Cloud models
|
||||
text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
}
|
||||
|
||||
var text = try await serviceRegistry.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
||||
text = TranscriptionOutputFilter.filter(text)
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
@ -14,12 +14,7 @@ class AudioTranscriptionService: ObservableObject {
|
||||
private let whisperState: WhisperState
|
||||
private let promptDetectionService = PromptDetectionService()
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
|
||||
|
||||
// Transcription services
|
||||
private let localTranscriptionService: LocalTranscriptionService
|
||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
private lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||
private let serviceRegistry: TranscriptionServiceRegistry
|
||||
|
||||
enum TranscriptionError: Error {
|
||||
case noAudioFile
|
||||
@ -32,7 +27,7 @@ class AudioTranscriptionService: ObservableObject {
|
||||
self.modelContext = modelContext
|
||||
self.whisperState = whisperState
|
||||
self.enhancementService = whisperState.enhancementService
|
||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
||||
self.serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
|
||||
}
|
||||
|
||||
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
|
||||
@ -45,21 +40,8 @@ class AudioTranscriptionService: ObservableObject {
|
||||
}
|
||||
|
||||
do {
|
||||
// Delegate transcription to appropriate service
|
||||
let transcriptionStart = Date()
|
||||
var text: String
|
||||
|
||||
switch model.provider {
|
||||
case .local:
|
||||
text = try await localTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
case .parakeet:
|
||||
text = try await parakeetTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
case .nativeApple:
|
||||
text = try await nativeAppleTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
default: // Cloud models
|
||||
text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
}
|
||||
|
||||
var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
|
||||
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
||||
text = TranscriptionOutputFilter.filter(text)
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
@ -75,12 +57,9 @@ class AudioTranscriptionService: ObservableObject {
|
||||
|
||||
text = WordReplacementService.shared.applyReplacements(to: text)
|
||||
logger.notice("✅ Word replacements applied")
|
||||
|
||||
// Get audio duration
|
||||
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
|
||||
// Create a permanent copy of the audio file
|
||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("Recordings")
|
||||
|
||||
@ -8,17 +8,11 @@ final class ModelPrewarmService: ObservableObject {
|
||||
private let whisperState: WhisperState
|
||||
private let modelContext: ModelContext
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "ModelPrewarm")
|
||||
|
||||
// Services (initialized lazily)
|
||||
private var localTranscriptionService: LocalTranscriptionService?
|
||||
private var parakeetTranscriptionService: ParakeetTranscriptionService?
|
||||
private let nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
private let cloudTranscriptionService = CloudTranscriptionService()
|
||||
|
||||
// Sample audio for prewarming
|
||||
private lazy var serviceRegistry = TranscriptionServiceRegistry(
|
||||
whisperState: whisperState,
|
||||
modelsDirectory: whisperState.modelsDirectory
|
||||
)
|
||||
private let prewarmAudioURL = Bundle.main.url(forResource: "esc", withExtension: "wav")
|
||||
|
||||
// User preference key
|
||||
private let prewarmEnabledKey = "PrewarmModelOnWake"
|
||||
|
||||
init(whisperState: WhisperState, modelContext: ModelContext) {
|
||||
@ -83,33 +77,9 @@ final class ModelPrewarmService: ObservableObject {
|
||||
let startTime = Date()
|
||||
|
||||
do {
|
||||
// Initialize services lazily
|
||||
if localTranscriptionService == nil {
|
||||
localTranscriptionService = LocalTranscriptionService(
|
||||
modelsDirectory: whisperState.modelsDirectory,
|
||||
whisperState: whisperState
|
||||
)
|
||||
}
|
||||
if parakeetTranscriptionService == nil {
|
||||
parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||
}
|
||||
|
||||
// Run transcription to trigger model loading and ANE compilation
|
||||
let transcribedText: String
|
||||
switch currentModel.provider {
|
||||
case .local:
|
||||
transcribedText = try await localTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
|
||||
case .parakeet:
|
||||
transcribedText = try await parakeetTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
|
||||
case .nativeApple:
|
||||
transcribedText = try await nativeAppleTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
|
||||
default:
|
||||
transcribedText = try await cloudTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
|
||||
}
|
||||
|
||||
let transcribedText = try await serviceRegistry.transcribe(audioURL: audioURL, model: currentModel)
|
||||
let duration = Date().timeIntervalSince(startTime)
|
||||
|
||||
// Save for telemetry
|
||||
let transcription = Transcription(
|
||||
text: "[PREWARM] \(transcribedText)",
|
||||
duration: 1.0,
|
||||
|
||||
46
VoiceInk/Services/TranscriptionServiceRegistry.swift
Normal file
46
VoiceInk/Services/TranscriptionServiceRegistry.swift
Normal file
@ -0,0 +1,46 @@
|
||||
import Foundation
|
||||
import SwiftUI
|
||||
import os
|
||||
|
||||
@MainActor
|
||||
class TranscriptionServiceRegistry {
|
||||
private let whisperState: WhisperState
|
||||
private let modelsDirectory: URL
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "TranscriptionServiceRegistry")
|
||||
|
||||
private(set) lazy var localTranscriptionService = LocalTranscriptionService(
|
||||
modelsDirectory: modelsDirectory,
|
||||
whisperState: whisperState
|
||||
)
|
||||
private(set) lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||
private(set) lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
private(set) lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||
|
||||
init(whisperState: WhisperState, modelsDirectory: URL) {
|
||||
self.whisperState = whisperState
|
||||
self.modelsDirectory = modelsDirectory
|
||||
}
|
||||
|
||||
func service(for provider: ModelProvider) -> TranscriptionService {
|
||||
switch provider {
|
||||
case .local:
|
||||
return localTranscriptionService
|
||||
case .parakeet:
|
||||
return parakeetTranscriptionService
|
||||
case .nativeApple:
|
||||
return nativeAppleTranscriptionService
|
||||
default:
|
||||
return cloudTranscriptionService
|
||||
}
|
||||
}
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
let service = service(for: model.provider)
|
||||
logger.debug("Transcribing with \(model.displayName) using \(String(describing: type(of: service)))")
|
||||
return try await service.transcribe(audioURL: audioURL, model: model)
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
parakeetTranscriptionService.cleanup()
|
||||
}
|
||||
}
|
||||
@ -340,8 +340,7 @@ extension WhisperState {
|
||||
await whisperContext?.releaseResources()
|
||||
whisperContext = nil
|
||||
isModelLoaded = false
|
||||
|
||||
parakeetTranscriptionService.cleanup()
|
||||
serviceRegistry.cleanup()
|
||||
}
|
||||
|
||||
// MARK: - Helper Methods
|
||||
|
||||
@ -68,11 +68,7 @@ class WhisperState: NSObject, ObservableObject {
|
||||
|
||||
let modelContext: ModelContext
|
||||
|
||||
// Transcription Services
|
||||
private var localTranscriptionService: LocalTranscriptionService!
|
||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||
internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||
internal var serviceRegistry: TranscriptionServiceRegistry!
|
||||
|
||||
private var modelUrl: URL? {
|
||||
let possibleURLs = [
|
||||
@ -122,9 +118,9 @@ class WhisperState: NSObject, ObservableObject {
|
||||
if let enhancementService = enhancementService {
|
||||
PowerModeSessionManager.shared.configure(whisperState: self, enhancementService: enhancementService)
|
||||
}
|
||||
|
||||
// Set the whisperState reference after super.init()
|
||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self)
|
||||
|
||||
// Initialize the transcription service registry
|
||||
self.serviceRegistry = TranscriptionServiceRegistry(whisperState: self, modelsDirectory: self.modelsDirectory)
|
||||
|
||||
setupNotifications()
|
||||
createModelsDirectoryIfNeeded()
|
||||
@ -216,7 +212,7 @@ class WhisperState: NSObject, ObservableObject {
|
||||
}
|
||||
}
|
||||
} else if let parakeetModel = await self.currentTranscriptionModel as? ParakeetModel {
|
||||
try? await self.parakeetTranscriptionService.loadModel(for: parakeetModel)
|
||||
try? await self.serviceRegistry.parakeetTranscriptionService.loadModel(for: parakeetModel)
|
||||
}
|
||||
|
||||
if let enhancementService = await self.enhancementService {
|
||||
@ -299,20 +295,8 @@ class WhisperState: NSObject, ObservableObject {
|
||||
throw WhisperStateError.transcriptionFailed
|
||||
}
|
||||
|
||||
let transcriptionService: TranscriptionService
|
||||
switch model.provider {
|
||||
case .local:
|
||||
transcriptionService = localTranscriptionService
|
||||
case .parakeet:
|
||||
transcriptionService = parakeetTranscriptionService
|
||||
case .nativeApple:
|
||||
transcriptionService = nativeAppleTranscriptionService
|
||||
default:
|
||||
transcriptionService = cloudTranscriptionService
|
||||
}
|
||||
|
||||
let transcriptionStart = Date()
|
||||
var text = try await transcriptionService.transcribe(audioURL: url, model: model)
|
||||
var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
|
||||
logger.notice("📝 Raw transcript: \(text, privacy: .public)")
|
||||
text = TranscriptionOutputFilter.filter(text)
|
||||
logger.notice("📝 Output filter result: \(text, privacy: .public)")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user