Custom Trancription model with openAI compatibility
This commit is contained in:
parent
f0dee3069b
commit
05cfbd1a0b
@ -27,15 +27,19 @@ import Foundation
|
||||
}
|
||||
|
||||
enum PredefinedModels {
|
||||
static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] {
|
||||
if !isMultilingual {
|
||||
return ["en": "English"]
|
||||
} else {
|
||||
return allLanguages
|
||||
}
|
||||
}
|
||||
|
||||
static let models: [any TranscriptionModel] = [
|
||||
static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] {
|
||||
if !isMultilingual {
|
||||
return ["en": "English"]
|
||||
} else {
|
||||
return allLanguages
|
||||
}
|
||||
}
|
||||
|
||||
static var models: [any TranscriptionModel] {
|
||||
return predefinedModels + CustomModelManager.shared.customModels
|
||||
}
|
||||
|
||||
private static let predefinedModels: [any TranscriptionModel] = [
|
||||
// Local Models
|
||||
LocalModel(
|
||||
name: "ggml-tiny",
|
||||
|
||||
@ -6,6 +6,7 @@ enum ModelProvider: String, Codable, Hashable, CaseIterable {
|
||||
case groq = "Groq"
|
||||
case elevenLabs = "ElevenLabs"
|
||||
case deepgram = "Deepgram"
|
||||
case custom = "Custom"
|
||||
// Future providers can be added here
|
||||
}
|
||||
|
||||
@ -55,4 +56,30 @@ struct CloudModel: TranscriptionModel {
|
||||
self.isMultilingualModel = isMultilingual
|
||||
self.supportedLanguages = supportedLanguages
|
||||
}
|
||||
}
|
||||
|
||||
// A new struct for custom cloud models
|
||||
struct CustomCloudModel: TranscriptionModel, Codable {
|
||||
let id: UUID
|
||||
let name: String
|
||||
let displayName: String
|
||||
let description: String
|
||||
let provider: ModelProvider = .custom
|
||||
let apiEndpoint: String
|
||||
let apiKey: String
|
||||
let modelName: String
|
||||
let isMultilingualModel: Bool
|
||||
let supportedLanguages: [String: String]
|
||||
|
||||
init(id: UUID = UUID(), name: String, displayName: String, description: String, apiEndpoint: String, apiKey: String, modelName: String, isMultilingual: Bool = true, supportedLanguages: [String: String]? = nil) {
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.displayName = displayName
|
||||
self.description = description
|
||||
self.apiEndpoint = apiEndpoint
|
||||
self.apiKey = apiKey
|
||||
self.modelName = modelName
|
||||
self.isMultilingualModel = isMultilingual
|
||||
self.supportedLanguages = supportedLanguages ?? PredefinedModels.getLanguageDictionary(isMultilingual: isMultilingual)
|
||||
}
|
||||
}
|
||||
@ -487,7 +487,7 @@ struct ConfigurationView: View {
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
} else {
|
||||
Picker("", selection: providerBinding) {
|
||||
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in
|
||||
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in
|
||||
Text(provider.rawValue).tag(provider)
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ class CloudTranscriptionService: TranscriptionService {
|
||||
private let groqService = GroqTranscriptionService()
|
||||
private let elevenLabsService = ElevenLabsTranscriptionService()
|
||||
private let deepgramService = DeepgramTranscriptionService()
|
||||
private let openAICompatibleService = OpenAICompatibleTranscriptionService()
|
||||
|
||||
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
|
||||
switch model.provider {
|
||||
@ -47,6 +48,11 @@ class CloudTranscriptionService: TranscriptionService {
|
||||
return try await elevenLabsService.transcribe(audioURL: audioURL, model: model)
|
||||
case .deepgram:
|
||||
return try await deepgramService.transcribe(audioURL: audioURL, model: model)
|
||||
case .custom:
|
||||
guard let customModel = model as? CustomCloudModel else {
|
||||
throw CloudTranscriptionError.unsupportedProvider
|
||||
}
|
||||
return try await openAICompatibleService.transcribe(audioURL: audioURL, model: customModel)
|
||||
default:
|
||||
throw CloudTranscriptionError.unsupportedProvider
|
||||
}
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
class OpenAICompatibleTranscriptionService {
|
||||
private let logger = Logger(subsystem: "com.voiceink.transcription", category: "OpenAICompatibleService")
|
||||
|
||||
func transcribe(audioURL: URL, model: CustomCloudModel) async throws -> String {
|
||||
let config = APIConfig(
|
||||
url: URL(string: model.apiEndpoint)!,
|
||||
apiKey: model.apiKey,
|
||||
modelName: model.modelName
|
||||
)
|
||||
|
||||
let boundary = "Boundary-\(UUID().uuidString)"
|
||||
var request = URLRequest(url: config.url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
|
||||
request.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization")
|
||||
|
||||
let body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
|
||||
|
||||
let (data, response) = try await URLSession.shared.upload(for: request, from: body)
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw CloudTranscriptionError.networkError(URLError(.badServerResponse))
|
||||
}
|
||||
|
||||
if !(200...299).contains(httpResponse.statusCode) {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "No error message"
|
||||
logger.error("OpenAI-compatible API request failed with status \(httpResponse.statusCode): \(errorMessage, privacy: .public)")
|
||||
throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage)
|
||||
}
|
||||
|
||||
do {
|
||||
let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data)
|
||||
return transcriptionResponse.text
|
||||
} catch {
|
||||
logger.error("Failed to decode OpenAI-compatible API response: \(error.localizedDescription)")
|
||||
throw CloudTranscriptionError.noTranscriptionReturned
|
||||
}
|
||||
}
|
||||
|
||||
private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
|
||||
var body = Data()
|
||||
let crlf = "\r\n"
|
||||
|
||||
guard let audioData = try? Data(contentsOf: audioURL) else {
|
||||
throw CloudTranscriptionError.audioFileNotFound
|
||||
}
|
||||
|
||||
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
|
||||
let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(audioData)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(modelName.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(selectedLanguage.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
// Include prompt for OpenAI-compatible APIs
|
||||
if !prompt.isEmpty {
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append(prompt.data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append("json".data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
|
||||
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
|
||||
body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!)
|
||||
body.append("0".data(using: .utf8)!)
|
||||
body.append(crlf.data(using: .utf8)!)
|
||||
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
private struct APIConfig {
|
||||
let url: URL
|
||||
let apiKey: String
|
||||
let modelName: String
|
||||
}
|
||||
|
||||
private struct TranscriptionResponse: Decodable {
|
||||
let text: String
|
||||
let language: String?
|
||||
let duration: Double?
|
||||
}
|
||||
}
|
||||
138
VoiceInk/Services/CustomModelManager.swift
Normal file
138
VoiceInk/Services/CustomModelManager.swift
Normal file
@ -0,0 +1,138 @@
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
class CustomModelManager: ObservableObject {
|
||||
static let shared = CustomModelManager()
|
||||
|
||||
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "CustomModelManager")
|
||||
private let userDefaults = UserDefaults.standard
|
||||
private let customModelsKey = "customCloudModels"
|
||||
|
||||
@Published var customModels: [CustomCloudModel] = []
|
||||
|
||||
private init() {
|
||||
loadCustomModels()
|
||||
}
|
||||
|
||||
// MARK: - CRUD Operations
|
||||
|
||||
func addCustomModel(_ model: CustomCloudModel) {
|
||||
customModels.append(model)
|
||||
saveCustomModels()
|
||||
logger.info("Added custom model: \(model.displayName)")
|
||||
}
|
||||
|
||||
func removeCustomModel(withId id: UUID) {
|
||||
customModels.removeAll { $0.id == id }
|
||||
saveCustomModels()
|
||||
logger.info("Removed custom model with ID: \(id)")
|
||||
}
|
||||
|
||||
func updateCustomModel(_ updatedModel: CustomCloudModel) {
|
||||
if let index = customModels.firstIndex(where: { $0.id == updatedModel.id }) {
|
||||
customModels[index] = updatedModel
|
||||
saveCustomModels()
|
||||
logger.info("Updated custom model: \(updatedModel.displayName)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private func loadCustomModels() {
|
||||
guard let data = userDefaults.data(forKey: customModelsKey) else {
|
||||
logger.info("No custom models found in UserDefaults")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
customModels = try JSONDecoder().decode([CustomCloudModel].self, from: data)
|
||||
} catch {
|
||||
logger.error("Failed to decode custom models: \(error.localizedDescription)")
|
||||
customModels = []
|
||||
}
|
||||
}
|
||||
|
||||
private func saveCustomModels() {
|
||||
do {
|
||||
let data = try JSONEncoder().encode(customModels)
|
||||
userDefaults.set(data, forKey: customModelsKey)
|
||||
} catch {
|
||||
logger.error("Failed to encode custom models: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Validation
|
||||
|
||||
func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String) -> [String] {
|
||||
var errors: [String] = []
|
||||
|
||||
if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Name cannot be empty")
|
||||
}
|
||||
|
||||
if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Display name cannot be empty")
|
||||
}
|
||||
|
||||
if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("API endpoint cannot be empty")
|
||||
} else if !isValidURL(apiEndpoint) {
|
||||
errors.append("API endpoint must be a valid URL")
|
||||
}
|
||||
|
||||
if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("API key cannot be empty")
|
||||
}
|
||||
|
||||
if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Model name cannot be empty")
|
||||
}
|
||||
|
||||
// Check for duplicate names
|
||||
if customModels.contains(where: { $0.name == name }) {
|
||||
errors.append("A model with this name already exists")
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String, excludingId: UUID? = nil) -> [String] {
|
||||
var errors: [String] = []
|
||||
|
||||
if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Name cannot be empty")
|
||||
}
|
||||
|
||||
if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Display name cannot be empty")
|
||||
}
|
||||
|
||||
if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("API endpoint cannot be empty")
|
||||
} else if !isValidURL(apiEndpoint) {
|
||||
errors.append("API endpoint must be a valid URL")
|
||||
}
|
||||
|
||||
if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("API key cannot be empty")
|
||||
}
|
||||
|
||||
if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
errors.append("Model name cannot be empty")
|
||||
}
|
||||
|
||||
// Check for duplicate names, excluding the specified ID
|
||||
if customModels.contains(where: { $0.name == name && $0.id != excludingId }) {
|
||||
errors.append("A model with this name already exists")
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
private func isValidURL(_ string: String) -> Bool {
|
||||
if let url = URL(string: string) {
|
||||
return url.scheme != nil && url.host != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -44,7 +44,7 @@ struct APIKeyManagementView: View {
|
||||
|
||||
// Provider Selection
|
||||
Picker("AI Provider", selection: $aiService.selectedProvider) {
|
||||
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in
|
||||
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in
|
||||
Text(provider.rawValue).tag(provider)
|
||||
}
|
||||
}
|
||||
|
||||
290
VoiceInk/Views/AddCustomModelView.swift
Normal file
290
VoiceInk/Views/AddCustomModelView.swift
Normal file
@ -0,0 +1,290 @@
|
||||
import SwiftUI
|
||||
|
||||
struct AddCustomModelCardView: View {
|
||||
@ObservedObject var customModelManager: CustomModelManager
|
||||
var onModelAdded: () -> Void
|
||||
var editingModel: CustomCloudModel? = nil
|
||||
|
||||
@State private var isExpanded = false
|
||||
@State private var displayName = ""
|
||||
@State private var apiEndpoint = ""
|
||||
@State private var apiKey = ""
|
||||
@State private var modelName = ""
|
||||
@State private var isMultilingual = true
|
||||
|
||||
@State private var validationErrors: [String] = []
|
||||
@State private var showingAlert = false
|
||||
@State private var isSaving = false
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
// Simple Add Model Button
|
||||
if !isExpanded {
|
||||
Button(action: {
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded = true
|
||||
// Pre-fill values - either from editing model or defaults
|
||||
if let editing = editingModel {
|
||||
displayName = editing.displayName
|
||||
apiEndpoint = editing.apiEndpoint
|
||||
apiKey = editing.apiKey
|
||||
modelName = editing.modelName
|
||||
isMultilingual = editing.isMultilingualModel
|
||||
} else {
|
||||
// Pre-fill some default values when adding new
|
||||
if apiEndpoint.isEmpty {
|
||||
apiEndpoint = "https://api.example.com/v1/audio/transcriptions"
|
||||
}
|
||||
if modelName.isEmpty {
|
||||
modelName = "large-v3-turbo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}) {
|
||||
HStack(spacing: 8) {
|
||||
Image(systemName: "plus")
|
||||
.font(.system(size: 14, weight: .medium))
|
||||
Text(editingModel != nil ? "Edit Model" : "Add Model")
|
||||
.font(.system(size: 14, weight: .semibold))
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.vertical, 12)
|
||||
.background(Color.accentColor)
|
||||
.cornerRadius(12)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.shadow(color: Color.accentColor.opacity(0.3), radius: 8, y: 4)
|
||||
}
|
||||
|
||||
// Expandable Form Section
|
||||
if isExpanded {
|
||||
VStack(alignment: .leading, spacing: 20) {
|
||||
// Header
|
||||
HStack {
|
||||
Text(editingModel != nil ? "Edit Custom Model" : "Add Custom Model")
|
||||
.font(.headline)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
Spacer()
|
||||
|
||||
Button(action: {
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded = false
|
||||
clearForm()
|
||||
}
|
||||
}) {
|
||||
Image(systemName: "xmark")
|
||||
.font(.system(size: 12, weight: .medium))
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
}
|
||||
|
||||
// Disclaimer
|
||||
HStack(spacing: 8) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.foregroundColor(.orange)
|
||||
.font(.caption)
|
||||
Text("Only OpenAI-compatible transcription APIs are supported")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
.padding(.horizontal, 12)
|
||||
.padding(.vertical, 8)
|
||||
.background(Color.orange.opacity(0.1))
|
||||
.cornerRadius(8)
|
||||
|
||||
// Form fields
|
||||
VStack(alignment: .leading, spacing: 16) {
|
||||
FormField(title: "Display Name", text: $displayName, placeholder: "My Custom Model")
|
||||
FormField(title: "API Endpoint", text: $apiEndpoint, placeholder: "https://api.example.com/v1/audio/transcriptions")
|
||||
FormField(title: "API Key", text: $apiKey, placeholder: "your-api-key", isSecure: true)
|
||||
FormField(title: "Model Name", text: $modelName, placeholder: "whisper-1")
|
||||
|
||||
Toggle("Multilingual Model", isOn: $isMultilingual)
|
||||
}
|
||||
|
||||
// Action buttons
|
||||
HStack(spacing: 12) {
|
||||
Button(action: {
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded = false
|
||||
clearForm()
|
||||
}
|
||||
}) {
|
||||
Text("Cancel")
|
||||
.font(.system(size: 13, weight: .medium))
|
||||
.foregroundColor(.secondary)
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.vertical, 10)
|
||||
.background(Color.secondary.opacity(0.1))
|
||||
.cornerRadius(8)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
|
||||
Button(action: {
|
||||
addModel()
|
||||
}) {
|
||||
HStack(spacing: 6) {
|
||||
if isSaving {
|
||||
ProgressView()
|
||||
.scaleEffect(0.8)
|
||||
.frame(width: 14, height: 14)
|
||||
} else {
|
||||
Image(systemName: editingModel != nil ? "checkmark.circle.fill" : "plus.circle.fill")
|
||||
.font(.system(size: 14))
|
||||
}
|
||||
Text(editingModel != nil ? "Update Model" : "Add Model")
|
||||
.font(.system(size: 13, weight: .medium))
|
||||
}
|
||||
.foregroundColor(.white)
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.vertical, 10)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(isFormValid ? Color(.controlAccentColor) : Color.secondary)
|
||||
.shadow(color: (isFormValid ? Color(.controlAccentColor) : Color.secondary).opacity(0.2), radius: 2, x: 0, y: 1)
|
||||
)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(!isFormValid || isSaving)
|
||||
}
|
||||
}
|
||||
.padding(16)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 12)
|
||||
.fill(Color(.windowBackgroundColor))
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 12)
|
||||
.stroke(Color(.separatorColor), lineWidth: 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
.alert("Validation Errors", isPresented: $showingAlert) {
|
||||
Button("OK") { }
|
||||
} message: {
|
||||
Text(validationErrors.joined(separator: "\n"))
|
||||
}
|
||||
.onChange(of: editingModel) { oldValue, newValue in
|
||||
if newValue != nil {
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded = true
|
||||
// Pre-fill values from editing model
|
||||
if let editing = newValue {
|
||||
displayName = editing.displayName
|
||||
apiEndpoint = editing.apiEndpoint
|
||||
apiKey = editing.apiKey
|
||||
modelName = editing.modelName
|
||||
isMultilingual = editing.isMultilingualModel
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var isFormValid: Bool {
|
||||
!displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
|
||||
!apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
|
||||
!apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
|
||||
!modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
}
|
||||
|
||||
private func clearForm() {
|
||||
displayName = ""
|
||||
apiEndpoint = ""
|
||||
apiKey = ""
|
||||
modelName = ""
|
||||
isMultilingual = true
|
||||
}
|
||||
|
||||
private func addModel() {
|
||||
let trimmedDisplayName = displayName.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let trimmedApiEndpoint = apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let trimmedApiKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let trimmedModelName = modelName.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
// Generate a name from display name (lowercase, no spaces)
|
||||
let generatedName = trimmedDisplayName.lowercased().replacingOccurrences(of: " ", with: "-")
|
||||
|
||||
validationErrors = customModelManager.validateModel(
|
||||
name: generatedName,
|
||||
displayName: trimmedDisplayName,
|
||||
apiEndpoint: trimmedApiEndpoint,
|
||||
apiKey: trimmedApiKey,
|
||||
modelName: trimmedModelName,
|
||||
excludingId: editingModel?.id
|
||||
)
|
||||
|
||||
if !validationErrors.isEmpty {
|
||||
showingAlert = true
|
||||
return
|
||||
}
|
||||
|
||||
isSaving = true
|
||||
|
||||
// Simulate a brief save operation for better UX
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
|
||||
if let editing = editingModel {
|
||||
// Update existing model
|
||||
let updatedModel = CustomCloudModel(
|
||||
id: editing.id,
|
||||
name: generatedName,
|
||||
displayName: trimmedDisplayName,
|
||||
description: "Custom transcription model",
|
||||
apiEndpoint: trimmedApiEndpoint,
|
||||
apiKey: trimmedApiKey,
|
||||
modelName: trimmedModelName,
|
||||
isMultilingual: isMultilingual
|
||||
)
|
||||
customModelManager.updateCustomModel(updatedModel)
|
||||
} else {
|
||||
// Add new model
|
||||
let customModel = CustomCloudModel(
|
||||
name: generatedName,
|
||||
displayName: trimmedDisplayName,
|
||||
description: "Custom transcription model",
|
||||
apiEndpoint: trimmedApiEndpoint,
|
||||
apiKey: trimmedApiKey,
|
||||
modelName: trimmedModelName,
|
||||
isMultilingual: isMultilingual
|
||||
)
|
||||
customModelManager.addCustomModel(customModel)
|
||||
}
|
||||
|
||||
onModelAdded()
|
||||
|
||||
// Reset form and collapse
|
||||
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
|
||||
isExpanded = false
|
||||
clearForm()
|
||||
isSaving = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FormField: View {
|
||||
let title: String
|
||||
@Binding var text: String
|
||||
let placeholder: String
|
||||
var isSecure: Bool = false
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
Text(title)
|
||||
.font(.subheadline)
|
||||
.fontWeight(.medium)
|
||||
.foregroundColor(.primary)
|
||||
|
||||
if isSecure {
|
||||
SecureField(placeholder, text: $text)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
} else {
|
||||
TextField(placeholder, text: $text)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,7 @@ struct ModelCardRowView: View {
|
||||
var deleteAction: () -> Void
|
||||
var setDefaultAction: () -> Void
|
||||
var downloadAction: () -> Void
|
||||
var editAction: ((CustomCloudModel) -> Void)?
|
||||
|
||||
var body: some View {
|
||||
Group {
|
||||
@ -37,6 +38,16 @@ struct ModelCardRowView: View {
|
||||
setDefaultAction: setDefaultAction
|
||||
)
|
||||
}
|
||||
case .custom:
|
||||
if let customModel = model as? CustomCloudModel {
|
||||
CustomModelCardView(
|
||||
model: customModel,
|
||||
isCurrent: isCurrent,
|
||||
setDefaultAction: setDefaultAction,
|
||||
deleteAction: deleteAction,
|
||||
editAction: editAction ?? { _ in }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -576,4 +587,132 @@ private func performanceColor(value: Double) -> Color {
|
||||
case 0.4..<0.6: return Color(.systemOrange)
|
||||
default: return Color(.systemRed)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Custom Model Card View
|
||||
struct CustomModelCardView: View {
|
||||
let model: CustomCloudModel
|
||||
let isCurrent: Bool
|
||||
var setDefaultAction: () -> Void
|
||||
var deleteAction: () -> Void
|
||||
var editAction: (CustomCloudModel) -> Void
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 0) {
|
||||
// Main card content
|
||||
HStack(alignment: .top, spacing: 16) {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
headerSection
|
||||
metadataSection
|
||||
descriptionSection
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
|
||||
actionSection
|
||||
}
|
||||
.padding(16)
|
||||
}
|
||||
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
|
||||
}
|
||||
|
||||
private var headerSection: some View {
|
||||
HStack(alignment: .firstTextBaseline) {
|
||||
Text(model.displayName)
|
||||
.font(.system(size: 13, weight: .semibold))
|
||||
.foregroundColor(Color(.labelColor))
|
||||
|
||||
statusBadge
|
||||
|
||||
Spacer()
|
||||
}
|
||||
}
|
||||
|
||||
private var statusBadge: some View {
|
||||
Group {
|
||||
if isCurrent {
|
||||
Text("Default")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color.accentColor))
|
||||
.foregroundColor(.white)
|
||||
} else {
|
||||
Text("Custom")
|
||||
.font(.system(size: 11, weight: .medium))
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Capsule().fill(Color.orange.opacity(0.2)))
|
||||
.foregroundColor(Color.orange)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var metadataSection: some View {
|
||||
HStack(spacing: 12) {
|
||||
// Provider
|
||||
Label("Custom Provider", systemImage: "cloud")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(1)
|
||||
|
||||
// Language
|
||||
Label(model.language, systemImage: "globe")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(1)
|
||||
|
||||
// OpenAI Compatible
|
||||
Label("OpenAI Compatible", systemImage: "checkmark.seal")
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(1)
|
||||
}
|
||||
.lineLimit(1)
|
||||
}
|
||||
|
||||
private var descriptionSection: some View {
|
||||
Text(model.description)
|
||||
.font(.system(size: 11))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
.lineLimit(2)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
.padding(.top, 4)
|
||||
}
|
||||
|
||||
private var actionSection: some View {
|
||||
HStack(spacing: 8) {
|
||||
if isCurrent {
|
||||
Text("Default Model")
|
||||
.font(.system(size: 12))
|
||||
.foregroundColor(Color(.secondaryLabelColor))
|
||||
} else {
|
||||
Button(action: setDefaultAction) {
|
||||
Text("Set as Default")
|
||||
.font(.system(size: 12))
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
|
||||
Menu {
|
||||
Button {
|
||||
editAction(model)
|
||||
} label: {
|
||||
Label("Edit Model", systemImage: "pencil")
|
||||
}
|
||||
|
||||
Button(role: .destructive) {
|
||||
deleteAction()
|
||||
} label: {
|
||||
Label("Delete Model", systemImage: "trash")
|
||||
}
|
||||
} label: {
|
||||
Image(systemName: "ellipsis.circle")
|
||||
.font(.system(size: 14))
|
||||
}
|
||||
.menuStyle(.borderlessButton)
|
||||
.menuIndicator(.hidden)
|
||||
.frame(width: 20, height: 20)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,10 +4,14 @@ import SwiftData
|
||||
struct ModelManagementView: View {
|
||||
@ObservedObject var whisperState: WhisperState
|
||||
@State private var modelToDelete: WhisperModel?
|
||||
@State private var customModelToDelete: CustomCloudModel?
|
||||
@State private var customModelToEdit: CustomCloudModel?
|
||||
@StateObject private var aiService = AIService()
|
||||
@StateObject private var customModelManager = CustomModelManager.shared
|
||||
@EnvironmentObject private var enhancementService: AIEnhancementService
|
||||
@Environment(\.modelContext) private var modelContext
|
||||
@StateObject private var whisperPrompt = WhisperPrompt()
|
||||
|
||||
|
||||
var body: some View {
|
||||
ScrollView {
|
||||
@ -32,6 +36,19 @@ struct ModelManagementView: View {
|
||||
secondaryButton: .cancel()
|
||||
)
|
||||
}
|
||||
.alert(item: $customModelToDelete) { model in
|
||||
Alert(
|
||||
title: Text("Delete Custom Model"),
|
||||
message: Text("Are you sure you want to delete the custom model '\(model.displayName)'?"),
|
||||
primaryButton: .destructive(Text("Delete")) {
|
||||
customModelManager.removeCustomModel(withId: model.id)
|
||||
// Update whisperState to refresh the UI
|
||||
whisperState.refreshAllAvailableModels()
|
||||
},
|
||||
secondaryButton: .cancel()
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private var defaultModelSection: some View {
|
||||
@ -76,7 +93,9 @@ struct ModelManagementView: View {
|
||||
downloadProgress: whisperState.downloadProgress,
|
||||
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
|
||||
deleteAction: {
|
||||
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
if model.provider == .custom, let customModel = model as? CustomCloudModel {
|
||||
customModelToDelete = customModel
|
||||
} else if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
|
||||
modelToDelete = downloadedModel
|
||||
}
|
||||
},
|
||||
@ -91,9 +110,22 @@ struct ModelManagementView: View {
|
||||
await whisperState.downloadModel(localModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
editAction: model.provider == .custom ? { customModel in
|
||||
customModelToEdit = customModel
|
||||
} : nil
|
||||
)
|
||||
}
|
||||
|
||||
// Add Custom Model Card at the bottom
|
||||
AddCustomModelCardView(
|
||||
customModelManager: customModelManager,
|
||||
editingModel: customModelToEdit
|
||||
) {
|
||||
// Refresh the models when a new custom model is added
|
||||
whisperState.refreshAllAvailableModels()
|
||||
customModelToEdit = nil // Clear editing state
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
|
||||
@ -15,6 +15,9 @@ extension WhisperState {
|
||||
case .deepgram:
|
||||
let key = UserDefaults.standard.string(forKey: "DeepgramAPIKey")
|
||||
return key != nil && !key!.isEmpty
|
||||
case .custom:
|
||||
// Custom models are always usable since they contain their own API keys
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,6 +114,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
createRecordingsDirectoryIfNeeded()
|
||||
loadAvailableModels()
|
||||
loadCurrentTranscriptionModel()
|
||||
refreshAllAvailableModels()
|
||||
}
|
||||
|
||||
private func createRecordingsDirectoryIfNeeded() {
|
||||
@ -512,6 +513,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
func getEnhancementService() -> AIEnhancementService? {
|
||||
return enhancementService
|
||||
}
|
||||
|
||||
func refreshAllAvailableModels() {
|
||||
allAvailableModels = PredefinedModels.models
|
||||
}
|
||||
}
|
||||
|
||||
struct WhisperModel: Identifiable {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user