Merge pull request #443 from Beingpax/refactor/centralize-transcription-services
Add TranscriptionServiceRegistry to eliminate duplicate service
This commit is contained in:
commit
aa10377a55
@ -17,12 +17,6 @@ class AudioTranscriptionManager: ObservableObject {
|
|||||||
private let audioProcessor = AudioProcessor()
|
private let audioProcessor = AudioProcessor()
|
||||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
|
||||||
|
|
||||||
// Transcription services - will be initialized when needed
|
|
||||||
private var localTranscriptionService: LocalTranscriptionService?
|
|
||||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
|
||||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
|
||||||
private var parakeetTranscriptionService: ParakeetTranscriptionService?
|
|
||||||
|
|
||||||
enum ProcessingPhase {
|
enum ProcessingPhase {
|
||||||
case idle
|
case idle
|
||||||
case loading
|
case loading
|
||||||
@ -64,52 +58,31 @@ class AudioTranscriptionManager: ObservableObject {
|
|||||||
guard let currentModel = whisperState.currentTranscriptionModel else {
|
guard let currentModel = whisperState.currentTranscriptionModel else {
|
||||||
throw TranscriptionError.noModelSelected
|
throw TranscriptionError.noModelSelected
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize local transcription service if needed
|
let serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
|
||||||
if localTranscriptionService == nil {
|
defer {
|
||||||
localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
serviceRegistry.cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize parakeet transcription service if needed
|
|
||||||
if parakeetTranscriptionService == nil {
|
|
||||||
parakeetTranscriptionService = ParakeetTranscriptionService()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process audio file
|
|
||||||
processingPhase = .processingAudio
|
processingPhase = .processingAudio
|
||||||
let samples = try await audioProcessor.processAudioToSamples(url)
|
let samples = try await audioProcessor.processAudioToSamples(url)
|
||||||
|
|
||||||
// Get audio duration
|
|
||||||
let audioAsset = AVURLAsset(url: url)
|
let audioAsset = AVURLAsset(url: url)
|
||||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||||
|
|
||||||
// Create permanent copy of the audio file
|
|
||||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||||
.appendingPathComponent("Recordings")
|
.appendingPathComponent("Recordings")
|
||||||
|
|
||||||
let fileName = "transcribed_\(UUID().uuidString).wav"
|
let fileName = "transcribed_\(UUID().uuidString).wav"
|
||||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||||
|
|
||||||
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
|
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
|
||||||
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
|
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
|
||||||
|
|
||||||
// Transcribe using appropriate service
|
|
||||||
processingPhase = .transcribing
|
processingPhase = .transcribing
|
||||||
let transcriptionStart = Date()
|
let transcriptionStart = Date()
|
||||||
var text: String
|
var text = try await serviceRegistry.transcribe(audioURL: permanentURL, model: currentModel)
|
||||||
|
|
||||||
switch currentModel.provider {
|
|
||||||
case .local:
|
|
||||||
text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
|
|
||||||
case .parakeet:
|
|
||||||
text = try await parakeetTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
|
|
||||||
case .nativeApple:
|
|
||||||
text = try await nativeAppleTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
|
|
||||||
default: // Cloud models
|
|
||||||
text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
||||||
text = TranscriptionOutputFilter.filter(text)
|
text = TranscriptionOutputFilter.filter(text)
|
||||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
|||||||
@ -14,12 +14,7 @@ class AudioTranscriptionService: ObservableObject {
|
|||||||
private let whisperState: WhisperState
|
private let whisperState: WhisperState
|
||||||
private let promptDetectionService = PromptDetectionService()
|
private let promptDetectionService = PromptDetectionService()
|
||||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
|
||||||
|
private let serviceRegistry: TranscriptionServiceRegistry
|
||||||
// Transcription services
|
|
||||||
private let localTranscriptionService: LocalTranscriptionService
|
|
||||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
|
||||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
|
||||||
private lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
|
||||||
|
|
||||||
enum TranscriptionError: Error {
|
enum TranscriptionError: Error {
|
||||||
case noAudioFile
|
case noAudioFile
|
||||||
@ -32,7 +27,7 @@ class AudioTranscriptionService: ObservableObject {
|
|||||||
self.modelContext = modelContext
|
self.modelContext = modelContext
|
||||||
self.whisperState = whisperState
|
self.whisperState = whisperState
|
||||||
self.enhancementService = whisperState.enhancementService
|
self.enhancementService = whisperState.enhancementService
|
||||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
self.serviceRegistry = TranscriptionServiceRegistry(whisperState: whisperState, modelsDirectory: whisperState.modelsDirectory)
|
||||||
}
|
}
|
||||||
|
|
||||||
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
|
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
|
||||||
@ -45,21 +40,8 @@ class AudioTranscriptionService: ObservableObject {
|
|||||||
}
|
}
|
||||||
|
|
||||||
do {
|
do {
|
||||||
// Delegate transcription to appropriate service
|
|
||||||
let transcriptionStart = Date()
|
let transcriptionStart = Date()
|
||||||
var text: String
|
var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
|
||||||
|
|
||||||
switch model.provider {
|
|
||||||
case .local:
|
|
||||||
text = try await localTranscriptionService.transcribe(audioURL: url, model: model)
|
|
||||||
case .parakeet:
|
|
||||||
text = try await parakeetTranscriptionService.transcribe(audioURL: url, model: model)
|
|
||||||
case .nativeApple:
|
|
||||||
text = try await nativeAppleTranscriptionService.transcribe(audioURL: url, model: model)
|
|
||||||
default: // Cloud models
|
|
||||||
text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model)
|
|
||||||
}
|
|
||||||
|
|
||||||
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
let transcriptionDuration = Date().timeIntervalSince(transcriptionStart)
|
||||||
text = TranscriptionOutputFilter.filter(text)
|
text = TranscriptionOutputFilter.filter(text)
|
||||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
@ -75,12 +57,9 @@ class AudioTranscriptionService: ObservableObject {
|
|||||||
|
|
||||||
text = WordReplacementService.shared.applyReplacements(to: text)
|
text = WordReplacementService.shared.applyReplacements(to: text)
|
||||||
logger.notice("✅ Word replacements applied")
|
logger.notice("✅ Word replacements applied")
|
||||||
|
|
||||||
// Get audio duration
|
|
||||||
let audioAsset = AVURLAsset(url: url)
|
let audioAsset = AVURLAsset(url: url)
|
||||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||||
|
|
||||||
// Create a permanent copy of the audio file
|
|
||||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||||
.appendingPathComponent("Recordings")
|
.appendingPathComponent("Recordings")
|
||||||
|
|||||||
@ -8,17 +8,11 @@ final class ModelPrewarmService: ObservableObject {
|
|||||||
private let whisperState: WhisperState
|
private let whisperState: WhisperState
|
||||||
private let modelContext: ModelContext
|
private let modelContext: ModelContext
|
||||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "ModelPrewarm")
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "ModelPrewarm")
|
||||||
|
private lazy var serviceRegistry = TranscriptionServiceRegistry(
|
||||||
// Services (initialized lazily)
|
whisperState: whisperState,
|
||||||
private var localTranscriptionService: LocalTranscriptionService?
|
modelsDirectory: whisperState.modelsDirectory
|
||||||
private var parakeetTranscriptionService: ParakeetTranscriptionService?
|
)
|
||||||
private let nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
|
||||||
private let cloudTranscriptionService = CloudTranscriptionService()
|
|
||||||
|
|
||||||
// Sample audio for prewarming
|
|
||||||
private let prewarmAudioURL = Bundle.main.url(forResource: "esc", withExtension: "wav")
|
private let prewarmAudioURL = Bundle.main.url(forResource: "esc", withExtension: "wav")
|
||||||
|
|
||||||
// User preference key
|
|
||||||
private let prewarmEnabledKey = "PrewarmModelOnWake"
|
private let prewarmEnabledKey = "PrewarmModelOnWake"
|
||||||
|
|
||||||
init(whisperState: WhisperState, modelContext: ModelContext) {
|
init(whisperState: WhisperState, modelContext: ModelContext) {
|
||||||
@ -83,33 +77,9 @@ final class ModelPrewarmService: ObservableObject {
|
|||||||
let startTime = Date()
|
let startTime = Date()
|
||||||
|
|
||||||
do {
|
do {
|
||||||
// Initialize services lazily
|
let transcribedText = try await serviceRegistry.transcribe(audioURL: audioURL, model: currentModel)
|
||||||
if localTranscriptionService == nil {
|
|
||||||
localTranscriptionService = LocalTranscriptionService(
|
|
||||||
modelsDirectory: whisperState.modelsDirectory,
|
|
||||||
whisperState: whisperState
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if parakeetTranscriptionService == nil {
|
|
||||||
parakeetTranscriptionService = ParakeetTranscriptionService()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run transcription to trigger model loading and ANE compilation
|
|
||||||
let transcribedText: String
|
|
||||||
switch currentModel.provider {
|
|
||||||
case .local:
|
|
||||||
transcribedText = try await localTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
|
|
||||||
case .parakeet:
|
|
||||||
transcribedText = try await parakeetTranscriptionService!.transcribe(audioURL: audioURL, model: currentModel)
|
|
||||||
case .nativeApple:
|
|
||||||
transcribedText = try await nativeAppleTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
|
|
||||||
default:
|
|
||||||
transcribedText = try await cloudTranscriptionService.transcribe(audioURL: audioURL, model: currentModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
let duration = Date().timeIntervalSince(startTime)
|
let duration = Date().timeIntervalSince(startTime)
|
||||||
|
|
||||||
// Save for telemetry
|
|
||||||
let transcription = Transcription(
|
let transcription = Transcription(
|
||||||
text: "[PREWARM] \(transcribedText)",
|
text: "[PREWARM] \(transcribedText)",
|
||||||
duration: 1.0,
|
duration: 1.0,
|
||||||
|
|||||||
46
VoiceInk/Services/TranscriptionServiceRegistry.swift
Normal file
46
VoiceInk/Services/TranscriptionServiceRegistry.swift
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import Foundation
|
||||||
|
import SwiftUI
|
||||||
|
import os
|
||||||
|
|
||||||
|
@MainActor
|
||||||
|
class TranscriptionServiceRegistry {
|
||||||
|
private let whisperState: WhisperState
|
||||||
|
private let modelsDirectory: URL
|
||||||
|
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "TranscriptionServiceRegistry")
|
||||||
|
|
||||||
|
private(set) lazy var localTranscriptionService = LocalTranscriptionService(
|
||||||
|
modelsDirectory: modelsDirectory,
|
||||||
|
whisperState: whisperState
|
||||||
|
)
|
||||||
|
private(set) lazy var cloudTranscriptionService = CloudTranscriptionService()
|
||||||
|
private(set) lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
||||||
|
private(set) lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
||||||
|
|
||||||
|
init(whisperState: WhisperState, modelsDirectory: URL) {
|
||||||
|
self.whisperState = whisperState
|
||||||
|
self.modelsDirectory = modelsDirectory
|
||||||
|
}
|
||||||
|
|
||||||
|
func service(for provider: ModelProvider) -> TranscriptionService {
|
||||||
|
switch provider {
|
||||||
|
case .local:
|
||||||
|
return localTranscriptionService
|
||||||
|
case .parakeet:
|
||||||
|
return parakeetTranscriptionService
|
||||||
|
case .nativeApple:
|
||||||
|
return nativeAppleTranscriptionService
|
||||||
|
default:
|
||||||
|
return cloudTranscriptionService
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||||
|
let service = service(for: model.provider)
|
||||||
|
logger.debug("Transcribing with \(model.displayName) using \(String(describing: type(of: service)))")
|
||||||
|
return try await service.transcribe(audioURL: audioURL, model: model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanup() {
|
||||||
|
parakeetTranscriptionService.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -340,8 +340,7 @@ extension WhisperState {
|
|||||||
await whisperContext?.releaseResources()
|
await whisperContext?.releaseResources()
|
||||||
whisperContext = nil
|
whisperContext = nil
|
||||||
isModelLoaded = false
|
isModelLoaded = false
|
||||||
|
serviceRegistry.cleanup()
|
||||||
parakeetTranscriptionService.cleanup()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Helper Methods
|
// MARK: - Helper Methods
|
||||||
|
|||||||
@ -68,11 +68,7 @@ class WhisperState: NSObject, ObservableObject {
|
|||||||
|
|
||||||
let modelContext: ModelContext
|
let modelContext: ModelContext
|
||||||
|
|
||||||
// Transcription Services
|
internal var serviceRegistry: TranscriptionServiceRegistry!
|
||||||
private var localTranscriptionService: LocalTranscriptionService!
|
|
||||||
private lazy var cloudTranscriptionService = CloudTranscriptionService()
|
|
||||||
private lazy var nativeAppleTranscriptionService = NativeAppleTranscriptionService()
|
|
||||||
internal lazy var parakeetTranscriptionService = ParakeetTranscriptionService()
|
|
||||||
|
|
||||||
private var modelUrl: URL? {
|
private var modelUrl: URL? {
|
||||||
let possibleURLs = [
|
let possibleURLs = [
|
||||||
@ -122,9 +118,9 @@ class WhisperState: NSObject, ObservableObject {
|
|||||||
if let enhancementService = enhancementService {
|
if let enhancementService = enhancementService {
|
||||||
PowerModeSessionManager.shared.configure(whisperState: self, enhancementService: enhancementService)
|
PowerModeSessionManager.shared.configure(whisperState: self, enhancementService: enhancementService)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the whisperState reference after super.init()
|
// Initialize the transcription service registry
|
||||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self)
|
self.serviceRegistry = TranscriptionServiceRegistry(whisperState: self, modelsDirectory: self.modelsDirectory)
|
||||||
|
|
||||||
setupNotifications()
|
setupNotifications()
|
||||||
createModelsDirectoryIfNeeded()
|
createModelsDirectoryIfNeeded()
|
||||||
@ -216,7 +212,7 @@ class WhisperState: NSObject, ObservableObject {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let parakeetModel = await self.currentTranscriptionModel as? ParakeetModel {
|
} else if let parakeetModel = await self.currentTranscriptionModel as? ParakeetModel {
|
||||||
try? await self.parakeetTranscriptionService.loadModel(for: parakeetModel)
|
try? await self.serviceRegistry.parakeetTranscriptionService.loadModel(for: parakeetModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
if let enhancementService = await self.enhancementService {
|
if let enhancementService = await self.enhancementService {
|
||||||
@ -299,20 +295,8 @@ class WhisperState: NSObject, ObservableObject {
|
|||||||
throw WhisperStateError.transcriptionFailed
|
throw WhisperStateError.transcriptionFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
let transcriptionService: TranscriptionService
|
|
||||||
switch model.provider {
|
|
||||||
case .local:
|
|
||||||
transcriptionService = localTranscriptionService
|
|
||||||
case .parakeet:
|
|
||||||
transcriptionService = parakeetTranscriptionService
|
|
||||||
case .nativeApple:
|
|
||||||
transcriptionService = nativeAppleTranscriptionService
|
|
||||||
default:
|
|
||||||
transcriptionService = cloudTranscriptionService
|
|
||||||
}
|
|
||||||
|
|
||||||
let transcriptionStart = Date()
|
let transcriptionStart = Date()
|
||||||
var text = try await transcriptionService.transcribe(audioURL: url, model: model)
|
var text = try await serviceRegistry.transcribe(audioURL: url, model: model)
|
||||||
logger.notice("📝 Raw transcript: \(text, privacy: .public)")
|
logger.notice("📝 Raw transcript: \(text, privacy: .public)")
|
||||||
text = TranscriptionOutputFilter.filter(text)
|
text = TranscriptionOutputFilter.filter(text)
|
||||||
logger.notice("📝 Output filter result: \(text, privacy: .public)")
|
logger.notice("📝 Output filter result: \(text, privacy: .public)")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user