feat: Added support for cloud transcription service

This commit is contained in:
Beingpax 2025-06-09 19:15:50 +05:45
parent a821b201c7
commit 9b507ce615
22 changed files with 1105 additions and 303 deletions

View File

@ -1,6 +1,6 @@
import Foundation
struct PredefinedModel: Identifiable, Hashable {
struct LocalModel: TranscriptionModel {
let id = UUID()
let name: String
let displayName: String
@ -11,6 +11,7 @@ import Foundation
let accuracy: Double
let ramUsage: Double
let hash: String
let provider: ModelProvider = .local
var downloadURL: String {
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/\(filename)"
@ -20,21 +21,9 @@ import Foundation
"\(name).bin"
}
func hash(into hasher: inout Hasher) {
hasher.combine(id)
}
static func == (lhs: PredefinedModel, rhs: PredefinedModel) -> Bool {
lhs.id == rhs.id
}
var isMultilingualModel: Bool {
supportedLanguages.count > 1
}
var language: String {
isMultilingualModel ? "Multilingual" : "English-only"
}
}
enum PredefinedModels {
@ -53,8 +42,9 @@ import Foundation
}
}
static let models: [PredefinedModel] = [
PredefinedModel(
static let models: [any TranscriptionModel] = [
// Local Models
LocalModel(
name: "ggml-tiny",
displayName: "Tiny",
size: "75 MiB",
@ -65,7 +55,7 @@ import Foundation
ramUsage: 0.3,
hash: "bd577a113a864445d4c299885e0cb97d4ba92b5f"
),
PredefinedModel(
LocalModel(
name: "ggml-tiny.en",
displayName: "Tiny (English)",
size: "75 MiB",
@ -76,7 +66,7 @@ import Foundation
ramUsage: 0.3,
hash: "c78c86eb1a8faa21b369bcd33207cc90d64ae9df"
),
PredefinedModel(
LocalModel(
name: "ggml-base.en",
displayName: "Base (English)",
size: "142 MiB",
@ -87,7 +77,7 @@ import Foundation
ramUsage: 0.5,
hash: "137c40403d78fd54d454da0f9bd998f78703390c"
),
PredefinedModel(
LocalModel(
name: "ggml-large-v2",
displayName: "Large v2",
size: "2.9 GiB",
@ -98,7 +88,7 @@ import Foundation
ramUsage: 3.8,
hash: "0f4c8e30f21cf1769f637135f521436792c48186"
),
PredefinedModel(
LocalModel(
name: "ggml-large-v3",
displayName: "Large v3",
size: "2.9 GiB",
@ -109,7 +99,7 @@ import Foundation
ramUsage: 3.9,
hash: "ad82bf6a9043ceed055076d0fd39f5f186ff8062"
),
PredefinedModel(
LocalModel(
name: "ggml-large-v3-turbo",
displayName: "Large v3 Turbo",
size: "1.5 GiB",
@ -121,7 +111,7 @@ import Foundation
ramUsage: 1.8,
hash: "4af2b29d7ec73d781377bfd1758ca957a807e941"
),
PredefinedModel(
LocalModel(
name: "ggml-large-v3-turbo-q5_0",
displayName: "Large v3 Turbo (Quantized)",
size: "547 MiB",
@ -132,6 +122,29 @@ import Foundation
ramUsage: 1.0,
hash: "e050f7970618a659205450ad97eb95a18d69c9ee"
),
// Cloud Models
CloudModel(
name: "whisper-large-v3-turbo",
displayName: "Whisper Large v3 Turbo (Groq)",
description: "Groq's ultra-fast Whisper Large v3 Turbo model with lightning-speed inference",
provider: .groq,
speed: 0.7,
accuracy: 0.96,
isMultilingual: true,
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true)
),
CloudModel(
name: "scribe_v1",
displayName: "Scribe v1 (ElevenLabs)",
description: "ElevenLabs' Scribe model for fast and accurate transcription.",
provider: .elevenLabs,
speed: 0.8,
accuracy: 0.95,
isMultilingual: true,
supportedLanguages: getLanguageDictionary(isMultilingual: true, isLargeV3: true)
),
]
static let allLanguages = [

View File

@ -0,0 +1,57 @@
import Foundation
// Enum to differentiate between model providers
enum ModelProvider: String, Codable, Hashable, CaseIterable {
case local = "Local"
case groq = "Groq"
case elevenLabs = "ElevenLabs"
// Future providers can be added here
}
// A unified protocol for any transcription model
protocol TranscriptionModel: Identifiable, Hashable {
var id: UUID { get }
var name: String { get }
var displayName: String { get }
var description: String { get }
var provider: ModelProvider { get }
// Language capabilities
var isMultilingualModel: Bool { get }
var supportedLanguages: [String: String] { get }
}
extension TranscriptionModel {
func hash(into hasher: inout Hasher) {
hasher.combine(id)
}
var language: String {
isMultilingualModel ? "Multilingual" : "English-only"
}
}
// A new struct for cloud models
struct CloudModel: TranscriptionModel {
let id: UUID
let name: String
let displayName: String
let description: String
let provider: ModelProvider
let speed: Double
let accuracy: Double
let isMultilingualModel: Bool
let supportedLanguages: [String: String]
init(id: UUID = UUID(), name: String, displayName: String, description: String, provider: ModelProvider, speed: Double, accuracy: Double, isMultilingual: Bool, supportedLanguages: [String: String]) {
self.id = id
self.name = name
self.displayName = displayName
self.description = description
self.provider = provider
self.speed = speed
self.accuracy = accuracy
self.isMultilingualModel = isMultilingual
self.supportedLanguages = supportedLanguages
}
}

View File

@ -116,20 +116,35 @@ class ActiveWindowService: ObservableObject {
}
if let whisperState = self.whisperState,
let modelName = config.selectedWhisperModel {
let models = await MainActor.run { whisperState.availableModels }
if let selectedModel = models.first(where: { $0.name == modelName }) {
let modelName = config.selectedWhisperModel,
let selectedModel = await whisperState.allAvailableModels.first(where: { $0.name == modelName }) {
let currentModelName = await MainActor.run { whisperState.currentTranscriptionModel?.name }
// Only change the model if it's different from the current one.
if currentModelName != modelName {
// Set the new model as default. This works for both local and cloud models.
await whisperState.setDefaultTranscriptionModel(selectedModel)
let currentModelName = await MainActor.run { whisperState.currentModel?.name }
if currentModelName != modelName {
await whisperState.setDefaultModel(selectedModel)
// The cleanup and load cycle is only necessary for local models.
if selectedModel.provider == .local {
// Unload any previously loaded model to free up memory.
await whisperState.cleanupModelResources()
do {
try await whisperState.loadModel(selectedModel)
} catch {
// Load the new local model into memory.
if let localModel = await whisperState.availableModels.first(where: { $0.name == selectedModel.name }) {
do {
try await whisperState.loadModel(localModel)
logger.info("✅ Power Mode: Successfully loaded local model '\(localModel.name)'.")
} catch {
logger.error("❌ Power Mode: Failed to load local model '\(localModel.name)': \(error.localizedDescription)")
}
}
} else {
// For cloud models, no in-memory loading is needed, but we should still
// clean up if the *previous* model was a local one.
await whisperState.cleanupModelResources()
logger.info("✅ Power Mode: Switched to cloud model '\(selectedModel.name)'. No local load needed.")
}
}
}

View File

@ -375,9 +375,12 @@ struct ConfigurationView: View {
.foregroundColor(.secondary)
Picker("", selection: modelBinding) {
ForEach(whisperState.availableModels) { model in
let displayName = whisperState.predefinedModels.first { $0.name == model.name }?.displayName ?? model.name
Text(displayName).tag(model.name as String?)
Text("Default")
.tag(nil as String?)
ForEach(whisperState.usableModels, id: \.name) { model in
Text(model.displayName)
.tag(model.name as String?)
}
}
.labelsHidden()
@ -387,7 +390,7 @@ struct ConfigurationView: View {
// Language Selection Subsection
if let selectedModel = effectiveModelName,
let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }),
let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }),
modelInfo.isMultilingualModel {
// Create a simple binding that uses UserDefaults language if nil
@ -416,7 +419,7 @@ struct ConfigurationView: View {
.frame(maxWidth: .infinity)
}
} else if let selectedModel = effectiveModelName,
let modelInfo = whisperState.predefinedModels.first(where: { $0.name == selectedModel }),
let modelInfo = whisperState.allAvailableModels.first(where: { $0.name == selectedModel }),
!modelInfo.isMultilingualModel {
// Silently set to English without showing UI
EmptyView()
@ -488,7 +491,7 @@ struct ConfigurationView: View {
.frame(maxWidth: .infinity, alignment: .leading)
} else {
Picker("", selection: providerBinding) {
ForEach(aiService.connectedProviders, id: \.self) { provider in
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in
Text(provider.rawValue).tag(provider)
}
}

View File

@ -105,7 +105,7 @@ struct ConfigurationRow: View {
private var selectedModel: String? {
if let modelName = config.selectedWhisperModel,
let model = whisperState.predefinedModels.first(where: { $0.name == modelName }) {
let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }) {
return model.displayName
}
return "Default"
@ -117,7 +117,7 @@ struct ConfigurationRow: View {
if langCode == "en" { return "English" }
if let modelName = config.selectedWhisperModel,
let model = whisperState.predefinedModels.first(where: { $0.name == modelName }),
let model = whisperState.allAvailableModels.first(where: { $0.name == modelName }),
let langName = model.supportedLanguages[langCode] {
return langName
}

View File

@ -9,6 +9,7 @@ enum AIProvider: String, CaseIterable {
case anthropic = "Anthropic"
case mistral = "Mistral"
case ollama = "Ollama"
case elevenLabs = "ElevenLabs"
case custom = "Custom"
var baseURL: String {
@ -25,6 +26,8 @@ enum AIProvider: String, CaseIterable {
return "https://api.anthropic.com/v1/messages"
case .mistral:
return "https://api.mistral.ai/v1/chat/completions"
case .elevenLabs:
return "https://api.elevenlabs.io/v1/speech-to-text"
case .ollama:
return UserDefaults.standard.string(forKey: "ollamaBaseURL") ?? "http://localhost:11434"
case .custom:
@ -46,6 +49,8 @@ enum AIProvider: String, CaseIterable {
return "claude-3-5-sonnet-20241022"
case .mistral:
return "mistral-large-latest"
case .elevenLabs:
return "scribe_v1"
case .ollama:
return UserDefaults.standard.string(forKey: "ollamaSelectedModel") ?? "mistral"
case .custom:
@ -88,6 +93,8 @@ enum AIProvider: String, CaseIterable {
"mistral-small-latest",
"mistral-saba-latest"
]
case .elevenLabs:
return ["scribe_v1", "scribe_v1_experimental"]
case .ollama:
return []
case .custom:
@ -257,6 +264,8 @@ class AIService: ObservableObject {
verifyGeminiAPIKey(key, completion: completion)
case .anthropic:
verifyAnthropicAPIKey(key, completion: completion)
case .elevenLabs:
verifyElevenLabsAPIKey(key, completion: completion)
default:
verifyOpenAICompatibleAPIKey(key, completion: completion)
}
@ -369,6 +378,30 @@ class AIService: ObservableObject {
}.resume()
}
private func verifyElevenLabsAPIKey(_ key: String, completion: @escaping (Bool) -> Void) {
let url = URL(string: "https://api.elevenlabs.io/v1/user")!
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue(key, forHTTPHeaderField: "xi-api-key")
URLSession.shared.dataTask(with: request) { _, response, error in
if let error = error {
self.logger.error("ElevenLabs API key verification failed: \(error.localizedDescription)")
completion(false)
return
}
if let httpResponse = response as? HTTPURLResponse {
completion(httpResponse.statusCode == 200)
} else {
completion(false)
}
}.resume()
}
func clearAPIKey() {
guard selectedProvider.requiresAPIKey else { return }

View File

@ -173,4 +173,47 @@ class AudioProcessor {
logger.notice("✅ Successfully converted \(samples.count) samples")
return samples
}
}
func saveSamplesAsWav(samples: [Float], to url: URL) throws {
let outputFormat = AVAudioFormat(
commonFormat: .pcmFormatInt16,
sampleRate: AudioFormat.targetSampleRate,
channels: AudioFormat.targetChannels,
interleaved: true
)
guard let outputFormat = outputFormat else {
throw AudioProcessingError.unsupportedFormat
}
let buffer = AVAudioPCMBuffer(
pcmFormat: outputFormat,
frameCapacity: AVAudioFrameCount(samples.count)
)
guard let buffer = buffer else {
throw AudioProcessingError.conversionFailed
}
// Convert float samples to int16
let int16Samples = samples.map { max(-1.0, min(1.0, $0)) * Float(Int16.max) }.map { Int16($0) }
// Copy samples to buffer
int16Samples.withUnsafeBufferPointer { int16Buffer in
let int16Pointer = int16Buffer.baseAddress!
buffer.int16ChannelData![0].update(from: int16Pointer, count: int16Samples.count)
}
buffer.frameLength = AVAudioFrameCount(samples.count)
// Create audio file
let audioFile = try AVAudioFile(
forWriting: url,
settings: outputFormat.settings,
commonFormat: .pcmFormatInt16,
interleaved: true
)
try audioFile.write(from: buffer)
logger.notice("✅ Successfully saved processed audio to \(url.lastPathComponent)")
}
}

View File

@ -15,10 +15,13 @@ class AudioTranscriptionManager: ObservableObject {
@Published var errorMessage: String?
private var currentTask: Task<Void, Error>?
private var whisperContext: WhisperContext?
private let audioProcessor = AudioProcessor()
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionManager")
// Transcription services - will be initialized when needed
private var localTranscriptionService: LocalTranscriptionService?
private let cloudTranscriptionService = CloudTranscriptionService()
enum ProcessingPhase {
case idle
case loading
@ -58,12 +61,14 @@ class AudioTranscriptionManager: ObservableObject {
currentTask = Task {
do {
guard let currentModel = whisperState.currentModel else {
guard let currentModel = whisperState.currentTranscriptionModel else {
throw TranscriptionError.noModelSelected
}
// Load Whisper model
whisperContext = try await WhisperContext.createContext(path: currentModel.url.path)
// Initialize local transcription service if needed
if localTranscriptionService == nil {
localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
}
// Process audio file
processingPhase = .processingAudio
@ -82,15 +87,19 @@ class AudioTranscriptionManager: ObservableObject {
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
try FileManager.default.createDirectory(at: recordingsDirectory, withIntermediateDirectories: true)
try FileManager.default.copyItem(at: url, to: permanentURL)
try audioProcessor.saveSamplesAsWav(samples: samples, to: permanentURL)
// Transcribe
// Transcribe using appropriate service
processingPhase = .transcribing
await whisperContext?.setPrompt(whisperState.whisperPrompt.transcriptionPrompt)
try await whisperContext?.fullTranscribe(samples: samples)
var text = await whisperContext?.getTranscription() ?? ""
var text: String
if currentModel.provider == .local {
text = try await localTranscriptionService!.transcribe(audioURL: permanentURL, model: currentModel)
} else {
text = try await cloudTranscriptionService.transcribe(audioURL: permanentURL, model: currentModel)
}
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
text = WhisperTextFormatter.format(text)
// Apply word replacements if enabled
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
@ -115,7 +124,6 @@ class AudioTranscriptionManager: ObservableObject {
currentTranscription = transcription
} catch {
logger.error("Enhancement failed: \(error.localizedDescription)")
messageLog += "Enhancement failed: \(error.localizedDescription). Using original transcription.\n"
let transcription = Transcription(
text: text,
duration: duration,
@ -148,28 +156,20 @@ class AudioTranscriptionManager: ObservableObject {
func cancelProcessing() {
currentTask?.cancel()
cleanupResources()
}
private func finishProcessing() {
isProcessing = false
processingPhase = .idle
currentTask = nil
cleanupResources()
}
private func handleError(_ error: Error) {
logger.error("Transcription error: \(error.localizedDescription)")
errorMessage = error.localizedDescription
messageLog += "Error: \(error.localizedDescription)\n"
isProcessing = false
processingPhase = .idle
currentTask = nil
cleanupResources()
}
private func cleanupResources() {
whisperContext = nil
}
}

View File

@ -10,12 +10,15 @@ class AudioTranscriptionService: ObservableObject {
@Published var messageLog = ""
@Published var currentError: TranscriptionError?
private var whisperContext: WhisperContext?
private let modelContext: ModelContext
private let enhancementService: AIEnhancementService?
private let whisperState: WhisperState
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "AudioTranscriptionService")
// Transcription services
private let localTranscriptionService: LocalTranscriptionService
private let cloudTranscriptionService = CloudTranscriptionService()
enum TranscriptionError: Error {
case noAudioFile
case transcriptionFailed
@ -27,82 +30,67 @@ class AudioTranscriptionService: ObservableObject {
self.modelContext = modelContext
self.whisperState = whisperState
self.enhancementService = whisperState.enhancementService
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: whisperState.modelsDirectory, whisperState: whisperState)
}
func retranscribeAudio(from url: URL, using whisperModel: WhisperModel) async throws -> Transcription {
func retranscribeAudio(from url: URL, using model: any TranscriptionModel) async throws -> Transcription {
guard FileManager.default.fileExists(atPath: url.path) else {
throw TranscriptionError.noAudioFile
}
await MainActor.run {
isTranscribing = true
messageLog = "Loading model...\n"
messageLog = "Starting retranscription...\n"
}
// Load the whisper model if needed
if whisperContext == nil {
do {
whisperContext = try await WhisperContext.createContext(path: whisperModel.url.path)
messageLog += "Model loaded successfully.\n"
} catch {
logger.error("❌ Failed to load model: \(error.localizedDescription)")
messageLog += "Failed to load model: \(error.localizedDescription)\n"
isTranscribing = false
throw TranscriptionError.modelNotLoaded
do {
// Delegate transcription to appropriate service
var text: String
if model.provider == .local {
messageLog += "Using local transcription service...\n"
text = try await localTranscriptionService.transcribe(audioURL: url, model: model)
messageLog += "Local transcription completed.\n"
} else {
messageLog += "Using cloud transcription service...\n"
text = try await cloudTranscriptionService.transcribe(audioURL: url, model: model)
messageLog += "Cloud transcription completed.\n"
}
}
guard let whisperContext = whisperContext else {
isTranscribing = false
throw TranscriptionError.modelNotLoaded
}
// Get audio duration
let audioAsset = AVURLAsset(url: url)
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
// Create a permanent copy of the audio file
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("Recordings")
let fileName = "retranscribed_\(UUID().uuidString).wav"
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
do {
try FileManager.default.copyItem(at: url, to: permanentURL)
} catch {
logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)")
messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n"
isTranscribing = false
throw error
}
let permanentURLString = permanentURL.absoluteString
// Transcribe the audio
messageLog += "Transcribing audio...\n"
do {
// Read audio samples
let samples = try readAudioSamples(permanentURL)
// Process with Whisper - using the same prompt as WhisperState
messageLog += "Setting prompt: \(whisperState.whisperPrompt.transcriptionPrompt)\n"
await whisperContext.setPrompt(whisperState.whisperPrompt.transcriptionPrompt)
try await whisperContext.fullTranscribe(samples: samples)
var text = await whisperContext.getTranscription()
// Common post-processing for both local and cloud transcriptions
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
logger.notice("✅ Retranscription completed successfully, length: \(text.count) characters")
// Apply word replacements if enabled
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
text = WordReplacementService.shared.applyReplacements(to: text)
messageLog += "Word replacements applied.\n"
logger.notice("✅ Word replacements applied")
}
// Apply AI enhancement if enabled - using the same enhancement service as WhisperState
// Get audio duration
let audioAsset = AVURLAsset(url: url)
let duration = CMTimeGetSeconds(try await audioAsset.load(.duration))
// Create a permanent copy of the audio file
let recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("Recordings")
let fileName = "retranscribed_\(UUID().uuidString).wav"
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
do {
try FileManager.default.copyItem(at: url, to: permanentURL)
} catch {
logger.error("❌ Failed to create permanent copy of audio: \(error.localizedDescription)")
messageLog += "Failed to create permanent copy of audio: \(error.localizedDescription)\n"
isTranscribing = false
throw error
}
let permanentURLString = permanentURL.absoluteString
// Apply AI enhancement if enabled
if let enhancementService = enhancementService,
enhancementService.isEnhancementEnabled,
enhancementService.isConfigured {
@ -182,19 +170,4 @@ class AudioTranscriptionService: ObservableObject {
throw error
}
}
private func readAudioSamples(_ url: URL) throws -> [Float] {
return try decodeWaveFile(url)
}
private func decodeWaveFile(_ url: URL) throws -> [Float] {
let data = try Data(contentsOf: url)
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
}
}

View File

@ -0,0 +1,211 @@
import Foundation
import os
enum CloudTranscriptionError: Error, LocalizedError {
case unsupportedProvider
case missingAPIKey
case invalidAPIKey
case audioFileNotFound
case apiRequestFailed(statusCode: Int, message: String)
case networkError(Error)
case noTranscriptionReturned
case dataEncodingError
var errorDescription: String? {
switch self {
case .unsupportedProvider:
return "The model provider is not supported by this service."
case .missingAPIKey:
return "API key for this service is missing. Please configure it in the settings."
case .invalidAPIKey:
return "The provided API key is invalid."
case .audioFileNotFound:
return "The audio file to transcribe could not be found."
case .apiRequestFailed(let statusCode, let message):
return "The API request failed with status code \(statusCode): \(message)"
case .networkError(let error):
return "A network error occurred: \(error.localizedDescription)"
case .noTranscriptionReturned:
return "The API returned an empty or invalid response."
case .dataEncodingError:
return "Failed to encode the request body."
}
}
}
class CloudTranscriptionService: TranscriptionService {
private struct APIConfig {
let url: URL
let apiKey: String
let modelName: String
}
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
let config = try getAPIConfig(for: model)
var request: URLRequest
var body: Data
switch model.provider {
case .elevenLabs:
let boundary = "Boundary-\(UUID().uuidString)"
var elevenLabsRequest = URLRequest(url: config.url)
elevenLabsRequest.httpMethod = "POST"
elevenLabsRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
elevenLabsRequest.setValue(config.apiKey, forHTTPHeaderField: "xi-api-key")
body = try createElevenLabsRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
request = elevenLabsRequest
case .groq:
let boundary = "Boundary-\(UUID().uuidString)"
var openAICompatibleRequest = URLRequest(url: config.url)
openAICompatibleRequest.httpMethod = "POST"
openAICompatibleRequest.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
openAICompatibleRequest.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization")
body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
request = openAICompatibleRequest
default:
throw CloudTranscriptionError.unsupportedProvider
}
let (data, response) = try await URLSession.shared.upload(for: request, from: body)
guard let httpResponse = response as? HTTPURLResponse else {
throw CloudTranscriptionError.networkError(URLError(.badServerResponse))
}
if !(200...299).contains(httpResponse.statusCode) {
let errorMessage = String(data: data, encoding: .utf8) ?? "No error message"
throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage)
}
do {
let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data)
return transcriptionResponse.text
} catch {
throw CloudTranscriptionError.noTranscriptionReturned
}
}
private func getAPIConfig(for model: any TranscriptionModel) throws -> APIConfig {
let providerKey: String
let apiURL: URL
switch model.provider {
case .groq:
providerKey = "GROQ"
apiURL = URL(string: "https://api.groq.com/openai/v1/audio/transcriptions")!
case .elevenLabs:
providerKey = "ElevenLabs"
apiURL = URL(string: "https://api.elevenlabs.io/v1/speech-to-text")!
default:
throw CloudTranscriptionError.unsupportedProvider
}
guard let apiKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey"), !apiKey.isEmpty else {
throw CloudTranscriptionError.missingAPIKey
}
return APIConfig(url: apiURL, apiKey: apiKey, modelName: model.name)
}
private func createElevenLabsRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
var body = Data()
let crlf = "\r\n"
guard let audioData = try? Data(contentsOf: audioURL) else {
throw CloudTranscriptionError.audioFileNotFound
}
// File
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
body.append(audioData)
body.append(crlf.data(using: .utf8)!)
// Model ID
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"model_id\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(modelName.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"language_code\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(selectedLanguage.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
}
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
return body
}
private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
var body = Data()
let crlf = "\r\n"
guard let audioData = try? Data(contentsOf: audioURL) else {
throw CloudTranscriptionError.audioFileNotFound
}
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
body.append(audioData)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(modelName.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(selectedLanguage.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
}
// Include prompt for OpenAI-compatible APIs
if !prompt.isEmpty {
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(prompt.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
}
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append("json".data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append("0".data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
return body
}
private struct TranscriptionResponse: Decodable {
let text: String
let language: String?
let duration: Double?
let x_groq: GroqMetadata?
struct GroqMetadata: Decodable {
let id: String?
}
}
}

View File

@ -0,0 +1,87 @@
import Foundation
import AVFoundation
import os
class LocalTranscriptionService: TranscriptionService {
private var whisperContext: WhisperContext?
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "LocalTranscriptionService")
private let modelsDirectory: URL
private weak var whisperState: WhisperState?
init(modelsDirectory: URL, whisperState: WhisperState? = nil) {
self.modelsDirectory = modelsDirectory
self.whisperState = whisperState
}
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
guard let localModel = model as? LocalModel else {
throw WhisperError.couldNotInitializeContext
}
logger.notice("Initiating local transcription for model: \(localModel.displayName)")
// Check if the required model is already loaded in WhisperState
if let whisperState = whisperState,
await whisperState.isModelLoaded,
let loadedContext = await whisperState.whisperContext,
await whisperState.currentModel?.name == localModel.name {
logger.notice("✅ Using already loaded model: \(localModel.name)")
whisperContext = loadedContext
} else {
// Model not loaded or wrong model loaded, proceed with loading
let modelURL = modelsDirectory.appendingPathComponent(localModel.filename)
guard FileManager.default.fileExists(atPath: modelURL.path) else {
logger.error("Model file not found at path: \(modelURL.path)")
throw WhisperError.couldNotInitializeContext
}
logger.notice("Loading model: \(localModel.name)")
do {
whisperContext = try await WhisperContext.createContext(path: modelURL.path)
} catch {
logger.error("Failed to load model: \(localModel.name) - \(error.localizedDescription)")
throw WhisperError.couldNotInitializeContext
}
}
guard let whisperContext = whisperContext else {
logger.error("Cannot transcribe: Model could not be loaded")
throw WhisperError.couldNotInitializeContext
}
// Read audio data
let data = try readAudioSamples(audioURL)
// Set prompt
let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
await whisperContext.setPrompt(currentPrompt)
// Transcribe
await whisperContext.fullTranscribe(samples: data)
let text = await whisperContext.getTranscription()
logger.notice("✅ Local transcription completed successfully.")
// Only release resources if we created a new context (not using the shared one)
if await whisperState?.whisperContext !== whisperContext {
await whisperContext.releaseResources()
self.whisperContext = nil
}
return text
}
private func readAudioSamples(_ url: URL) throws -> [Float] {
let data = try Data(contentsOf: url)
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
}
}

View File

@ -0,0 +1,14 @@
import Foundation
/// A protocol defining the interface for a transcription service.
/// This allows for a unified way to handle both local and cloud-based transcription models.
protocol TranscriptionService {
/// Transcribes the audio from a given file URL.
///
/// - Parameters:
/// - audioURL: The URL of the audio file to transcribe.
/// - model: The `TranscriptionModel` to use for transcription. This provides context about the provider (local, OpenAI, etc.).
/// - Returns: The transcribed text as a `String`.
/// - Throws: An error if the transcription fails.
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String
}

View File

@ -44,7 +44,7 @@ struct APIKeyManagementView: View {
// Provider Selection
Picker("AI Provider", selection: $aiService.selectedProvider) {
ForEach(AIProvider.allCases, id: \.self) { provider in
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in
Text(provider.rawValue).tag(provider)
}
}
@ -385,6 +385,8 @@ struct APIKeyManagementView: View {
URL(string: "https://console.anthropic.com/settings/keys")!
case .mistral:
URL(string: "https://console.mistral.ai/api-keys")!
case .elevenLabs:
URL(string: "https://elevenlabs.io/speech-synthesis")!
case .ollama, .custom:
URL(string: "")! // This case should never be reached
}

View File

@ -388,7 +388,7 @@ struct AudioPlayerView: View {
}
private func retranscribeAudio() {
guard let currentModel = whisperState.currentModel else {
guard let currentTranscriptionModel = whisperState.currentTranscriptionModel else {
errorMessage = "No transcription model selected"
showRetranscribeError = true
DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
@ -401,7 +401,7 @@ struct AudioPlayerView: View {
Task {
do {
let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentModel)
let _ = try await transcriptionService.retranscribeAudio(from: url, using: currentTranscriptionModel)
await MainActor.run {
isRetranscribing = false
showRetranscribeSuccess = true

View File

@ -26,23 +26,18 @@ struct LanguageSelectionView: View {
// Function to check if current model is multilingual
private func isMultilingualModel() -> Bool {
guard let currentModel = whisperState.currentModel,
let predefinedModel = PredefinedModels.models.first(where: { $0.name == currentModel.name }) else {
guard let currentModel = whisperState.currentTranscriptionModel else {
return false
}
return predefinedModel.isMultilingualModel
return currentModel.isMultilingualModel
}
// Function to get current model's supported languages
private func getCurrentModelLanguages() -> [String: String] {
guard let currentModel = whisperState.currentModel,
let predefinedModel = PredefinedModels.models.first(where: {
$0.name == currentModel.name
})
else {
guard let currentModel = whisperState.currentTranscriptionModel else {
return ["en": "English"] // Default to English if no model found
}
return predefinedModel.supportedLanguages
return currentModel.supportedLanguages
}
// Get the display name of the current language
@ -74,16 +69,13 @@ struct LanguageSelectionView: View {
Text("Transcription Language")
.font(.headline)
if let currentModel = whisperState.currentModel,
let predefinedModel = PredefinedModels.models.first(where: {
$0.name == currentModel.name
})
if let currentModel = whisperState.currentTranscriptionModel
{
if isMultilingualModel() {
VStack(alignment: .leading, spacing: 8) {
Picker("Select Language", selection: $selectedLanguage) {
ForEach(
predefinedModel.supportedLanguages.sorted(by: {
currentModel.supportedLanguages.sorted(by: {
if $0.key == "auto" { return true }
if $1.key == "auto" { return false }
return $0.value < $1.value
@ -97,7 +89,7 @@ struct LanguageSelectionView: View {
updateLanguage(newValue)
}
Text("Current model: \(predefinedModel.displayName)")
Text("Current model: \(currentModel.displayName)")
.font(.caption)
.foregroundColor(.secondary)
@ -114,7 +106,7 @@ struct LanguageSelectionView: View {
.font(.subheadline)
.foregroundColor(.primary)
Text("Current model: \(predefinedModel.displayName)")
Text("Current model: \(currentModel.displayName)")
.font(.caption)
.foregroundColor(.secondary)

View File

@ -80,26 +80,21 @@ struct MenuBarView: View {
}
Menu {
ForEach(whisperState.availableModels) { model in
ForEach(whisperState.usableModels, id: \.id) { model in
Button {
Task {
await whisperState.setDefaultModel(model)
await whisperState.setDefaultTranscriptionModel(model)
}
} label: {
HStack {
Text(PredefinedModels.models.first { $0.name == model.name }?.displayName ?? model.name)
if whisperState.currentModel?.name == model.name {
Text(model.displayName)
if whisperState.currentTranscriptionModel?.id == model.id {
Image(systemName: "checkmark")
}
}
}
}
if whisperState.availableModels.isEmpty {
Text("No models downloaded")
.foregroundColor(.secondary)
}
Divider()
Button("Manage Models") {
@ -107,7 +102,7 @@ struct MenuBarView: View {
}
} label: {
HStack {
Text("Model: \(PredefinedModels.models.first { $0.name == whisperState.currentModel?.name }?.displayName ?? "None")")
Text("Model: \(whisperState.currentTranscriptionModel?.displayName ?? "None")")
Image(systemName: "chevron.up.chevron.down")
.font(.system(size: 10))
}

View File

@ -2,7 +2,49 @@ import SwiftUI
import AppKit
struct ModelCardRowView: View {
let model: PredefinedModel
let model: any TranscriptionModel
let isDownloaded: Bool
let isCurrent: Bool
let downloadProgress: [String: Double]
let modelURL: URL?
// Actions
var deleteAction: () -> Void
var setDefaultAction: () -> Void
var downloadAction: () -> Void
var body: some View {
Group {
switch model.provider {
case .local:
if let localModel = model as? LocalModel {
LocalModelCardView(
model: localModel,
isDownloaded: isDownloaded,
isCurrent: isCurrent,
downloadProgress: downloadProgress,
modelURL: modelURL,
deleteAction: deleteAction,
setDefaultAction: setDefaultAction,
downloadAction: downloadAction
)
}
case .groq, .elevenLabs:
if let cloudModel = model as? CloudModel {
CloudModelCardView(
model: cloudModel,
isCurrent: isCurrent,
setDefaultAction: setDefaultAction
)
}
}
}
}
}
// MARK: - Local Model Card View
struct LocalModelCardView: View {
let model: LocalModel
let isDownloaded: Bool
let isCurrent: Bool
let downloadProgress: [String: Double]
@ -36,8 +78,6 @@ struct ModelCardRowView: View {
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
}
// MARK: - Components
private var headerSection: some View {
HStack(alignment: .firstTextBaseline) {
Text(model.displayName)
@ -162,14 +202,210 @@ struct ModelCardRowView: View {
Label("Delete Model", systemImage: "trash")
}
if isDownloaded {
Button {
if let modelURL = modelURL {
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
}
} label: {
Label("Show in Finder", systemImage: "folder")
Button {
if let modelURL = modelURL {
NSWorkspace.shared.selectFile(modelURL.path, inFileViewerRootedAtPath: "")
}
} label: {
Label("Show in Finder", systemImage: "folder")
}
} label: {
Image(systemName: "ellipsis.circle")
.font(.system(size: 14))
}
.menuStyle(.borderlessButton)
.menuIndicator(.hidden)
.frame(width: 20, height: 20)
}
}
}
}
// MARK: - Cloud Model Card View
struct CloudModelCardView: View {
let model: CloudModel
let isCurrent: Bool
var setDefaultAction: () -> Void
@StateObject private var aiService = AIService()
@State private var isExpanded = false
@State private var apiKey = ""
@State private var isVerifying = false
@State private var verificationStatus: VerificationStatus = .none
enum VerificationStatus {
case none, verifying, success, failure
}
private var isConfigured: Bool {
guard let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") else {
return false
}
return !savedKey.isEmpty
}
private var providerKey: String {
switch model.provider {
case .groq:
return "GROQ"
case .elevenLabs:
return "ElevenLabs"
default:
return model.provider.rawValue
}
}
var body: some View {
VStack(alignment: .leading, spacing: 0) {
// Main card content
HStack(alignment: .top, spacing: 16) {
VStack(alignment: .leading, spacing: 6) {
headerSection
metadataSection
descriptionSection
}
.frame(maxWidth: .infinity, alignment: .leading)
actionSection
}
.padding(16)
// Expandable configuration section
if isExpanded {
Divider()
.padding(.horizontal, 16)
configurationSection
.padding(16)
}
}
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
.onAppear {
loadSavedAPIKey()
}
}
private var headerSection: some View {
HStack(alignment: .firstTextBaseline) {
Text(model.displayName)
.font(.system(size: 13, weight: .semibold))
.foregroundColor(Color(.labelColor))
statusBadge
Spacer()
}
}
private var statusBadge: some View {
Group {
if isCurrent {
Text("Default")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color.accentColor))
.foregroundColor(.white)
} else if isConfigured {
Text("Configured")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color(.systemGreen).opacity(0.2)))
.foregroundColor(Color(.systemGreen))
} else {
Text("Setup Required")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color(.systemOrange).opacity(0.2)))
.foregroundColor(Color(.systemOrange))
}
}
}
private var metadataSection: some View {
HStack(spacing: 16) {
// Provider
Label(model.provider.rawValue, systemImage: "cloud")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
// Language
Label(model.language, systemImage: "globe")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
// Speed
HStack(spacing: 4) {
Text("Speed")
.font(.system(size: 11, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
progressDotsWithNumber(value: model.speed * 10)
}
// Accuracy
HStack(spacing: 4) {
Text("Accuracy")
.font(.system(size: 11, weight: .medium))
.foregroundColor(Color(.secondaryLabelColor))
progressDotsWithNumber(value: model.accuracy * 10)
}
}
}
private var descriptionSection: some View {
Text(model.description)
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(2)
.fixedSize(horizontal: false, vertical: true)
.padding(.top, 4)
}
private var actionSection: some View {
HStack(spacing: 8) {
if isCurrent {
Text("Default Model")
.font(.system(size: 12))
.foregroundColor(Color(.secondaryLabelColor))
} else if isConfigured {
Button(action: setDefaultAction) {
Text("Set as Default")
.font(.system(size: 12))
}
.buttonStyle(.bordered)
.controlSize(.small)
} else {
Button(action: {
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded.toggle()
}
}) {
HStack(spacing: 4) {
Text("Configure")
.font(.system(size: 12, weight: .medium))
Image(systemName: "gear")
.font(.system(size: 12, weight: .medium))
}
.foregroundColor(.white)
.padding(.horizontal, 12)
.padding(.vertical, 6)
.background(
Capsule()
.fill(Color(.controlAccentColor))
.shadow(color: Color(.controlAccentColor).opacity(0.2), radius: 2, x: 0, y: 1)
)
}
.buttonStyle(.plain)
}
if isConfigured {
Menu {
Button {
clearAPIKey()
} label: {
Label("Remove API Key", systemImage: "trash")
}
} label: {
Image(systemName: "ellipsis.circle")
@ -182,40 +418,134 @@ struct ModelCardRowView: View {
}
}
// MARK: - Helpers
private var downloadComponents: [(String, Double)] {
[
("Model", downloadProgress[model.name + "_main"] ?? 0),
("CoreML", downloadProgress[model.name + "_coreml"] ?? 0)
].filter { $0.1 > 0 }
}
private func progressDotsWithNumber(value: Double) -> some View {
HStack(spacing: 4) {
progressDots(value: value)
Text(String(format: "%.1f", value))
.font(.system(size: 10, weight: .medium, design: .monospaced))
.foregroundColor(Color(.secondaryLabelColor))
}
}
private func progressDots(value: Double) -> some View {
HStack(spacing: 2) {
ForEach(0..<5) { index in
Circle()
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
.frame(width: 6, height: 6)
private var configurationSection: some View {
VStack(alignment: .leading, spacing: 12) {
Text("API Key Configuration")
.font(.system(size: 13, weight: .semibold))
.foregroundColor(Color(.labelColor))
HStack(spacing: 8) {
SecureField("Enter your \(model.provider.rawValue) API key", text: $apiKey)
.textFieldStyle(.roundedBorder)
.disabled(isVerifying)
Button(action: verifyAPIKey) {
HStack(spacing: 4) {
if isVerifying {
ProgressView()
.scaleEffect(0.7)
.frame(width: 12, height: 12)
} else {
Image(systemName: verificationStatus == .success ? "checkmark" : "checkmark.shield")
.font(.system(size: 12, weight: .medium))
}
Text(isVerifying ? "Verifying..." : "Verify")
.font(.system(size: 12, weight: .medium))
}
.foregroundColor(.white)
.padding(.horizontal, 12)
.padding(.vertical, 6)
.background(
Capsule()
.fill(verificationStatus == .success ? Color(.systemGreen) : Color(.controlAccentColor))
)
}
.buttonStyle(.plain)
.disabled(apiKey.isEmpty || isVerifying)
}
if verificationStatus == .failure {
Text("Invalid API key. Please check your key and try again.")
.font(.caption)
.foregroundColor(Color(.systemRed))
} else if verificationStatus == .success {
Text("API key verified successfully!")
.font(.caption)
.foregroundColor(Color(.systemGreen))
}
}
}
private func performanceColor(value: Double) -> Color {
switch value {
case 0.8...1.0: return Color(.systemGreen)
case 0.6..<0.8: return Color(.systemYellow)
case 0.4..<0.6: return Color(.systemOrange)
default: return Color(.systemRed)
private func loadSavedAPIKey() {
if let savedKey = UserDefaults.standard.string(forKey: "\(providerKey)APIKey") {
apiKey = savedKey
verificationStatus = .success
}
}
private func verifyAPIKey() {
guard !apiKey.isEmpty else { return }
isVerifying = true
verificationStatus = .verifying
// Set the provider in AIService temporarily for verification
let originalProvider = aiService.selectedProvider
if model.provider == .groq {
aiService.selectedProvider = .groq
} else if model.provider == .elevenLabs {
aiService.selectedProvider = .elevenLabs
}
aiService.verifyAPIKey(apiKey) { [self] isValid in
DispatchQueue.main.async {
self.isVerifying = false
if isValid {
self.verificationStatus = .success
// Save the API key
UserDefaults.standard.set(self.apiKey, forKey: "\(self.providerKey)APIKey")
// Collapse the configuration section after successful verification
withAnimation(.easeInOut(duration: 0.3)) {
self.isExpanded = false
}
} else {
self.verificationStatus = .failure
}
// Restore original provider
aiService.selectedProvider = originalProvider
}
}
}
private func clearAPIKey() {
UserDefaults.standard.removeObject(forKey: "\(providerKey)APIKey")
apiKey = ""
verificationStatus = .none
withAnimation(.easeInOut(duration: 0.3)) {
isExpanded = false
}
}
}
// MARK: - Helper Views and Functions
private func progressDotsWithNumber(value: Double) -> some View {
HStack(spacing: 4) {
progressDots(value: value)
Text(String(format: "%.1f", value))
.font(.system(size: 10, weight: .medium, design: .monospaced))
.foregroundColor(Color(.secondaryLabelColor))
}
}
private func progressDots(value: Double) -> some View {
HStack(spacing: 2) {
ForEach(0..<5) { index in
Circle()
.fill(index < Int(value / 2) ? performanceColor(value: value / 10) : Color(.quaternaryLabelColor))
.frame(width: 6, height: 6)
}
}
}
private func performanceColor(value: Double) -> Color {
switch value {
case 0.8...1.0: return Color(.systemGreen)
case 0.6..<0.8: return Color(.systemYellow)
case 0.4..<0.6: return Color(.systemOrange)
default: return Color(.systemRed)
}
}

View File

@ -39,9 +39,7 @@ struct ModelManagementView: View {
Text("Default Model")
.font(.headline)
.foregroundColor(.secondary)
Text(whisperState.currentModel.flatMap { model in
PredefinedModels.models.first { $0.name == model.name }?.displayName
} ?? "No model selected")
Text(whisperState.currentTranscriptionModel?.displayName ?? "No model selected")
.font(.title2)
.fontWeight(.bold)
}
@ -62,7 +60,7 @@ struct ModelManagementView: View {
.font(.title3)
.fontWeight(.semibold)
Text("(\(whisperState.predefinedModels.count))")
Text("(\(whisperState.allAvailableModels.count))")
.foregroundColor(.secondary)
.font(.subheadline)
@ -70,11 +68,11 @@ struct ModelManagementView: View {
}
VStack(spacing: 12) {
ForEach(whisperState.predefinedModels) { model in
ForEach(whisperState.allAvailableModels, id: \.id) { model in
ModelCardRowView(
model: model,
isDownloaded: whisperState.availableModels.contains { $0.name == model.name },
isCurrent: whisperState.currentModel?.name == model.name,
isCurrent: whisperState.currentTranscriptionModel?.name == model.name,
downloadProgress: whisperState.downloadProgress,
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
deleteAction: {
@ -83,15 +81,15 @@ struct ModelManagementView: View {
}
},
setDefaultAction: {
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
Task {
await whisperState.setDefaultModel(downloadedModel)
}
Task {
await whisperState.setDefaultTranscriptionModel(model)
}
},
downloadAction: {
Task {
await whisperState.downloadModel(model)
if let localModel = model as? LocalModel {
Task {
await whisperState.downloadModel(localModel)
}
}
}
)

View File

@ -9,7 +9,7 @@ struct OnboardingModelDownloadView: View {
@State private var isModelSet = false
@State private var showTutorial = false
private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" }!
private let turboModel = PredefinedModels.models.first { $0.name == "ggml-large-v3-turbo-q5_0" } as! LocalModel
var body: some View {
ZStack {

View File

@ -122,7 +122,7 @@ extension WhisperState {
}
// Shows an alert about Core ML support and first-run optimization
private func showCoreMLAlert(for model: PredefinedModel, completion: @escaping () -> Void) {
private func showCoreMLAlert(for model: LocalModel, completion: @escaping () -> Void) {
Task { @MainActor in
let alert = NSAlert()
alert.messageText = "Core ML Support for \(model.displayName) Model"
@ -138,7 +138,7 @@ extension WhisperState {
}
}
func downloadModel(_ model: PredefinedModel) async {
func downloadModel(_ model: LocalModel) async {
guard let url = URL(string: model.downloadURL) else { return }
// Check if model supports Core ML (non-quantized models)
@ -160,7 +160,7 @@ extension WhisperState {
}
}
private func performModelDownload(_ model: PredefinedModel, _ url: URL) async {
private func performModelDownload(_ model: LocalModel, _ url: URL) async {
do {
let whisperModel = try await downloadMainModel(model, from: url)
@ -176,7 +176,7 @@ extension WhisperState {
}
}
private func downloadMainModel(_ model: PredefinedModel, from url: URL) async throws -> WhisperModel {
private func downloadMainModel(_ model: LocalModel, from url: URL) async throws -> WhisperModel {
let progressKeyMain = model.name + "_main"
let data = try await downloadFileWithProgress(from: url, progressKey: progressKeyMain)
@ -232,8 +232,7 @@ extension WhisperState {
return model
}
private func handleModelDownloadError(_ model: PredefinedModel, _ error: Error) {
currentError = .modelDownloadFailed
private func handleModelDownloadError(_ model: LocalModel, _ error: Error) {
self.downloadProgress.removeValue(forKey: model.name + "_main")
self.downloadProgress.removeValue(forKey: model.name + "_coreml")
}
@ -262,7 +261,6 @@ extension WhisperState {
}
} catch {
logError("Error deleting model: \(model.name)", error)
currentError = .modelDeletionFailed
}
}

View File

@ -0,0 +1,18 @@
import Foundation
extension WhisperState {
var usableModels: [any TranscriptionModel] {
allAvailableModels.filter { model in
switch model.provider {
case .local:
return availableModels.contains { $0.name == model.name }
case .groq:
let key = UserDefaults.standard.string(forKey: "GROQAPIKey")
return key != nil && !key!.isEmpty
case .elevenLabs:
let key = UserDefaults.standard.string(forKey: "ElevenLabsAPIKey")
return key != nil && !key!.isEmpty
}
}
}
}

View File

@ -12,9 +12,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
@Published var canTranscribe = false
@Published var isRecording = false
@Published var currentModel: WhisperModel?
@Published var currentTranscriptionModel: (any TranscriptionModel)?
@Published var isModelLoading = false
@Published var availableModels: [WhisperModel] = []
@Published var predefinedModels: [PredefinedModel] = PredefinedModels.models
@Published var allAvailableModels: [any TranscriptionModel] = PredefinedModels.models
@Published var clipboardMessage = ""
@Published var miniRecorderError: String?
@Published var isProcessing = false
@ -53,6 +54,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
let modelContext: ModelContext
// Transcription Services
private var localTranscriptionService: LocalTranscriptionService
private let cloudTranscriptionService = CloudTranscriptionService()
private var modelUrl: URL? {
let possibleURLs = [
Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "Models"),
@ -85,21 +90,28 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
init(modelContext: ModelContext, enhancementService: AIEnhancementService? = nil) {
self.modelContext = modelContext
self.modelsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
let appSupportDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("WhisperModels")
self.recordingsDirectory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0]
.appendingPathComponent("com.prakashjoshipax.VoiceInk")
.appendingPathComponent("Recordings")
self.modelsDirectory = appSupportDirectory.appendingPathComponent("WhisperModels")
self.recordingsDirectory = appSupportDirectory.appendingPathComponent("Recordings")
// Initialize services without whisperState reference first
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory)
self.enhancementService = enhancementService
self.licenseViewModel = LicenseViewModel()
super.init()
// Set the whisperState reference after super.init()
self.localTranscriptionService = LocalTranscriptionService(modelsDirectory: self.modelsDirectory, whisperState: self)
setupNotifications()
createModelsDirectoryIfNeeded()
createRecordingsDirectoryIfNeeded()
loadAvailableModels()
loadCurrentTranscriptionModel()
if let savedModelName = UserDefaults.standard.string(forKey: "CurrentModel"),
let savedModel = availableModels.first(where: { $0.name == savedModelName }) {
@ -139,11 +151,11 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
logger.error("❌ No recorded file found after stopping recording")
}
} else {
guard currentModel != nil else {
guard currentTranscriptionModel != nil else {
await MainActor.run {
let alert = NSAlert()
alert.messageText = "No Whisper Model Selected"
alert.informativeText = "Please select a default whisper model in AI Models tab before recording."
alert.messageText = "No AI Model Selected"
alert.informativeText = "Please select a default AI model in AI Models tab before recording."
alert.alertStyle = .warning
alert.addButton(withTitle: "OK")
alert.runModal()
@ -174,11 +186,15 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
await ActiveWindowService.shared.applyConfigurationForCurrentApp()
if let currentModel = await self.currentModel, await self.whisperContext == nil {
do {
try await self.loadModel(currentModel)
} catch {
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
// Only load model if it's a local model and not already loaded
if let model = self.currentTranscriptionModel, model.provider == .local {
if let localWhisperModel = self.availableModels.first(where: { $0.name == model.name }),
self.whisperContext == nil {
do {
try await self.loadModel(localWhisperModel)
} catch {
self.logger.error("❌ Model loading failed: \(error.localizedDescription)")
}
}
}
@ -256,11 +272,13 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
await cleanupModelResources()
return
}
await MainActor.run {
isProcessing = true
isTranscribing = true
canTranscribe = false
}
defer {
if shouldCancelRecording {
Task {
@ -268,56 +286,33 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
}
}
}
guard let currentModel = currentModel else {
guard let model = currentTranscriptionModel else {
logger.error("❌ Cannot transcribe: No model selected")
currentError = .modelLoadFailed
return
}
if whisperContext == nil {
logger.notice("🔄 Model not loaded yet, attempting to load now: \(currentModel.name)")
do {
try await loadModel(currentModel)
} catch {
logger.error("❌ Failed to load model: \(currentModel.name) - \(error.localizedDescription)")
currentError = .modelLoadFailed
return
}
}
guard let whisperContext = whisperContext else {
logger.error("❌ Cannot transcribe: Model could not be loaded")
currentError = .modelLoadFailed
return
}
logger.notice("🔄 Starting transcription with model: \(currentModel.name)")
logger.notice("🔄 Starting transcription with model: \(model.displayName)")
do {
let permanentURL = try saveRecordingPermanently(url)
let permanentURLString = permanentURL.absoluteString
if shouldCancelRecording { return }
let data = try readAudioSamples(url)
if shouldCancelRecording { return }
// Get the actual audio duration from the file
let audioAsset = AVURLAsset(url: url)
let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration))
logger.notice("📊 Audio file duration: \(actualDuration) seconds")
// Ensure we're using the most recent prompt from UserDefaults
let currentPrompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? whisperPrompt.transcriptionPrompt
await whisperContext.setPrompt(currentPrompt)
if shouldCancelRecording { return }
await whisperContext.fullTranscribe(samples: data)
if shouldCancelRecording { return }
var text = await whisperContext.getTranscription()
// --- Core Transcription Logic ---
let transcriptionService: TranscriptionService = (model.provider == .local) ? localTranscriptionService : cloudTranscriptionService
var text = try await transcriptionService.transcribe(audioURL: url, model: model)
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
logger.notice("✅ Transcription completed successfully, length: \(text.count) characters")
// --- Post-processing and Saving ---
let permanentURL = try saveRecordingPermanently(url)
if UserDefaults.standard.bool(forKey: "IsWordReplacementEnabled") {
text = WordReplacementService.shared.applyReplacements(to: text)
logger.notice("✅ Word replacements applied")
}
let audioAsset = AVURLAsset(url: url)
let actualDuration = CMTimeGetSeconds(try await audioAsset.load(.duration))
var promptDetectionResult: PromptDetectionService.PromptDetectionResult? = nil
let originalText = text
let originalText = text
if let enhancementService = enhancementService, enhancementService.isConfigured {
let detectionResult = promptDetectionService.analyzeText(text, with: enhancementService)
@ -330,23 +325,22 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
enhancementService.isConfigured {
do {
if shouldCancelRecording { return }
// Use processed text (without trigger words) for AI enhancement
let textForAI = promptDetectionResult?.processedText ?? text
let enhancedText = try await enhancementService.enhance(textForAI)
let newTranscription = Transcription(
text: originalText,
text: originalText,
duration: actualDuration,
enhancedText: enhancedText,
audioFileURL: permanentURLString
audioFileURL: permanentURL.absoluteString
)
modelContext.insert(newTranscription)
try? modelContext.save()
text = enhancedText
text = enhancedText
} catch {
let newTranscription = Transcription(
text: originalText,
duration: actualDuration,
audioFileURL: permanentURLString
audioFileURL: permanentURL.absoluteString
)
modelContext.insert(newTranscription)
try? modelContext.save()
@ -355,11 +349,12 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
let newTranscription = Transcription(
text: originalText,
duration: actualDuration,
audioFileURL: permanentURLString
audioFileURL: permanentURL.absoluteString
)
modelContext.insert(newTranscription)
try? modelContext.save()
}
if case .trialExpired = licenseViewModel.licenseState {
text = """
Your trial has expired. Upgrade to VoiceInk Pro at tryvoiceink.com/buy
@ -367,7 +362,6 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
"""
}
// Add a space to the end of the text
text += " "
SoundManager.shared.playStopSound()
@ -386,8 +380,8 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
}
try? FileManager.default.removeItem(at: url)
if let result = promptDetectionResult,
let enhancementService = enhancementService,
if let result = promptDetectionResult,
let enhancementService = enhancementService,
result.shouldEnableAI {
await promptDetectionService.restoreOriginalSettings(result, to: enhancementService)
}
@ -396,39 +390,64 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
await cleanupModelResources()
} catch {
currentError = .transcriptionFailed
logger.error("❌ Transcription failed: \(error.localizedDescription)")
await cleanupModelResources()
await dismissMiniRecorder()
}
}
private func readAudioSamples(_ url: URL) throws -> [Float] {
return try decodeWaveFile(url)
}
private func decodeWaveFile(_ url: URL) throws -> [Float] {
let data = try Data(contentsOf: url)
let floats = stride(from: 44, to: data.count, by: 2).map {
return data[$0..<$0 + 2].withUnsafeBytes {
let short = Int16(littleEndian: $0.load(as: Int16.self))
return max(-1.0, min(Float(short) / 32767.0, 1.0))
}
}
return floats
}
@Published var currentError: WhisperStateError?
func getEnhancementService() -> AIEnhancementService? {
return enhancementService
}
private func saveRecordingPermanently(_ tempURL: URL) throws -> URL {
let fileName = "\(UUID().uuidString).wav"
let permanentURL = recordingsDirectory.appendingPathComponent(fileName)
try FileManager.default.copyItem(at: tempURL, to: permanentURL)
return permanentURL
}
private func loadCurrentTranscriptionModel() {
if let savedModelName = UserDefaults.standard.string(forKey: "CurrentTranscriptionModel"),
let savedModel = allAvailableModels.first(where: { $0.name == savedModelName }) {
currentTranscriptionModel = savedModel
// If it's a local model, also set it as currentModel for backward compatibility
if let localModel = savedModel as? LocalModel,
let whisperModel = availableModels.first(where: { $0.name == localModel.name }) {
currentModel = whisperModel
}
}
}
// Function to set any transcription model as default
func setDefaultTranscriptionModel(_ model: any TranscriptionModel) async {
await MainActor.run {
self.currentTranscriptionModel = model
UserDefaults.standard.set(model.name, forKey: "CurrentTranscriptionModel")
// If it's a local model, also update currentModel for backward compatibility
if let localModel = model as? LocalModel,
let whisperModel = self.availableModels.first(where: { $0.name == localModel.name }) {
self.currentModel = whisperModel
UserDefaults.standard.set(whisperModel.name, forKey: "CurrentModel")
} else {
// For cloud models, clear the old currentModel
self.currentModel = nil
}
// Enable transcription for cloud models immediately since they don't need loading
if model.provider != .local {
self.canTranscribe = true
self.isModelLoaded = true
}
}
logger.info("Default transcription model set to: \(model.name) (\(model.provider.rawValue))")
// Post notification about the model change
NotificationCenter.default.post(name: .didChangeModel, object: nil, userInfo: ["modelName": model.name])
}
func getEnhancementService() -> AIEnhancementService? {
return enhancementService
}
}
struct WhisperModel: Identifiable {
@ -473,4 +492,5 @@ private class TaskDelegate: NSObject, URLSessionTaskDelegate {
extension Notification.Name {
static let toggleMiniRecorder = Notification.Name("toggleMiniRecorder")
static let didChangeModel = Notification.Name("didChangeModel")
}