feat: Added support for cloud transcription service
This commit is contained in:
parent
a821b201c7
commit
9b507ce615
@ -1,6 +1,6 @@
|
||||
import Foundation
|
||||
|
||||
struct PredefinedModel: Identifiable, Hashable {
|
||||
struct LocalModel: TranscriptionModel {
|
||||
let id = UUID()
|
||||
let name: String
|
||||
let displayName: String
|
||||
@ -11,6 +11,7 @@ import Foundation
|
||||
let accuracy: Double
|
||||
let ramUsage: Double
|
||||
let hash: String
|
||||
let provider: ModelProvider = .local
|
||||
|
||||
var downloadURL: String {
|
||||
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)"
|
||||
@ -20,21 +21,9 @@ import Foundation
|
||||
"\(name).bin"
|
||||
}
|
||||
|
||||
func hash(into hasher: inout Hasher) {
|
||||
hasher.combine(id)
|
||||
}
|
||||
|
||||
static func == (lhs: PredefinedModel, rhs: PredefinedModel) -> Bool {
|
||||
lhs.id == rhs.id
|
||||
}
|
||||
|
||||
var isMultilingualModel: Bool {
|
||||
supportedLanguages.count > 1
|
||||
}
|
||||
|
||||
var language: String {
|
||||
isMultilingualModel ? "Multilingual" : "English-only"
|
||||
}
|
||||
}
|
||||
|
||||
enum PredefinedModels {
|
||||
@ -53,8 +42,9 @@ import Foundation
|
||||
}
|
||||
}
|
||||
|
||||
static let models: [PredefinedModel] = [
|
||||
PredefinedModel(
|
||||
static let models: [any TranscriptionModel] = [
|
||||
// Local Models
|
||||
LocalModel(
|
||||
name: "ggml-tiny",
|
||||
displayName: "Tiny",
|
||||
size: "75 MiB",
|
||||
@ -65,7 +55,7 @@ import Foundation
|
||||
ramUsage: 0.3,
|
||||
hash: "bd577a113a864445d4c299885e0cb97d4ba92b5f"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-tiny.en",
|
||||
displayName: "Tiny (English)",
|
||||
size: "75 MiB",
|
||||
@ -76,7 +66,7 @@ import Foundation
|
||||
ramUsage: 0.3,
|
||||
hash: "c78c86eb1a8faa21b369bcd33207cc90d64ae9df"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-base.en",
|
||||
displayName: "Base (English)",
|
||||
size: "142 MiB",
|
||||
@ -87,7 +77,7 @@ import Foundation
|
||||
ramUsage: 0.5,
|
||||
hash: "137c40403d78fd54d454da0f9bd998f78703390c"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-large-v2",
|
||||
displayName: "Large v2",
|
||||
size: "2.9 GiB",
|
||||
@ -98,7 +88,7 @@ import Foundation
|
||||
ramUsage: 3.8,
|
||||
hash: "0f4c8e30f21cf1769f637135f521436792c48186"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-large-v3",
|
||||
displayName: "Large v3",
|
||||
size: "2.9 GiB",
|
||||
@ -109,7 +99,7 @@ import Foundation
|
||||
ramUsage: 3.9,
|
||||
hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-large-v3-turbo",
|
||||
displayName: "Large v3 Turbo",
|
||||
size: "1.5 GiB",
|
||||
@ -121,7 +111,7 @@ import Foundation
|
||||
ramUsage: 1.8,
|
||||
hash: "4af2b29d7ec73d781377bfd1758ca957a807e941"
|
||||
),
|
||||
PredefinedModel(
|
||||
LocalModel(
|
||||
name: "ggml-large-v3-turbo-q5_0",
|
||||
displayName: "Large v3 Turbo (Quantized)",
|
||||
size: "547 MiB",
|
||||
@ -132,6 +122,29 @@ import Foundation
|
||||
ramUsage: 1.0,
|
||||
hash: "e050f7970618a659205450ad97eb95a18d69c9ee"
|
||||
),
|
||||
|
||||
// Cloud Models
|
||||
CloudModel(
|
||||
name: "whisper-large-v3-turbo",
|
||||
displayName: "Whisper Large v3 Turbo (Groq)",
|
||||
description: "Groq's ultra-fast Whisper Large v3 Turbo model with lightning-speed inference",
|
||||
provider: .groq,
|
||||
speed: 0.7,
|
||||
accuracy: 0.96,
|
||||
isMultilingual: true,
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true)
|
||||
),
|
||||
CloudModel(
|
||||
name: "scribe_v1",
|
||||
displayName: "Scribe v1 (ElevenLabs)",
|
||||
description: "ElevenLabs' Scribe model for fast and accurate transcription.",
|
||||
provider: .elevenLabs,
|
||||
speed: 0.8,
|
||||
accuracy: 0.95,
|
||||
isMultilingual: true,
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true)
|
||||
),
|
||||
|
||||
]
|
||||
|
||||
static let allLanguages = [
|
||||
|
||||
57
VoiceInk/Models/TranscriptionModel.swift
Normal file
57
VoiceInk/Models/TranscriptionModel.swift
Normal file
@ -0,0 +1,57 @@
|
||||
import Foundation
|
||||
|
||||
// Enum to differentiate between model providers
|
||||
enum ModelProvider: String, Codable, Hashable, CaseIterable {
|
||||
case local = "Local"
|
||||
case groq = "Groq"
|
||||
case elevenLabs = "ElevenLabs"
|
||||
// Future providers can be added here
|
||||
}
|
||||
|
||||
// A unified protocol for any transcription model
|
||||
protocol TranscriptionModel: Identifiable, Hashable {
|
||||
var id: UUID { get }
|
||||
var name: String { get }
|
||||
var displayName: String { get }
|
||||
var description: String { get }
|
||||
var provider: ModelProvider { get }
|
||||
|
||||
// Language capabilities
|
||||
var isMultilingualModel: Bool { get }
|
||||
var supportedLanguages: [String: String] { get }
|
||||
}
|
||||
|
||||
extension TranscriptionModel {
|
||||
func hash(into hasher: inout Hasher) {
|
||||
hasher.combine(id)
|
||||
}
|
||||
|
||||
var language: String {
|
||||
isMultilingualModel ? "Multilingual" : "English-only"
|
||||
}
|
||||
}
|
||||
|
||||
// A new struct for cloud models
|
||||
struct CloudModel: TranscriptionModel {
|
||||
let id: UUID
|
||||
let name: String
|
||||
let displayName: String
|
||||
let description: String
|
||||
let provider: ModelProvider
|
||||
let speed: Double
|
||||
let accuracy: Double
|
||||
let isMultilingualModel: Bool
|
||||
let supportedLanguages: [String: String]
|
||||
|
||||
init(id: UUID = UUID(), name: String, displayName: String, description: String, provider: ModelProvider, speed: Double, accuracy: Double, isMultilingual: Bool, supportedLanguages: [String: String]) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.displayName = displayName
|
||||
self.description = description
|
||||
self.provider = provider
|
||||
self.speed = speed
|
||||
self.accuracy = accuracy
|
||||
self.isMultilingualModel = isMultilingual
|
||||
self.supportedLanguages = supportedLanguages
|
||||
}
|
||||
}
|
||||
@ -116,20 +116,35 @@ class ActiveWindowService: ObservableObject {
|
||||
}
|
||||
|
||||
if let whisperState = self.whisperState,
|
||||
let modelName = config.selectedWhisperModel {
|
||||
let models = await MainActor.run { whisperState.availableModels }
|
||||
if let selectedModel = models.first(where: { $0.name == modelName }) {
|
||||
let modelName = config.selectedWhisperModel,
|
||||
let selectedModel = await whisperState.allAvailableModels.first(where: { $0.name == modelName }) {
|
||||
|
||||
let currentModelName = await MainActor.run { whisperState.currentTranscriptionModel?.name }
|
||||
|
||||
// Only change the model if it's different from the current one.
|
||||
if currentModelName != modelName {
|
||||
// Set the new model as default. This works for both local and cloud models.
|
||||
await whisperState.setDefaultTranscriptionModel(selectedModel)
|
||||
|
||||
let currentModelName = await MainActor.run { whisperState.currentModel?.name }
|
||||
if currentModelName != modelName {
|
||||
await whisperState.setDefaultModel(selectedModel)
|
||||
// The cleanup and load cycle is only necessary for local models.
|
||||
if selectedModel.provider == .local {
|
||||
// Unload any previously loaded model to free up memory.
|
||||
await whisperState.cleanupModelResources()
|
||||
|
||||
do {
|
||||
try await whisperState.loadModel(selectedModel)
|
||||
} catch {
|
||||
|
||||
// Load the new local model into memory.
|
||||
if let localModel = await whisperState.availableModels.first(where: { $0.name == selectedModel.name }) {
|
||||
do {
|
||||
try await whisperState.loadModel(localModel)
|
||||
logger.info("✅ Power Mode: Successfully loaded local model '\(localModel.name)'.")
|
||||
} catch {
|
||||
logger.error("❌ Power Mode: Failed to load local model '\(localModel.name)': \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For cloud models, no in-memory loading is needed, but we should still
|
||||
// clean up if the *previous* model was a local one.
|
||||
await whisperState.cleanupModelResources()
|
||||
logger.info("✅ Power Mode: Switched to cloud model '\(selectedModel.name)'. No local load needed.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -375,9 +375,12 @@ struct ConfigurationView: View {
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
Picker("", selection: modelBinding) {
|
||||
ForEach(whisperState.availableModels) { model in
|
||||
let displayName = whisperState.predefinedModels.first { $0.name == model.name }?.displayName ?? model.name
|
||||
Text(displayName).tag(model.name as String?)
|
||||
Text("Default")
|
||||
.tag(nil as String?)
|
||||
|
||||
ForEach(whisperState.usableModels, id: \.name) { model in
|
||||
Text(model.displayName)
|
||||
.tag(model.name as String?)
|
||||
}
|
||||
}
|
||||
.labelsHidden()
|
||||
@ -387,7 +390,7 @@ struct ConfigurationView: View {
|
||||
|
||||
// Language Selection Subsection
|
||||
if let selectedModel = effectiveModelName,
|
||||
let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }),
|
||||
let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }),
|
||||
modelInfo.isMultilingualModel {
|
||||
|
||||
// Create a simple binding that uses UserDefaults language if nil
|
||||
@ -416,7 +419,7 @@ struct ConfigurationView: View {
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
} else if let selectedModel = effectiveModelName,
|
||||
let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }),
|
||||
let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }),
|
||||
!modelInfo.isMultilingualModel {
|
||||
// Silently set to English without showing UI
|
||||
EmptyView()
|
||||
@ -488,7 +491,7 @@ struct ConfigurationView: View {
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
} else {
|
||||
Picker("", selection: providerBinding) {
|
||||
ForEach(aiService.connectedProviders, id: \.self) { provider in
|
||||
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in
|
||||
Text(provider.rawValue).tag(provider)
|
||||
}
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ struct ConfigurationRow: View {
|
||||
|
||||
private var selectedModel: String? {
|
||||
if let modelName = config.selectedWhisperModel,
|
||||
let model = whisperState.predefinedModels.first(where: { $0.name == modelName }) {
|
||||
let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }) {
|
||||
return model.displayName
|
||||
}
|
||||
return "Default"
|
||||
@ -117,7 +117,7 @@ struct ConfigurationRow: View {
|
||||
if langCode == "en" { return "English" }
|
||||
|
||||
if let modelName = config.selectedWhisperModel,
|
||||
let model = whisperState.predefinedModels.first(where: { $0.name == modelName }),
|
||||
let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }),
|
||||
let langName = model.supportedLanguages[langCode] {
|
||||
return langName
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ enum AIProvider: String, CaseIterable {
|
||||
case anthropic = "Anthropic"
|
||||
case mistral = "Mistral"
|
||||
case ollama = "Ollama"
|
||||
case elevenLabs = "ElevenLabs"
|
||||
case custom = "Custom"
|
||||
|
||||
var baseURL: String {
|
||||
@ -25,6 +26,8 @@ enum AIProvider: String, CaseIterable {
|
||||
return "https://api.anthropic.com/v1/messages"
|
||||
case .mistral:
|
||||
return "https://api.mistral.ai/v1/chat/completions"
|
||||
case .elevenLabs:
|
||||
return "https://api.elevenlabs.io/v1/speech-to-text"
|
||||
case .ollama:
|
||||
return UserDefaults.standard.string(forKey: "ollamaBaseURL") ?? "http://localhost:11434"
|
||||
case .custom:
|
||||
@ -46,6 +49,8 @@ enum AIProvider: String, CaseIterable {
|
||||
return "claude-3-5-sonnet-20241022"
|
||||
case .mistral:
|
||||
return "mistral-large-latest"
|
||||
case .elevenLabs:
|
||||
return "scribe_v1"
|
||||
case .ollama:
|
||||
return UserDefaults.standard.string(forKey: "ollamaSelectedModel") ?? "mistral"
|
||||
case .custom:
|
||||
@ -88,6 +93,8 @@ enum AIProvider: String, CaseIterable {
|
||||
"mistral-small-latest",
|
||||
"mistral-saba-latest"
|
||||
]
|
||||
case .elevenLabs:
|
||||
return ["scribe_v1", "scribe_v1_experimental"]
|
||||
case .ollama:
|
||||
return []
|
||||
case .custom:
|
||||
@ -257,6 +264,8 @@ class AIService: ObservableObject {
|
||||
verifyGeminiAPIKey(key, completion: completion)
|
||||
case .anthropic:
|
||||
verifyAnthropicAPIKey(key, completion: completion)
|
||||
case .elevenLabs:
|
||||
verifyElevenLabsAPIKey(key, completion: completion)
|
||||
default:
|
||||
verifyOpenAICompatibleAPIKey(key, completion: completion)
|
||||
}
|
||||
@ -369,6 +378,30 @@ class AIService: ObservableObject {
|
||||
}.resume()
|
||||
}
|
||||
|
||||
private func verifyElevenLabsAPIKey(_ key: String, completion: @escaping (Bool) -> Void) {
|
||||
let url = URL(string: "https://api.elevenlabs.io/v1/user")!
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "GET"
|
||||
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.addValue(key, forHTTPHeaderField: "xi-api-key")
|
||||
|
||||
URLSession.shared.dataTask(with: request) { _, response, error in
|
||||
if let error = error {
|
||||
self.logger.error("ElevenLabs API key verification failed: \(error.localizedDescription)")
|
||||
completion(false)
|
||||
return
|
||||
}
|
||||
|
||||
if let httpResponse = response as? HTTPURLResponse {
|
||||
completion(httpResponse.statusCode == 200)
|
||||
} else {
|
||||
completion(false)
|
||||
}
|
||||
}.resume()
|
||||
}
|
||||
|
||||
|
||||
|
||||
func clearAPIKey() {
|
||||
guard selectedProvider.requiresAPIKey else { return }
|
||||
|
||||
|
||||
@ -173,4 +173,47 @@ class AudioProcessor {
|
||||
logger.notice("✅ Successfully converted \(samples.count) samples")
|
||||
return samples
|
||||
}
|
||||
}
|
||||
func saveSamplesAsWav(samples: [Float], to url: URL) throws {
|
||||
let outputFormat = AVAudioFormat(
|
||||
commonFormat: .pcmFormatInt16,
|
||||
sampleRate: AudioFormat.targetSampleRate,
|
||||
channels: AudioFormat.targetChannels,
|
||||
interleaved: true
|
||||
)
|
||||
|
||||
guard let outputFormat = outputFormat else {
|
||||
throw AudioProcessingError.unsupportedFormat
|
||||
}
|
||||
|
||||
let buffer = AVAudioPCMBuffer(
|
||||
pcmFormat: outputFormat,
|
||||
frameCapacity: AVAudioFrameCount(samples.count)
|
||||
)
|
||||
|
||||
guard let buffer = buffer else {
|
||||
throw AudioProcessingError.conversionFailed
|
||||
}
|
||||
|
||||
// Convert float samples to int16
|
||||
let int16Samples = samples.map { max(-1.0, min(1.0, $0)) * Float(Int16.max) }.map { Int16($0) }
|
||||
|
||||
// Copy samples to buffer
|
||||
int16Samples.withUnsafeBufferPointer { int16Buffer in
|
||||
let int16Pointer = int16Buffer.baseAddress!
|
||||
buffer.int16ChannelData![0].update(from: int16Pointer, count: int16Samples.count)
|
||||
}
|
||||
buffer.frameLength = AVAudioFrameCount(samples.count)
|
||||
|
||||
// Create audio file
|
||||
let audioFile = try AVAudioFile(
|
||||
forWriting: url,
|
||||
settings: outputFormat.settings,
|
||||
commonFormat: .pcmFormatInt16,
|
||||
interleaved: true
|
||||
)
|
||||
|
||||
try audioFile.write(from: buffer)
|
||||
logger.notice("✅ Successfully saved processed audio to \(url.lastPathComponent)")
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,10 +15,13 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
@Published var errorMessage: String?
|
||||
|
||||
private var currentTask: Task<Void, Error>?
|
||||
private var whisperContext: WhisperContext?
|
||||
private let audioProcessor = AudioProcessor()
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
|
||||
|
||||
// Transcription services - will be initialized when needed
|
||||
private var localTranscriptionService: LocalTranscriptionService?
|
||||
private let cloudTranscriptionService = CloudTranscriptionService()
|
||||
|
||||
enum ProcessingPhase {
|
||||
case idle
|
||||
case loading
|
||||
@ -58,12 +61,14 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
|
||||
currentTask = Task {
|
||||
do {
|
||||
guard let currentModel = whisperState.currentModel else {
|
||||
guard let currentModel = whisperState.currentTranscriptionModel else {
|
||||
throw TranscriptionError.noModelSelected
|
||||
}
|
||||
|
||||
// Load Whisper model
|
||||
whisperContext = try await WhisperContext.createContext(path: currentModel.url.path)
|
||||
// Initialize local transcription service if needed
|
||||
if localTranscriptionService == nil {
|
||||
localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
||||
}
|
||||
|
||||
// Process audio file
|
||||
processingPhase = .processingAudio
|
||||
@ -82,15 +87,19 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||
|
||||
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
|
||||
try FileManager.default.copyItem(at: url, to: permanentURL)
|
||||
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
|
||||
|
||||
// Transcribe
|
||||
// Transcribe using appropriate service
|
||||
processingPhase = .transcribing
|
||||
await whisperContext?.setPrompt(whisperState.whisperPrompt.transcriptionPrompt)
|
||||
try await whisperContext?.fullTranscribe(samples: samples)
|
||||
var text = await whisperContext?.getTranscription() ?? ""
|
||||
var text: String
|
||||
|
||||
if currentModel.provider == .local {
|
||||
text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
} else {
|
||||
text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
|
||||
}
|
||||
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
text = WhisperTextFormatter.format(text)
|
||||
|
||||
// Apply word replacements if enabled
|
||||
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
|
||||
@ -115,7 +124,6 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
currentTranscription = transcription
|
||||
} catch {
|
||||
logger.error("Enhancement failed: \(error.localizedDescription)")
|
||||
messageLog += "Enhancement failed: \(error.localizedDescription). Using original transcription.\n"
|
||||
let transcription = Transcription(
|
||||
text: text,
|
||||
duration: duration,
|
||||
@ -148,28 +156,20 @@ class AudioTranscriptionManager: ObservableObject {
|
||||
|
||||
func cancelProcessing() {
|
||||
currentTask?.cancel()
|
||||
cleanupResources()
|
||||
}
|
||||
|
||||
private func finishProcessing() {
|
||||
isProcessing = false
|
||||
processingPhase = .idle
|
||||
currentTask = nil
|
||||
cleanupResources()
|
||||
}
|
||||
|
||||
private func handleError(_ error: Error) {
|
||||
logger.error("Transcription error: \(error.localizedDescription)")
|
||||
errorMessage = error.localizedDescription
|
||||
messageLog += "Error: \(error.localizedDescription)\n"
|
||||
isProcessing = false
|
||||
processingPhase = .idle
|
||||
currentTask = nil
|
||||
cleanupResources()
|
||||
}
|
||||
|
||||
private func cleanupResources() {
|
||||
whisperContext = nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -10,12 +10,15 @@ class AudioTranscriptionService: ObservableObject {
|
||||
@Published var messageLog = ""
|
||||
@Published var currentError: TranscriptionError?
|
||||
|
||||
private var whisperContext: WhisperContext?
|
||||
private let modelContext: ModelContext
|
||||
private let enhancementService: AIEnhancementService?
|
||||
private let whisperState: WhisperState
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
|
||||
|
||||
// Transcription services
|
||||
private let localTranscriptionService: LocalTranscriptionService
|
||||
private let cloudTranscriptionService = CloudTranscriptionService()
|
||||
|
||||
enum TranscriptionError: Error {
|
||||
case noAudioFile
|
||||
case transcriptionFailed
|
||||
@ -27,82 +30,67 @@ class AudioTranscriptionService: ObservableObject {
|
||||
self.modelContext = modelContext
|
||||
self.whisperState = whisperState
|
||||
self.enhancementService = whisperState.enhancementService
|
||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
|
||||
}
|
||||
|
||||
func retranscribeAudio(from url: URL, using whisperModel: WhisperModel) async throws -> Transcription {
|
||||
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
|
||||
guard FileManager.default.fileExists(atPath: url.path) else {
|
||||
throw TranscriptionError.noAudioFile
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
isTranscribing = true
|
||||
messageLog = "Loading model...\n"
|
||||
messageLog = "Starting retranscription...\n"
|
||||
}
|
||||
|
||||
// Load the whisper model if needed
|
||||
if whisperContext == nil {
|
||||
do {
|
||||
whisperContext = try await WhisperContext.createContext(path: whisperModel.url.path)
|
||||
messageLog += "Model loaded successfully.\n"
|
||||
} catch {
|
||||
logger.error("❌ Failed to load model: \(error.localizedDescription)")
|
||||
messageLog += "Failed to load model: \(error.localizedDescription)\n"
|
||||
isTranscribing = false
|
||||
throw TranscriptionError.modelNotLoaded
|
||||
do {
|
||||
// Delegate transcription to appropriate service
|
||||
var text: String
|
||||
|
||||
if model.provider == .local {
|
||||
messageLog += "Using local transcription service...\n"
|
||||
text = try await localTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
messageLog += "Local transcription completed.\n"
|
||||
} else {
|
||||
messageLog += "Using cloud transcription service...\n"
|
||||
text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model)
|
||||
messageLog += "Cloud transcription completed.\n"
|
||||
}
|
||||
}
|
||||
|
||||
guard let whisperContext = whisperContext else {
|
||||
isTranscribing = false
|
||||
throw TranscriptionError.modelNotLoaded
|
||||
}
|
||||
|
||||
// Get audio duration
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
|
||||
// Create a permanent copy of the audio file
|
||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("Recordings")
|
||||
|
||||
let fileName = "retranscribed_\(UUID().uuidString).wav"
|
||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||
|
||||
do {
|
||||
try FileManager.default.copyItem(at: url, to: permanentURL)
|
||||
} catch {
|
||||
logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)")
|
||||
messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n"
|
||||
isTranscribing = false
|
||||
throw error
|
||||
}
|
||||
|
||||
let permanentURLString = permanentURL.absoluteString
|
||||
|
||||
// Transcribe the audio
|
||||
messageLog += "Transcribing audio...\n"
|
||||
|
||||
do {
|
||||
// Read audio samples
|
||||
let samples = try readAudioSamples(permanentURL)
|
||||
|
||||
// Process with Whisper - using the same prompt as WhisperState
|
||||
messageLog += "Setting prompt: \(whisperState.whisperPrompt.transcriptionPrompt)\n"
|
||||
await whisperContext.setPrompt(whisperState.whisperPrompt.transcriptionPrompt)
|
||||
|
||||
try await whisperContext.fullTranscribe(samples: samples)
|
||||
var text = await whisperContext.getTranscription()
|
||||
// Common post-processing for both local and cloud transcriptions
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
logger.notice("✅ Retranscription completed successfully, length: \(text.count) characters")
|
||||
|
||||
// Apply word replacements if enabled
|
||||
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
|
||||
text = WordReplacementService.shared.applyReplacements(to: text)
|
||||
messageLog += "Word replacements applied.\n"
|
||||
logger.notice("✅ Word replacements applied")
|
||||
}
|
||||
|
||||
// Apply AI enhancement if enabled - using the same enhancement service as WhisperState
|
||||
// Get audio duration
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
|
||||
// Create a permanent copy of the audio file
|
||||
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("Recordings")
|
||||
|
||||
let fileName = "retranscribed_\(UUID().uuidString).wav"
|
||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||
|
||||
do {
|
||||
try FileManager.default.copyItem(at: url, to: permanentURL)
|
||||
} catch {
|
||||
logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)")
|
||||
messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n"
|
||||
isTranscribing = false
|
||||
throw error
|
||||
}
|
||||
|
||||
let permanentURLString = permanentURL.absoluteString
|
||||
|
||||
// Apply AI enhancement if enabled
|
||||
if let enhancementService = enhancementService,
|
||||
enhancementService.isEnhancementEnabled,
|
||||
enhancementService.isConfigured {
|
||||
@ -182,19 +170,4 @@ class AudioTranscriptionService: ObservableObject {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private func readAudioSamples(_ url: URL) throws -> [Float] {
|
||||
return try decodeWaveFile(url)
|
||||
}
|
||||
|
||||
private func decodeWaveFile(_ url: URL) throws -> [Float] {
|
||||
let data = try Data(contentsOf: url)
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||
}
|
||||
}
|
||||
return floats
|
||||
}
|
||||
}
|
||||
211
VoiceInk/Services/CloudTranscriptionService.swift
Normal file
211
VoiceInk/Services/CloudTranscriptionService.swift
Normal file
@ -0,0 +1,211 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
enum CloudTranscriptionError: Error, LocalizedError {
|
||||
case unsupportedProvider
|
||||
case missingAPIKey
|
||||
case invalidAPIKey
|
||||
case audioFileNotFound
|
||||
case apiRequestFailed(statusCode: Int, message: String)
|
||||
case networkError(Error)
|
||||
case noTranscriptionReturned
|
||||
case dataEncodingError
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .unsupportedProvider:
|
||||
return "The model provider is not supported by this service."
|
||||
case .missingAPIKey:
|
||||
return "API key for this service is missing. Please configure it in the settings."
|
||||
case .invalidAPIKey:
|
||||
return "The provided API key is invalid."
|
||||
case .audioFileNotFound:
|
||||
return "The audio file to transcribe could not be found."
|
||||
case .apiRequestFailed(let statusCode, let message):
|
||||
return "The API request failed with status code \(statusCode): \(message)"
|
||||
case .networkError(let error):
|
||||
return "A network error occurred: \(error.localizedDescription)"
|
||||
case .noTranscriptionReturned:
|
||||
return "The API returned an empty or invalid response."
|
||||
case .dataEncodingError:
|
||||
return "Failed to encode the request body."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CloudTranscriptionService: TranscriptionService {
|
||||
|
||||
private struct APIConfig {
|
||||
let url: URL
|
||||
let apiKey: String
|
||||
let modelName: String
|
||||
}
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
let config = try getAPIConfig(for: model)
|
||||
|
||||
var request: URLRequest
|
||||
var body: Data
|
||||
|
||||
switch model.provider {
|
||||
case .elevenLabs:
|
||||
let boundary = "Boundary-\(UUID().uuidString)"
|
||||
var elevenLabsRequest = URLRequest(url: config.url)
|
||||
elevenLabsRequest.httpMethod = "POST"
|
||||
elevenLabsRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
|
||||
elevenLabsRequest.setValue(config.apiKey, forHTTPHeaderField: "xi-api-key")
|
||||
body = try createElevenLabsRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
|
||||
request = elevenLabsRequest
|
||||
|
||||
case .groq:
|
||||
let boundary = "Boundary-\(UUID().uuidString)"
|
||||
var openAICompatibleRequest = URLRequest(url: config.url)
|
||||
openAICompatibleRequest.httpMethod = "POST"
|
||||
openAICompatibleRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
|
||||
openAICompatibleRequest.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization")
|
||||
body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
|
||||
request = openAICompatibleRequest
|
||||
|
||||
|
||||
|
||||
default:
|
||||
throw CloudTranscriptionError.unsupportedProvider
|
||||
}
|
||||
|
||||
let (data, response) = try await URLSession.shared.upload(for: request, from: body)
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw CloudTranscriptionError.networkError(URLError(.badServerResponse))
|
||||
}
|
||||
|
||||
if !(200...299).contains(httpResponse.statusCode) {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "No error message"
|
||||
throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage)
|
||||
}
|
||||
|
||||
do {
|
||||
let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data)
|
||||
return transcriptionResponse.text
|
||||
} catch {
|
||||
throw CloudTranscriptionError.noTranscriptionReturned
|
||||
}
|
||||
}
|
||||
|
||||
private func getAPIConfig(for model: any TranscriptionModel) throws -> APIConfig {
|
||||
let providerKey: String
|
||||
let apiURL: URL
|
||||
|
||||
switch model.provider {
|
||||
case .groq:
|
||||
providerKey = "GROQ"
|
||||
apiURL = URL(string: "https://api.groq.com/openai/v1/audio/transcriptions")!
|
||||
case .elevenLabs:
|
||||
providerKey = "ElevenLabs"
|
||||
apiURL = URL(string: "https://api.elevenlabs.io/v1/speech-to-text")!
|
||||
default:
|
||||
throw CloudTranscriptionError.unsupportedProvider
|
||||
}
|
||||
|
||||
guard let apiKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey"), !apiKey.isEmpty else {
|
||||
throw CloudTranscriptionError.missingAPIKey
|
||||
}
|
||||
|
||||
return APIConfig(url: apiURL, apiKey: apiKey, modelName: model.name)
|
||||
}
|
||||
|
||||
private func createElevenLabsRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
|
||||
var body = Data()
|
||||
let crlf = "\r\n"
|
||||
|
||||
guard let audioData = try? Data(contentsOf: audioURL) else {
|
||||
throw CloudTranscriptionError.audioFileNotFound
|
||||
}
|
||||
|
||||
// File
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(audioData)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
// Model ID
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"model_id\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(modelName.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
|
||||
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"language_code\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(selectedLanguage.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
|
||||
var body = Data()
|
||||
let crlf = "\r\n"
|
||||
|
||||
guard let audioData = try? Data(contentsOf: audioURL) else {
|
||||
throw CloudTranscriptionError.audioFileNotFound
|
||||
}
|
||||
|
||||
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
|
||||
let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(audioData)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(modelName.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(selectedLanguage.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
// Include prompt for OpenAI-compatible APIs
|
||||
if !prompt.isEmpty {
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(prompt.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append("json".data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append("0".data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
private struct TranscriptionResponse: Decodable {
|
||||
let text: String
|
||||
let language: String?
|
||||
let duration: Double?
|
||||
let x_groq: GroqMetadata?
|
||||
|
||||
struct GroqMetadata: Decodable {
|
||||
let id: String?
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
87
VoiceInk/Services/LocalTranscriptionService.swift
Normal file
87
VoiceInk/Services/LocalTranscriptionService.swift
Normal file
@ -0,0 +1,87 @@
|
||||
import Foundation
|
||||
import AVFoundation
|
||||
import os
|
||||
|
||||
class LocalTranscriptionService: TranscriptionService {
|
||||
|
||||
private var whisperContext: WhisperContext?
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "LocalTranscriptionService")
|
||||
private let modelsDirectory: URL
|
||||
private weak var whisperState: WhisperState?
|
||||
|
||||
init(modelsDirectory: URL, whisperState: WhisperState? = nil) {
|
||||
self.modelsDirectory = modelsDirectory
|
||||
self.whisperState = whisperState
|
||||
}
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
guard let localModel = model as? LocalModel else {
|
||||
throw WhisperError.couldNotInitializeContext
|
||||
}
|
||||
|
||||
logger.notice("Initiating local transcription for model: \(localModel.displayName)")
|
||||
|
||||
// Check if the required model is already loaded in WhisperState
|
||||
if let whisperState = whisperState,
|
||||
await whisperState.isModelLoaded,
|
||||
let loadedContext = await whisperState.whisperContext,
|
||||
await whisperState.currentModel?.name == localModel.name {
|
||||
|
||||
logger.notice("✅ Using already loaded model: \(localModel.name)")
|
||||
whisperContext = loadedContext
|
||||
} else {
|
||||
// Model not loaded or wrong model loaded, proceed with loading
|
||||
let modelURL = modelsDirectory.appendingPathComponent(localModel.filename)
|
||||
|
||||
guard FileManager.default.fileExists(atPath: modelURL.path) else {
|
||||
logger.error("Model file not found at path: \(modelURL.path)")
|
||||
throw WhisperError.couldNotInitializeContext
|
||||
}
|
||||
|
||||
logger.notice("Loading model: \(localModel.name)")
|
||||
do {
|
||||
whisperContext = try await WhisperContext.createContext(path: modelURL.path)
|
||||
} catch {
|
||||
logger.error("Failed to load model: \(localModel.name) - \(error.localizedDescription)")
|
||||
throw WhisperError.couldNotInitializeContext
|
||||
}
|
||||
}
|
||||
|
||||
guard let whisperContext = whisperContext else {
|
||||
logger.error("Cannot transcribe: Model could not be loaded")
|
||||
throw WhisperError.couldNotInitializeContext
|
||||
}
|
||||
|
||||
// Read audio data
|
||||
let data = try readAudioSamples(audioURL)
|
||||
|
||||
// Set prompt
|
||||
let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
|
||||
await whisperContext.setPrompt(currentPrompt)
|
||||
|
||||
// Transcribe
|
||||
await whisperContext.fullTranscribe(samples: data)
|
||||
let text = await whisperContext.getTranscription()
|
||||
|
||||
logger.notice("✅ Local transcription completed successfully.")
|
||||
|
||||
// Only release resources if we created a new context (not using the shared one)
|
||||
if await whisperState?.whisperContext !== whisperContext {
|
||||
await whisperContext.releaseResources()
|
||||
self.whisperContext = nil
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
private func readAudioSamples(_ url: URL) throws -> [Float] {
|
||||
let data = try Data(contentsOf: url)
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||
}
|
||||
}
|
||||
return floats
|
||||
}
|
||||
}
|
||||
14
VoiceInk/Services/TranscriptionService.swift
Normal file
14
VoiceInk/Services/TranscriptionService.swift
Normal file
@ -0,0 +1,14 @@
|
||||
import Foundation
|
||||
|
||||
/// A protocol defining the interface for a transcription service.
|
||||
/// This allows for a unified way to handle both local and cloud-based transcription models.
|
||||
protocol TranscriptionService {
|
||||
/// Transcribes the audio from a given file URL.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - audioURL: The URL of the audio file to transcribe.
|
||||
/// - model: The `TranscriptionModel` to use for transcription. This provides context about the provider (local, OpenAI, etc.).
|
||||
/// - Returns: The transcribed text as a `String`.
|
||||
/// - Throws: An error if the transcription fails.
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String
|
||||
}
|
||||
@ -44,7 +44,7 @@ struct APIKeyManagementView: View {
|
||||
|
||||
// Provider Selection
|
||||
Picker("AI Provider", selection: $aiService.selectedProvider) {
|
||||
ForEach(AIProvider.allCases, id: \.self) { provider in
|
||||
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in
|
||||
Text(provider.rawValue).tag(provider)
|
||||
}
|
||||
}
|
||||
@ -385,6 +385,8 @@ struct APIKeyManagementView: View {
|
||||
URL(string: "https://console.anthropic.com/settings/keys")!
|
||||
case .mistral:
|
||||
URL(string: "https://console.mistral.ai/api-keys")!
|
||||
case .elevenLabs:
|
||||
URL(string: "https://elevenlabs.io/speech-synthesis")!
|
||||
case .ollama, .custom:
|
||||
URL(string: "")! // This case should never be reached
|
||||
}
|
||||
|
||||
@ -388,7 +388,7 @@ struct AudioPlayerView: View {
|
||||
}
|
||||
|
||||
private func retranscribeAudio() {
|
||||
guard let currentModel = whisperState.currentModel else {
|
||||
guard let currentTranscriptionModel = whisperState.currentTranscriptionModel else {
|
||||
errorMessage = "No transcription model selected"
|
||||
showRetranscribeError = true
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
|
||||
@ -401,7 +401,7 @@ struct AudioPlayerView: View {
|
||||
|
||||
Task {
|
||||
do {
|
||||
let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentModel)
|
||||
let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentTranscriptionModel)
|
||||
await MainActor.run {
|
||||
isRetranscribing = false
|
||||
showRetranscribeSuccess = true
|
||||
|
||||
@ -26,23 +26,18 @@ struct LanguageSelectionView: View {
|
||||
|
||||
// Function to check if current model is multilingual
|
||||
private func isMultilingualModel() -> Bool {
|
||||
guard let currentModel = whisperState.currentModel,
|
||||
let predefinedModel = PredefinedModels.models.first(where: { $0.name == currentModel.name }) else {
|
||||
guard let currentModel = whisperState.currentTranscriptionModel else {
|
||||
return false
|
||||
}
|
||||
return predefinedModel.isMultilingualModel
|
||||
return currentModel.isMultilingualModel
|
||||
}
|
||||
|
||||
// Function to get current model's supported languages
|
||||
private func getCurrentModelLanguages() -> [String: String] {
|
||||
guard let currentModel = whisperState.currentModel,
|
||||
let predefinedModel = PredefinedModels.models.first(where: {
|
||||
$0.name == currentModel.name
|
||||
})
|
||||
else {
|
||||
guard let currentModel = whisperState.currentTranscriptionModel else {
|
||||
return ["en": "English"] // Default to English if no model found
|
||||
}
|
||||
return predefinedModel.supportedLanguages
|
||||
return currentModel.supportedLanguages
|
||||
}
|
||||
|
||||
// Get the display name of the current language
|
||||
@ -74,16 +69,13 @@ struct LanguageSelectionView: View {
|
||||
Text("Transcription Language")
|
||||
.font(.headline)
|
||||
|
||||
if let currentModel = whisperState.currentModel,
|
||||
let predefinedModel = PredefinedModels.models.first(where: {
|
||||
$0.name == currentModel.name
|
||||
})
|
||||
if let currentModel = whisperState.currentTranscriptionModel
|
||||
{
|
||||
if isMultilingualModel() {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
Picker("Select Language", selection: $selectedLanguage) {
|
||||
ForEach(
|
||||
predefinedModel.supportedLanguages.sorted(by: {
|
||||
currentModel.supportedLanguages.sorted(by: {
|
||||
if $0.key == "auto" { return true }
|
||||
if $1.key == "auto" { return false }
|
||||
return $0.value < $1.value
|
||||
@ -97,7 +89,7 @@ struct LanguageSelectionView: View {
|
||||
updateLanguage(newValue)
|
||||
}
|
||||
|
||||
Text("Current model: \(predefinedModel.displayName)")
|
||||
Text("Current model: \(currentModel.displayName)")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
@ -114,7 +106,7 @@ struct LanguageSelectionView: View {
|
||||
.font(.subheadline)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Text("Current model: \(predefinedModel.displayName)")
|
||||
Text("Current model: \(currentModel.displayName)")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
|
||||
|
||||
@ -80,26 +80,21 @@ struct MenuBarView: View {
|
||||
}
|
||||
|
||||
Menu {
|
||||
ForEach(whisperState.availableModels) { model in
|
||||
ForEach(whisperState.usableModels, id: \.id) { model in
|
||||
Button {
|
||||
Task {
|
||||
await whisperState.setDefaultModel(model)
|
||||
await whisperState.setDefaultTranscriptionModel(model)
|
||||
}
|
||||
} label: {
|
||||
HStack {
|
||||
Text(PredefinedModels.models.first { $0.name == model.name }?.displayName ?? model.name)
|
||||
if whisperState.currentModel?.name == model.name {
|
||||
Text(model.displayName)
|
||||
if whisperState.currentTranscriptionModel?.id == model.id {
|
||||
Image(systemName: "checkmark")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if whisperState.availableModels.isEmpty {
|
||||
Text("No models downloaded")
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
|
||||
Divider()
|
||||
|
||||
Button("Manage Models") {
|
||||
@ -107,7 +102,7 @@ struct MenuBarView: View {
|
||||
}
|
||||
} label: {
|
||||
HStack {
|
||||
Text("Model: \(PredefinedModels.models.first { $0.name == whisperState.currentModel?.name }?.displayName ?? "None")")
|
||||
Text("Model: \(whisperState.currentTranscriptionModel?.displayName ?? "None")")
|
||||
Image(systemName: "chevron.up.chevron.down")
|
||||
.font(.system(size: 10))
|
||||
}
|
||||
|
||||
@ -2,7 +2,49 @@ import SwiftUI
|
||||
import AppKit
|
||||
|
||||
struct ModelCardRowView: View {
|
||||
let model: PredefinedModel
|
||||
let model: any TranscriptionModel
|
||||
let isDownloaded: Bool
|
||||
let isCurrent: Bool
|
||||
let downloadProgress: [String: Double]
|
||||
let modelURL: URL?
|
||||
|
||||
// Actions
|
||||
var deleteAction: () -> Void
|
||||
var setDefaultAction: () -> Void
|
||||
var downloadAction: () -> Void
|
||||
|
||||
var body: some View {
|
||||
Group {
|
||||
switch model.provider {
|
||||
case .local:
|
||||
if let localModel = model as? LocalModel {
|
||||
LocalModelCardView(
|
||||
model: localModel,
|
||||
isDownloaded: isDownloaded,
|
||||
isCurrent: isCurrent,
|
||||
downloadProgress: downloadProgress,
|
||||
modelURL: modelURL,
|
||||
deleteAction: deleteAction,
|
||||
setDefaultAction: setDefaultAction,
|
||||
downloadAction: downloadAction
|
||||
)
|
||||
}
|
||||
case .groq, .elevenLabs:
|
||||
if let cloudModel = model as? CloudModel {
|
||||
CloudModelCardView(
|
||||
model: cloudModel,
|
||||
isCurrent: isCurrent,
|
||||
setDefaultAction: setDefaultAction
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Local Model Card View
|
||||
struct LocalModelCardView: View {
|
||||
let model: LocalModel
|
||||
let isDownloaded: Bool
|
||||
let isCurrent: Bool
|
||||
let downloadProgress: [String: Double]
|
||||
@ -36,8 +78,6 @@ struct ModelCardRowView: View {
|
||||
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
|
||||
}
|
||||
|
||||
// MARK: - Components
|
||||
|
||||
private var headerSection: some View {
|
||||
HStack(alignment: .firstTextBaseline) {
|
||||
Text(model.displayName)
|
||||
@ -162,14 +202,210 @@ struct ModelCardRowView: View {
|
||||
Label("Delete Model", systemImage: "trash")
|
||||
}
|
||||
|
||||
if isDownloaded {
|
||||
Button {
|
||||
if let modelURL = modelURL {
|
||||
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
|
||||
}
|
||||
} label: {
|
||||
Label("Show in Finder", systemImage: "folder")
|
||||
Button {
|
||||
if let modelURL = modelURL {
|
||||
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
|
||||
}
|
||||
} label: {
|
||||
Label("Show in Finder", systemImage: "folder")
|
||||
}
|
||||
} label: {
|
||||
Image(systemName: "ellipsis.circle")
|
||||
.font(.system(size: 14))
|
||||
}
|
||||
.menuStyle(.borderlessButton)
|
||||
.menuIndicator(.hidden)
|
||||
.frame(width: 20, height: 20)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Cloud Model Card View
|
||||
struct CloudModelCardView: View {
|
||||
let model: CloudModel
|
||||
let isCurrent: Bool
|
||||
var setDefaultAction: () -> Void
|
||||
|
||||
@StateObject private var aiService = AIService()
|
||||
@State private var isExpanded = false
|
||||
@State private var apiKey = ""
|
||||
@State private var isVerifying = false
|
||||
@State private var verificationStatus: VerificationStatus = .none
|
||||
|
||||
enum VerificationStatus {
|
||||
case none, verifying, success, failure
|
||||
}
|
||||
|
||||
private var isConfigured: Bool {
|
||||
guard let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") else {
|
||||
return false
|
||||
}
|
||||
return !savedKey.isEmpty
|
||||
}
|
||||
|
||||
private var providerKey: String {
|
||||
switch model.provider {
|
||||
case .groq:
|
||||
return "GROQ"
|
||||
case .elevenLabs:
|
||||
return "ElevenLabs"
|
||||
default:
|
||||
return model.provider.rawValue
|
||||
}
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 0) {
|
||||
// Main card content
|
||||
HStack(alignment: .top, spacing: 16) {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
headerSection
|
||||
metadataSection
|
||||
descriptionSection
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
|
||||
actionSection
|
||||
}
|
||||
.padding(16)
|
||||
|
||||
// Expandable configuration section
|
||||
if isExpanded {
|
||||
Divider()
|
||||
.padding(.horizontal, 16)
|
||||
|
||||
configurationSection
|
||||
.padding(16)
|
||||
}
|
||||
}
|
||||
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
|
||||
.onAppear {
|
||||
loadSavedAPIKey()
|
||||
}
|
||||
}
|
||||
|
||||
private var headerSection: some View {
|
||||
HStack(alignment: .firstTextBaseline) {
|
||||
Text(model.displayName)
|
||||
.font(.system(size: 13, weight: .semibold))
|
||||
.foregroundColor(Color(.labelColor))
|
||||
|
||||
statusBadge
|
||||
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
|
||||
private var statusBadge: some View {
|
||||
Group {
|
||||
if isCurrent {
|
||||
Text("Default")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color.accentColor))
|
||||
.foregroundColor(.white)
|
||||
} else if isConfigured {
|
||||
Text("Configured")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color(.systemGreen).opacity(0.2)))
|
||||
.foregroundColor(Color(.systemGreen))
|
||||
} else {
|
||||
Text("Setup Required")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color(.systemOrange).opacity(0.2)))
|
||||
.foregroundColor(Color(.systemOrange))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var metadataSection: some View {
|
||||
HStack(spacing: 16) {
|
||||
// Provider
|
||||
Label(model.provider.rawValue, systemImage: "cloud")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
|
||||
// Language
|
||||
Label(model.language, systemImage: "globe")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
|
||||
// Speed
|
||||
HStack(spacing: 4) {
|
||||
Text("Speed")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
progressDotsWithNumber(value: model.speed * 10)
|
||||
}
|
||||
|
||||
// Accuracy
|
||||
HStack(spacing: 4) {
|
||||
Text("Accuracy")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
progressDotsWithNumber(value: model.accuracy * 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var descriptionSection: some View {
|
||||
Text(model.description)
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(2)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
.padding(.top, 4)
|
||||
}
|
||||
|
||||
private var actionSection: some View {
|
||||
HStack(spacing: 8) {
|
||||
if isCurrent {
|
||||
Text("Default Model")
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
} else if isConfigured {
|
||||
Button(action: setDefaultAction) {
|
||||
Text("Set as Default")
|
||||
.font(.system(size: 12))
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
} else {
|
||||
Button(action: {
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded.toggle()
|
||||
}
|
||||
}) {
|
||||
HStack(spacing: 4) {
|
||||
Text("Configure")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
Image(systemName: "gear")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 6)
|
||||
.background(
|
||||
Capsule()
|
||||
.fill(Color(.controlAccentColor))
|
||||
.shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1)
|
||||
)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
}
|
||||
|
||||
if isConfigured {
|
||||
Menu {
|
||||
Button {
|
||||
clearAPIKey()
|
||||
} label: {
|
||||
Label("Remove API Key", systemImage: "trash")
|
||||
}
|
||||
} label: {
|
||||
Image(systemName: "ellipsis.circle")
|
||||
@ -182,40 +418,134 @@ struct ModelCardRowView: View {
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private var downloadComponents: [(String, Double)] {
|
||||
[
|
||||
("Model", downloadProgress[model.name + "_main"] ?? 0),
|
||||
("CoreML", downloadProgress[model.name + "_coreml"] ?? 0)
|
||||
].filter { $0.1 > 0 }
|
||||
}
|
||||
|
||||
private func progressDotsWithNumber(value: Double) -> some View {
|
||||
HStack(spacing: 4) {
|
||||
progressDots(value: value)
|
||||
Text(String(format: "%.1f", value))
|
||||
.font(.system(size: 10, weight: .medium, design: .monospaced))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
}
|
||||
}
|
||||
|
||||
private func progressDots(value: Double) -> some View {
|
||||
HStack(spacing: 2) {
|
||||
ForEach(0..<5) { index in
|
||||
Circle()
|
||||
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
|
||||
.frame(width: 6, height: 6)
|
||||
private var configurationSection: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
Text("API Key Configuration")
|
||||
.font(.system(size: 13, weight: .semibold))
|
||||
.foregroundColor(Color(.labelColor))
|
||||
|
||||
HStack(spacing: 8) {
|
||||
SecureField("Enter your \(model.provider.rawValue) API key", text: $apiKey)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.disabled(isVerifying)
|
||||
|
||||
Button(action: verifyAPIKey) {
|
||||
HStack(spacing: 4) {
|
||||
if isVerifying {
|
||||
ProgressView()
|
||||
.scaleEffect(0.7)
|
||||
.frame(width: 12, height: 12)
|
||||
} else {
|
||||
Image(systemName: verificationStatus == .success ? "checkmark" : "checkmark.shield")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
}
|
||||
Text(isVerifying ? "Verifying..." : "Verify")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 6)
|
||||
.background(
|
||||
Capsule()
|
||||
.fill(verificationStatus == .success ? Color(.systemGreen) : Color(.controlAccentColor))
|
||||
)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(apiKey.isEmpty || isVerifying)
|
||||
}
|
||||
|
||||
if verificationStatus == .failure {
|
||||
Text("Invalid API key. Please check your key and try again.")
|
||||
.font(.caption)
|
||||
.foregroundColor(Color(.systemRed))
|
||||
} else if verificationStatus == .success {
|
||||
Text("API key verified successfully!")
|
||||
.font(.caption)
|
||||
.foregroundColor(Color(.systemGreen))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func performanceColor(value: Double) -> Color {
|
||||
switch value {
|
||||
case 0.8...1.0: return Color(.systemGreen)
|
||||
case 0.6..<0.8: return Color(.systemYellow)
|
||||
case 0.4..<0.6: return Color(.systemOrange)
|
||||
default: return Color(.systemRed)
|
||||
private func loadSavedAPIKey() {
|
||||
if let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") {
|
||||
apiKey = savedKey
|
||||
verificationStatus = .success
|
||||
}
|
||||
}
|
||||
|
||||
private func verifyAPIKey() {
|
||||
guard !apiKey.isEmpty else { return }
|
||||
|
||||
isVerifying = true
|
||||
verificationStatus = .verifying
|
||||
|
||||
// Set the provider in AIService temporarily for verification
|
||||
let originalProvider = aiService.selectedProvider
|
||||
if model.provider == .groq {
|
||||
aiService.selectedProvider = .groq
|
||||
} else if model.provider == .elevenLabs {
|
||||
aiService.selectedProvider = .elevenLabs
|
||||
}
|
||||
|
||||
aiService.verifyAPIKey(apiKey) { [self] isValid in
|
||||
DispatchQueue.main.async {
|
||||
self.isVerifying = false
|
||||
if isValid {
|
||||
self.verificationStatus = .success
|
||||
// Save the API key
|
||||
UserDefaults.standard.set(self.apiKey, forKey: "\(self.providerKey)APIKey")
|
||||
|
||||
// Collapse the configuration section after successful verification
|
||||
withAnimation(.easeInOut(duration: 0.3)) {
|
||||
self.isExpanded = false
|
||||
}
|
||||
} else {
|
||||
self.verificationStatus = .failure
|
||||
}
|
||||
|
||||
// Restore original provider
|
||||
aiService.selectedProvider = originalProvider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func clearAPIKey() {
|
||||
UserDefaults.standard.removeObject(forKey: "\(providerKey)APIKey")
|
||||
apiKey = ""
|
||||
verificationStatus = .none
|
||||
|
||||
withAnimation(.easeInOut(duration: 0.3)) {
|
||||
isExpanded = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helper Views and Functions
|
||||
|
||||
private func progressDotsWithNumber(value: Double) -> some View {
|
||||
HStack(spacing: 4) {
|
||||
progressDots(value: value)
|
||||
Text(String(format: "%.1f", value))
|
||||
.font(.system(size: 10, weight: .medium, design: .monospaced))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
}
|
||||
}
|
||||
|
||||
private func progressDots(value: Double) -> some View {
|
||||
HStack(spacing: 2) {
|
||||
ForEach(0..<5) { index in
|
||||
Circle()
|
||||
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
|
||||
.frame(width: 6, height: 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func performanceColor(value: Double) -> Color {
|
||||
switch value {
|
||||
case 0.8...1.0: return Color(.systemGreen)
|
||||
case 0.6..<0.8: return Color(.systemYellow)
|
||||
case 0.4..<0.6: return Color(.systemOrange)
|
||||
default: return Color(.systemRed)
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,9 +39,7 @@ struct ModelManagementView: View {
|
||||
Text("Default Model")
|
||||
.font(.headline)
|
||||
.foregroundColor(.secondary)
|
||||
Text(whisperState.currentModel.flatMap { model in
|
||||
PredefinedModels.models.first { $0.name == model.name }?.displayName
|
||||
} ?? "No model selected")
|
||||
Text(whisperState.currentTranscriptionModel?.displayName ?? "No model selected")
|
||||
.font(.title2)
|
||||
.fontWeight(.bold)
|
||||
}
|
||||
@ -62,7 +60,7 @@ struct ModelManagementView: View {
|
||||
.font(.title3)
|
||||
.fontWeight(.semibold)
|
||||
|
||||
Text("(\(whisperState.predefinedModels.count))")
|
||||
Text("(\(whisperState.allAvailableModels.count))")
|
||||
.foregroundColor(.secondary)
|
||||
.font(.subheadline)
|
||||
|
||||
@ -70,11 +68,11 @@ struct ModelManagementView: View {
|
||||
}
|
||||
|
||||
VStack(spacing: 12) {
|
||||
ForEach(whisperState.predefinedModels) { model in
|
||||
ForEach(whisperState.allAvailableModels, id: \.id) { model in
|
||||
ModelCardRowView(
|
||||
model: model,
|
||||
isDownloaded: whisperState.availableModels.contains { $0.name == model.name },
|
||||
isCurrent: whisperState.currentModel?.name == model.name,
|
||||
isCurrent: whisperState.currentTranscriptionModel?.name == model.name,
|
||||
downloadProgress: whisperState.downloadProgress,
|
||||
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
|
||||
deleteAction: {
|
||||
@ -83,15 +81,15 @@ struct ModelManagementView: View {
|
||||
}
|
||||
},
|
||||
setDefaultAction: {
|
||||
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
Task {
|
||||
await whisperState.setDefaultModel(downloadedModel)
|
||||
}
|
||||
Task {
|
||||
await whisperState.setDefaultTranscriptionModel(model)
|
||||
}
|
||||
},
|
||||
downloadAction: {
|
||||
Task {
|
||||
await whisperState.downloadModel(model)
|
||||
if let localModel = model as? LocalModel {
|
||||
Task {
|
||||
await whisperState.downloadModel(localModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ struct OnboardingModelDownloadView: View {
|
||||
@State private var isModelSet = false
|
||||
@State private var showTutorial = false
|
||||
|
||||
private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" }!
|
||||
private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" } as! LocalModel
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
|
||||
@ -122,7 +122,7 @@ extension WhisperState {
|
||||
}
|
||||
|
||||
// Shows an alert about Core ML support and first-run optimization
|
||||
private func showCoreMLAlert(for model: PredefinedModel, completion: @escaping () -> Void) {
|
||||
private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) {
|
||||
Task { @MainActor in
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Core ML Support for \(model.displayName) Model"
|
||||
@ -138,7 +138,7 @@ extension WhisperState {
|
||||
}
|
||||
}
|
||||
|
||||
func downloadModel(_ model: PredefinedModel) async {
|
||||
func downloadModel(_ model: LocalModel) async {
|
||||
guard let url = URL(string: model.downloadURL) else { return }
|
||||
|
||||
// Check if model supports Core ML (non-quantized models)
|
||||
@ -160,7 +160,7 @@ extension WhisperState {
|
||||
}
|
||||
}
|
||||
|
||||
private func performModelDownload(_ model: PredefinedModel, _ url: URL) async {
|
||||
private func performModelDownload(_ model: LocalModel, _ url: URL) async {
|
||||
do {
|
||||
let whisperModel = try await downloadMainModel(model, from: url)
|
||||
|
||||
@ -176,7 +176,7 @@ extension WhisperState {
|
||||
}
|
||||
}
|
||||
|
||||
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
|
||||
private func downloadMainModel(_ model: LocalModel, from url: URL) async throws -> WhisperModel {
|
||||
let progressKeyMain = model.name + "_main"
|
||||
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
|
||||
|
||||
@ -232,8 +232,7 @@ extension WhisperState {
|
||||
return model
|
||||
}
|
||||
|
||||
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
|
||||
currentError = .modelDownloadFailed
|
||||
private func handleModelDownloadError(_ model: LocalModel, _ error: Error) {
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_main")
|
||||
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
|
||||
}
|
||||
@ -262,7 +261,6 @@ extension WhisperState {
|
||||
}
|
||||
} catch {
|
||||
logError("Error deleting model: \(model.name)", error)
|
||||
currentError = .modelDeletionFailed
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
18
VoiceInk/Whisper/WhisperState+ModelQueries.swift
Normal file
18
VoiceInk/Whisper/WhisperState+ModelQueries.swift
Normal file
@ -0,0 +1,18 @@
|
||||
import Foundation
|
||||
|
||||
extension WhisperState {
|
||||
var usableModels: [any TranscriptionModel] {
|
||||
allAvailableModels.filter { model in
|
||||
switch model.provider {
|
||||
case .local:
|
||||
return availableModels.contains { $0.name == model.name }
|
||||
case .groq:
|
||||
let key = UserDefaults.standard.string(forKey: "GROQAPIKey")
|
||||
return key != nil && !key!.isEmpty
|
||||
case .elevenLabs:
|
||||
let key = UserDefaults.standard.string(forKey: "ElevenLabsAPIKey")
|
||||
return key != nil && !key!.isEmpty
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,9 +12,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
@Published var canTranscribe = false
|
||||
@Published var isRecording = false
|
||||
@Published var currentModel: WhisperModel?
|
||||
@Published var currentTranscriptionModel: (any TranscriptionModel)?
|
||||
@Published var isModelLoading = false
|
||||
@Published var availableModels: [WhisperModel] = []
|
||||
@Published var predefinedModels: [PredefinedModel] = PredefinedModels.models
|
||||
@Published var allAvailableModels: [any TranscriptionModel] = PredefinedModels.models
|
||||
@Published var clipboardMessage = ""
|
||||
@Published var miniRecorderError: String?
|
||||
@Published var isProcessing = false
|
||||
@ -53,6 +54,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
|
||||
let modelContext: ModelContext
|
||||
|
||||
// Transcription Services
|
||||
private var localTranscriptionService: LocalTranscriptionService
|
||||
private let cloudTranscriptionService = CloudTranscriptionService()
|
||||
|
||||
private var modelUrl: URL? {
|
||||
let possibleURLs = [
|
||||
Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "Models"),
|
||||
@ -85,21 +90,28 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
|
||||
init(modelContext: ModelContext, enhancementService: AIEnhancementService? = nil) {
|
||||
self.modelContext = modelContext
|
||||
self.modelsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
let appSupportDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("WhisperModels")
|
||||
self.recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
|
||||
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
|
||||
.appendingPathComponent("Recordings")
|
||||
|
||||
self.modelsDirectory = appSupportDirectory.appendingPathComponent("WhisperModels")
|
||||
self.recordingsDirectory = appSupportDirectory.appendingPathComponent("Recordings")
|
||||
|
||||
// Initialize services without whisperState reference first
|
||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory)
|
||||
|
||||
self.enhancementService = enhancementService
|
||||
self.licenseViewModel = LicenseViewModel()
|
||||
|
||||
super.init()
|
||||
|
||||
// Set the whisperState reference after super.init()
|
||||
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self)
|
||||
|
||||
setupNotifications()
|
||||
createModelsDirectoryIfNeeded()
|
||||
createRecordingsDirectoryIfNeeded()
|
||||
loadAvailableModels()
|
||||
loadCurrentTranscriptionModel()
|
||||
|
||||
if let savedModelName = UserDefaults.standard.string(forKey: "CurrentModel"),
|
||||
let savedModel = availableModels.first(where: { $0.name == savedModelName }) {
|
||||
@ -139,11 +151,11 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
logger.error("❌ No recorded file found after stopping recording")
|
||||
}
|
||||
} else {
|
||||
guard currentModel != nil else {
|
||||
guard currentTranscriptionModel != nil else {
|
||||
await MainActor.run {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "No Whisper Model Selected"
|
||||
alert.informativeText = "Please select a default whisper model in AI Models tab before recording."
|
||||
alert.messageText = "No AI Model Selected"
|
||||
alert.informativeText = "Please select a default AI model in AI Models tab before recording."
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
@ -174,11 +186,15 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
|
||||
await ActiveWindowService.shared.applyConfigurationForCurrentApp()
|
||||
|
||||
if let currentModel = await self.currentModel, await self.whisperContext == nil {
|
||||
do {
|
||||
try await self.loadModel(currentModel)
|
||||
} catch {
|
||||
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
|
||||
// Only load model if it's a local model and not already loaded
|
||||
if let model = self.currentTranscriptionModel, model.provider == .local {
|
||||
if let localWhisperModel = self.availableModels.first(where: { $0.name == model.name }),
|
||||
self.whisperContext == nil {
|
||||
do {
|
||||
try await self.loadModel(localWhisperModel)
|
||||
} catch {
|
||||
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -256,11 +272,13 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
await cleanupModelResources()
|
||||
return
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
isProcessing = true
|
||||
isTranscribing = true
|
||||
canTranscribe = false
|
||||
}
|
||||
|
||||
defer {
|
||||
if shouldCancelRecording {
|
||||
Task {
|
||||
@ -268,56 +286,33 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
}
|
||||
}
|
||||
}
|
||||
guard let currentModel = currentModel else {
|
||||
|
||||
guard let model = currentTranscriptionModel else {
|
||||
logger.error("❌ Cannot transcribe: No model selected")
|
||||
currentError = .modelLoadFailed
|
||||
return
|
||||
}
|
||||
if whisperContext == nil {
|
||||
logger.notice("🔄 Model not loaded yet, attempting to load now: \(currentModel.name)")
|
||||
do {
|
||||
try await loadModel(currentModel)
|
||||
} catch {
|
||||
logger.error("❌ Failed to load model: \(currentModel.name) - \(error.localizedDescription)")
|
||||
currentError = .modelLoadFailed
|
||||
return
|
||||
}
|
||||
}
|
||||
guard let whisperContext = whisperContext else {
|
||||
logger.error("❌ Cannot transcribe: Model could not be loaded")
|
||||
currentError = .modelLoadFailed
|
||||
return
|
||||
}
|
||||
logger.notice("🔄 Starting transcription with model: \(currentModel.name)")
|
||||
|
||||
logger.notice("🔄 Starting transcription with model: \(model.displayName)")
|
||||
|
||||
do {
|
||||
let permanentURL = try saveRecordingPermanently(url)
|
||||
let permanentURLString = permanentURL.absoluteString
|
||||
if shouldCancelRecording { return }
|
||||
let data = try readAudioSamples(url)
|
||||
if shouldCancelRecording { return }
|
||||
|
||||
// Get the actual audio duration from the file
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
logger.notice("📊 Audio file duration: \(actualDuration) seconds")
|
||||
|
||||
// Ensure we're using the most recent prompt from UserDefaults
|
||||
let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? whisperPrompt.transcriptionPrompt
|
||||
await whisperContext.setPrompt(currentPrompt)
|
||||
|
||||
if shouldCancelRecording { return }
|
||||
await whisperContext.fullTranscribe(samples: data)
|
||||
if shouldCancelRecording { return }
|
||||
var text = await whisperContext.getTranscription()
|
||||
// --- Core Transcription Logic ---
|
||||
let transcriptionService: TranscriptionService = (model.provider == .local) ? localTranscriptionService : cloudTranscriptionService
|
||||
var text = try await transcriptionService.transcribe(audioURL: url, model: model)
|
||||
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
logger.notice("✅ Transcription completed successfully, length: \(text.count) characters")
|
||||
|
||||
// --- Post-processing and Saving ---
|
||||
let permanentURL = try saveRecordingPermanently(url)
|
||||
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
|
||||
text = WordReplacementService.shared.applyReplacements(to: text)
|
||||
logger.notice("✅ Word replacements applied")
|
||||
}
|
||||
|
||||
let audioAsset = AVURLAsset(url: url)
|
||||
let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration))
|
||||
var promptDetectionResult: PromptDetectionService.PromptDetectionResult? = nil
|
||||
let originalText = text
|
||||
let originalText = text
|
||||
|
||||
if let enhancementService = enhancementService, enhancementService.isConfigured {
|
||||
let detectionResult = promptDetectionService.analyzeText(text, with: enhancementService)
|
||||
@ -330,23 +325,22 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
enhancementService.isConfigured {
|
||||
do {
|
||||
if shouldCancelRecording { return }
|
||||
// Use processed text (without trigger words) for AI enhancement
|
||||
let textForAI = promptDetectionResult?.processedText ?? text
|
||||
let enhancedText = try await enhancementService.enhance(textForAI)
|
||||
let newTranscription = Transcription(
|
||||
text: originalText,
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
enhancedText: enhancedText,
|
||||
audioFileURL: permanentURLString
|
||||
audioFileURL: permanentURL.absoluteString
|
||||
)
|
||||
modelContext.insert(newTranscription)
|
||||
try? modelContext.save()
|
||||
text = enhancedText
|
||||
text = enhancedText
|
||||
} catch {
|
||||
let newTranscription = Transcription(
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
audioFileURL: permanentURLString
|
||||
audioFileURL: permanentURL.absoluteString
|
||||
)
|
||||
modelContext.insert(newTranscription)
|
||||
try? modelContext.save()
|
||||
@ -355,11 +349,12 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
let newTranscription = Transcription(
|
||||
text: originalText,
|
||||
duration: actualDuration,
|
||||
audioFileURL: permanentURLString
|
||||
audioFileURL: permanentURL.absoluteString
|
||||
)
|
||||
modelContext.insert(newTranscription)
|
||||
try? modelContext.save()
|
||||
}
|
||||
|
||||
if case .trialExpired = licenseViewModel.licenseState {
|
||||
text = """
|
||||
Your trial has expired. Upgrade to VoiceInk Pro at tryvoiceink.com/buy
|
||||
@ -367,7 +362,6 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
"""
|
||||
}
|
||||
|
||||
// Add a space to the end of the text
|
||||
text += " "
|
||||
|
||||
SoundManager.shared.playStopSound()
|
||||
@ -386,8 +380,8 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
}
|
||||
try? FileManager.default.removeItem(at: url)
|
||||
|
||||
if let result = promptDetectionResult,
|
||||
let enhancementService = enhancementService,
|
||||
if let result = promptDetectionResult,
|
||||
let enhancementService = enhancementService,
|
||||
result.shouldEnableAI {
|
||||
await promptDetectionService.restoreOriginalSettings(result, to: enhancementService)
|
||||
}
|
||||
@ -396,39 +390,64 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
await cleanupModelResources()
|
||||
|
||||
} catch {
|
||||
currentError = .transcriptionFailed
|
||||
logger.error("❌ Transcription failed: \(error.localizedDescription)")
|
||||
await cleanupModelResources()
|
||||
await dismissMiniRecorder()
|
||||
}
|
||||
}
|
||||
|
||||
private func readAudioSamples(_ url: URL) throws -> [Float] {
|
||||
return try decodeWaveFile(url)
|
||||
}
|
||||
|
||||
private func decodeWaveFile(_ url: URL) throws -> [Float] {
|
||||
let data = try Data(contentsOf: url)
|
||||
let floats = stride(from: 44, to: data.count, by: 2).map {
|
||||
return data[$0..<$0 + 2].withUnsafeBytes {
|
||||
let short = Int16(littleEndian: $0.load(as: Int16.self))
|
||||
return max(-1.0, min(Float(short) / 32767.0, 1.0))
|
||||
}
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
@Published var currentError: WhisperStateError?
|
||||
|
||||
func getEnhancementService() -> AIEnhancementService? {
|
||||
return enhancementService
|
||||
}
|
||||
|
||||
private func saveRecordingPermanently(_ tempURL: URL) throws -> URL {
|
||||
let fileName = "\(UUID().uuidString).wav"
|
||||
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
|
||||
try FileManager.default.copyItem(at: tempURL, to: permanentURL)
|
||||
return permanentURL
|
||||
}
|
||||
|
||||
private func loadCurrentTranscriptionModel() {
|
||||
if let savedModelName = UserDefaults.standard.string(forKey: "CurrentTranscriptionModel"),
|
||||
let savedModel = allAvailableModels.first(where: { $0.name == savedModelName }) {
|
||||
currentTranscriptionModel = savedModel
|
||||
|
||||
// If it's a local model, also set it as currentModel for backward compatibility
|
||||
if let localModel = savedModel as? LocalModel,
|
||||
let whisperModel = availableModels.first(where: { $0.name == localModel.name }) {
|
||||
currentModel = whisperModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function to set any transcription model as default
|
||||
func setDefaultTranscriptionModel(_ model: any TranscriptionModel) async {
|
||||
await MainActor.run {
|
||||
self.currentTranscriptionModel = model
|
||||
UserDefaults.standard.set(model.name, forKey: "CurrentTranscriptionModel")
|
||||
|
||||
// If it's a local model, also update currentModel for backward compatibility
|
||||
if let localModel = model as? LocalModel,
|
||||
let whisperModel = self.availableModels.first(where: { $0.name == localModel.name }) {
|
||||
self.currentModel = whisperModel
|
||||
UserDefaults.standard.set(whisperModel.name, forKey: "CurrentModel")
|
||||
} else {
|
||||
// For cloud models, clear the old currentModel
|
||||
self.currentModel = nil
|
||||
}
|
||||
|
||||
// Enable transcription for cloud models immediately since they don't need loading
|
||||
if model.provider != .local {
|
||||
self.canTranscribe = true
|
||||
self.isModelLoaded = true
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Default transcription model set to: \(model.name) (\(model.provider.rawValue))")
|
||||
|
||||
// Post notification about the model change
|
||||
NotificationCenter.default.post(name: .didChangeModel, object: nil, userInfo: ["modelName": model.name])
|
||||
}
|
||||
|
||||
func getEnhancementService() -> AIEnhancementService? {
|
||||
return enhancementService
|
||||
}
|
||||
}
|
||||
|
||||
struct WhisperModel: Identifiable {
|
||||
@ -473,4 +492,5 @@ private class TaskDelegate: NSObject, URLSessionTaskDelegate {
|
||||
|
||||
extension Notification.Name {
|
||||
static let toggleMiniRecorder = Notification.Name("toggleMiniRecorder")
|
||||
static let didChangeModel = Notification.Name("didChangeModel")
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user