Added support for Gemini models for transcription
This commit is contained in:
parent
49a3cd0f4a
commit
2708cc502a
@ -232,6 +232,28 @@ import Foundation
|
||||
accuracy: 0.97,
|
||||
isMultilingual: true,
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, provider: .mistral)
|
||||
),
|
||||
|
||||
// Gemini Models
|
||||
CloudModel(
|
||||
name: "gemini-2.5-pro",
|
||||
displayName: "Gemini 2.5 Pro",
|
||||
description: "Google's advanced multimodal model with high-quality transcription capabilities.",
|
||||
provider: .gemini,
|
||||
speed: 0.7,
|
||||
accuracy: 0.96,
|
||||
isMultilingual: true,
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, provider: .gemini)
|
||||
),
|
||||
CloudModel(
|
||||
name: "gemini-2.5-flash",
|
||||
displayName: "Gemini 2.5 Flash",
|
||||
description: "Google's optimized model for low-latency transcription with multimodal support.",
|
||||
provider: .gemini,
|
||||
speed: 0.9,
|
||||
accuracy: 0.94,
|
||||
isMultilingual: true,
|
||||
supportedLanguages: getLanguageDictionary(isMultilingual: true, provider: .gemini)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ enum ModelProvider: String, Codable, Hashable, CaseIterable {
|
||||
case elevenLabs = "ElevenLabs"
|
||||
case deepgram = "Deepgram"
|
||||
case mistral = "Mistral"
|
||||
case gemini = "Gemini"
|
||||
case custom = "Custom"
|
||||
case nativeApple = "Native Apple"
|
||||
// Future providers can be added here
|
||||
|
||||
@ -39,6 +39,7 @@ class CloudTranscriptionService: TranscriptionService {
|
||||
private lazy var elevenLabsService = ElevenLabsTranscriptionService()
|
||||
private lazy var deepgramService = DeepgramTranscriptionService()
|
||||
private lazy var mistralService = MistralTranscriptionService()
|
||||
private lazy var geminiService = GeminiTranscriptionService()
|
||||
private lazy var openAICompatibleService = OpenAICompatibleTranscriptionService()
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
@ -53,6 +54,8 @@ class CloudTranscriptionService: TranscriptionService {
|
||||
text = try await deepgramService.transcribe(audioURL: audioURL, model: model)
|
||||
case .mistral:
|
||||
text = try await mistralService.transcribe(audioURL: audioURL, model: model)
|
||||
case .gemini:
|
||||
text = try await geminiService.transcribe(audioURL: audioURL, model: model)
|
||||
case .custom:
|
||||
guard let customModel = model as? CustomCloudModel else {
|
||||
throw CloudTranscriptionError.unsupportedProvider
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
class GeminiTranscriptionService {
|
||||
private let logger = Logger(subsystem: "com.voiceink.transcription", category: "GeminiService")
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
let config = try getAPIConfig(for: model)
|
||||
|
||||
logger.notice("Starting Gemini transcription with model: \(model.name, privacy: .public)")
|
||||
|
||||
var request = URLRequest(url: config.url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.setValue(config.apiKey, forHTTPHeaderField: "x-goog-api-key")
|
||||
|
||||
guard let audioData = try? Data(contentsOf: audioURL) else {
|
||||
throw CloudTranscriptionError.audioFileNotFound
|
||||
}
|
||||
|
||||
logger.notice("Audio file loaded, size: \(audioData.count) bytes")
|
||||
|
||||
let base64AudioData = audioData.base64EncodedString()
|
||||
|
||||
let requestBody = GeminiRequest(
|
||||
contents: [
|
||||
GeminiContent(
|
||||
parts: [
|
||||
.text(GeminiTextPart(text: "Please transcribe this audio file. Provide only the transcribed text.")),
|
||||
.audio(GeminiAudioPart(
|
||||
inlineData: GeminiInlineData(
|
||||
mimeType: "audio/wav",
|
||||
data: base64AudioData
|
||||
)
|
||||
))
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
do {
|
||||
let jsonData = try JSONEncoder().encode(requestBody)
|
||||
request.httpBody = jsonData
|
||||
logger.notice("Request body encoded, sending to Gemini API")
|
||||
} catch {
|
||||
logger.error("Failed to encode Gemini request: \(error.localizedDescription)")
|
||||
throw CloudTranscriptionError.dataEncodingError
|
||||
}
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw CloudTranscriptionError.networkError(URLError(.badServerResponse))
|
||||
}
|
||||
|
||||
if !(200...299).contains(httpResponse.statusCode) {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "No error message"
|
||||
logger.error("Gemini API request failed with status \(httpResponse.statusCode): \(errorMessage, privacy: .public)")
|
||||
throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage)
|
||||
}
|
||||
|
||||
do {
|
||||
let transcriptionResponse = try JSONDecoder().decode(GeminiResponse.self, from: data)
|
||||
guard let candidate = transcriptionResponse.candidates.first,
|
||||
let part = candidate.content.parts.first,
|
||||
!part.text.isEmpty else {
|
||||
logger.error("No transcript found in Gemini response")
|
||||
throw CloudTranscriptionError.noTranscriptionReturned
|
||||
}
|
||||
logger.notice("Gemini transcription successful, text length: \(part.text.count)")
|
||||
return part.text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
} catch {
|
||||
logger.error("Failed to decode Gemini API response: \(error.localizedDescription)")
|
||||
throw CloudTranscriptionError.noTranscriptionReturned
|
||||
}
|
||||
}
|
||||
|
||||
private func getAPIConfig(for model: any TranscriptionModel) throws -> APIConfig {
|
||||
guard let apiKey = UserDefaults.standard.string(forKey: "GeminiAPIKey"), !apiKey.isEmpty else {
|
||||
throw CloudTranscriptionError.missingAPIKey
|
||||
}
|
||||
|
||||
let urlString = "https://generativelanguage.googleapis.com/v1beta/models/\(model.name):generateContent"
|
||||
guard let apiURL = URL(string: urlString) else {
|
||||
throw CloudTranscriptionError.dataEncodingError
|
||||
}
|
||||
|
||||
return APIConfig(url: apiURL, apiKey: apiKey, modelName: model.name)
|
||||
}
|
||||
|
||||
private struct APIConfig {
|
||||
let url: URL
|
||||
let apiKey: String
|
||||
let modelName: String
|
||||
}
|
||||
|
||||
private struct GeminiRequest: Codable {
|
||||
let contents: [GeminiContent]
|
||||
}
|
||||
|
||||
private struct GeminiContent: Codable {
|
||||
let parts: [GeminiPart]
|
||||
}
|
||||
|
||||
private enum GeminiPart: Codable {
|
||||
case text(GeminiTextPart)
|
||||
case audio(GeminiAudioPart)
|
||||
|
||||
func encode(to encoder: Encoder) throws {
|
||||
var container = encoder.singleValueContainer()
|
||||
switch self {
|
||||
case .text(let textPart):
|
||||
try container.encode(textPart)
|
||||
case .audio(let audioPart):
|
||||
try container.encode(audioPart)
|
||||
}
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
if let textPart = try? container.decode(GeminiTextPart.self) {
|
||||
self = .text(textPart)
|
||||
} else if let audioPart = try? container.decode(GeminiAudioPart.self) {
|
||||
self = .audio(audioPart)
|
||||
} else {
|
||||
throw DecodingError.dataCorrupted(DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Invalid part"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private struct GeminiTextPart: Codable {
|
||||
let text: String
|
||||
}
|
||||
|
||||
private struct GeminiAudioPart: Codable {
|
||||
let inlineData: GeminiInlineData
|
||||
}
|
||||
|
||||
private struct GeminiInlineData: Codable {
|
||||
let mimeType: String
|
||||
let data: String
|
||||
}
|
||||
|
||||
private struct GeminiResponse: Codable {
|
||||
let candidates: [GeminiCandidate]
|
||||
}
|
||||
|
||||
private struct GeminiCandidate: Codable {
|
||||
let content: GeminiResponseContent
|
||||
}
|
||||
|
||||
private struct GeminiResponseContent: Codable {
|
||||
let parts: [GeminiResponsePart]
|
||||
}
|
||||
|
||||
private struct GeminiResponsePart: Codable {
|
||||
let text: String
|
||||
}
|
||||
}
|
||||
@ -36,6 +36,8 @@ struct CloudModelCardView: View {
|
||||
return "Deepgram"
|
||||
case .mistral:
|
||||
return "Mistral"
|
||||
case .gemini:
|
||||
return "Gemini"
|
||||
default:
|
||||
return model.provider.rawValue
|
||||
}
|
||||
@ -277,6 +279,8 @@ struct CloudModelCardView: View {
|
||||
aiService.selectedProvider = .deepgram
|
||||
case .mistral:
|
||||
aiService.selectedProvider = .mistral
|
||||
case .gemini:
|
||||
aiService.selectedProvider = .gemini
|
||||
default:
|
||||
// This case should ideally not be hit for cloud models in this view
|
||||
print("Warning: verifyAPIKey called for unsupported provider \(model.provider.rawValue)")
|
||||
|
||||
@ -55,7 +55,7 @@ struct ModelCardRowView: View {
|
||||
setDefaultAction: setDefaultAction
|
||||
)
|
||||
}
|
||||
case .groq, .elevenLabs, .deepgram, .mistral:
|
||||
case .groq, .elevenLabs, .deepgram, .mistral, .gemini:
|
||||
if let cloudModel = model as? CloudModel {
|
||||
CloudModelCardView(
|
||||
model: cloudModel,
|
||||
|
||||
@ -218,7 +218,7 @@ struct ModelManagementView: View {
|
||||
case .local:
|
||||
return whisperState.allAvailableModels.filter { $0.provider == .local || $0.provider == .nativeApple || $0.provider == .parakeet }
|
||||
case .cloud:
|
||||
let cloudProviders: [ModelProvider] = [.groq, .elevenLabs, .deepgram, .mistral]
|
||||
let cloudProviders: [ModelProvider] = [.groq, .elevenLabs, .deepgram, .mistral, .gemini]
|
||||
return whisperState.allAvailableModels.filter { cloudProviders.contains($0.provider) }
|
||||
case .custom:
|
||||
return whisperState.allAvailableModels.filter { $0.provider == .custom }
|
||||
|
||||
@ -26,6 +26,9 @@ extension WhisperState {
|
||||
case .mistral:
|
||||
let key = UserDefaults.standard.string(forKey: "MistralAPIKey")
|
||||
return key != nil && !key!.isEmpty
|
||||
case .gemini:
|
||||
let key = UserDefaults.standard.string(forKey: "GeminiAPIKey")
|
||||
return key != nil && !key!.isEmpty
|
||||
case .custom:
|
||||
// Custom models are always usable since they contain their own API keys
|
||||
return true
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user