Custom Trancription model with openAI compatibility

This commit is contained in:
Beingpax 2025-06-17 13:50:22 +05:45
parent f0dee3069b
commit 05cfbd1a0b
12 changed files with 761 additions and 13 deletions

View File

@ -27,15 +27,19 @@ import Foundation
}
enum PredefinedModels {
static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] {
if !isMultilingual {
return ["en": "English"]
} else {
return allLanguages
}
}
static let models: [any TranscriptionModel] = [
static func getLanguageDictionary(isMultilingual: Bool) -> [String: String] {
if !isMultilingual {
return ["en": "English"]
} else {
return allLanguages
}
}
static var models: [any TranscriptionModel] {
return predefinedModels + CustomModelManager.shared.customModels
}
private static let predefinedModels: [any TranscriptionModel] = [
// Local Models
LocalModel(
name: "ggml-tiny",

View File

@ -6,6 +6,7 @@ enum ModelProvider: String, Codable, Hashable, CaseIterable {
case groq = "Groq"
case elevenLabs = "ElevenLabs"
case deepgram = "Deepgram"
case custom = "Custom"
// Future providers can be added here
}
@ -55,4 +56,30 @@ struct CloudModel: TranscriptionModel {
self.isMultilingualModel = isMultilingual
self.supportedLanguages = supportedLanguages
}
}
// A new struct for custom cloud models
struct CustomCloudModel: TranscriptionModel, Codable {
let id: UUID
let name: String
let displayName: String
let description: String
let provider: ModelProvider = .custom
let apiEndpoint: String
let apiKey: String
let modelName: String
let isMultilingualModel: Bool
let supportedLanguages: [String: String]
init(id: UUID = UUID(), name: String, displayName: String, description: String, apiEndpoint: String, apiKey: String, modelName: String, isMultilingual: Bool = true, supportedLanguages: [String: String]? = nil) {
self.id = id
self.name = name
self.displayName = displayName
self.description = description
self.apiEndpoint = apiEndpoint
self.apiKey = apiKey
self.modelName = modelName
self.isMultilingualModel = isMultilingual
self.supportedLanguages = supportedLanguages ?? PredefinedModels.getLanguageDictionary(isMultilingual: isMultilingual)
}
}

View File

@ -487,7 +487,7 @@ struct ConfigurationView: View {
.frame(maxWidth: .infinity, alignment: .leading)
} else {
Picker("", selection: providerBinding) {
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs }, id: \.self) { provider in
ForEach(aiService.connectedProviders.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in
Text(provider.rawValue).tag(provider)
}
}

View File

@ -38,6 +38,7 @@ class CloudTranscriptionService: TranscriptionService {
private let groqService = GroqTranscriptionService()
private let elevenLabsService = ElevenLabsTranscriptionService()
private let deepgramService = DeepgramTranscriptionService()
private let openAICompatibleService = OpenAICompatibleTranscriptionService()
func transcribe(audioURL: URL, model: any TranscriptionModel) async throws -> String {
switch model.provider {
@ -47,6 +48,11 @@ class CloudTranscriptionService: TranscriptionService {
return try await elevenLabsService.transcribe(audioURL: audioURL, model: model)
case .deepgram:
return try await deepgramService.transcribe(audioURL: audioURL, model: model)
case .custom:
guard let customModel = model as? CustomCloudModel else {
throw CloudTranscriptionError.unsupportedProvider
}
return try await openAICompatibleService.transcribe(audioURL: audioURL, model: customModel)
default:
throw CloudTranscriptionError.unsupportedProvider
}

View File

@ -0,0 +1,104 @@
import Foundation
import os
class OpenAICompatibleTranscriptionService {
private let logger = Logger(subsystem: "com.voiceink.transcription", category: "OpenAICompatibleService")
func transcribe(audioURL: URL, model: CustomCloudModel) async throws -> String {
let config = APIConfig(
url: URL(string: model.apiEndpoint)!,
apiKey: model.apiKey,
modelName: model.modelName
)
let boundary = "Boundary-\(UUID().uuidString)"
var request = URLRequest(url: config.url)
request.httpMethod = "POST"
request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization")
let body = try createOpenAICompatibleRequestBody(audioURL: audioURL, modelName: config.modelName, boundary: boundary)
let (data, response) = try await URLSession.shared.upload(for: request, from: body)
guard let httpResponse = response as? HTTPURLResponse else {
throw CloudTranscriptionError.networkError(URLError(.badServerResponse))
}
if !(200...299).contains(httpResponse.statusCode) {
let errorMessage = String(data: data, encoding: .utf8) ?? "No error message"
logger.error("OpenAI-compatible API request failed with status \(httpResponse.statusCode): \(errorMessage, privacy: .public)")
throw CloudTranscriptionError.apiRequestFailed(statusCode: httpResponse.statusCode, message: errorMessage)
}
do {
let transcriptionResponse = try JSONDecoder().decode(TranscriptionResponse.self, from: data)
return transcriptionResponse.text
} catch {
logger.error("Failed to decode OpenAI-compatible API response: \(error.localizedDescription)")
throw CloudTranscriptionError.noTranscriptionReturned
}
}
private func createOpenAICompatibleRequestBody(audioURL: URL, modelName: String, boundary: String) throws -> Data {
var body = Data()
let crlf = "\r\n"
guard let audioData = try? Data(contentsOf: audioURL) else {
throw CloudTranscriptionError.audioFileNotFound
}
let selectedLanguage = UserDefaults.standard.string(forKey: "SelectedLanguage") ?? "auto"
let prompt = UserDefaults.standard.string(forKey: "TranscriptionPrompt") ?? ""
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"file\"; filename=\"\(audioURL.lastPathComponent)\"\(crlf)".data(using: .utf8)!)
body.append("Content-Type: audio/wav\(crlf)\(crlf)".data(using: .utf8)!)
body.append(audioData)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"model\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(modelName.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
if selectedLanguage != "auto", !selectedLanguage.isEmpty {
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"language\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(selectedLanguage.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
}
// Include prompt for OpenAI-compatible APIs
if !prompt.isEmpty {
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"prompt\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append(prompt.data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
}
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"response_format\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append("json".data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)\(crlf)".data(using: .utf8)!)
body.append("Content-Disposition: form-data; name=\"temperature\"\(crlf)\(crlf)".data(using: .utf8)!)
body.append("0".data(using: .utf8)!)
body.append(crlf.data(using: .utf8)!)
body.append("--\(boundary)--\(crlf)".data(using: .utf8)!)
return body
}
private struct APIConfig {
let url: URL
let apiKey: String
let modelName: String
}
private struct TranscriptionResponse: Decodable {
let text: String
let language: String?
let duration: Double?
}
}

View File

@ -0,0 +1,138 @@
import Foundation
import os
class CustomModelManager: ObservableObject {
static let shared = CustomModelManager()
private let logger = Logger(subsystem: "com.prakashjoshipax.voiceink", category: "CustomModelManager")
private let userDefaults = UserDefaults.standard
private let customModelsKey = "customCloudModels"
@Published var customModels: [CustomCloudModel] = []
private init() {
loadCustomModels()
}
// MARK: - CRUD Operations
func addCustomModel(_ model: CustomCloudModel) {
customModels.append(model)
saveCustomModels()
logger.info("Added custom model: \(model.displayName)")
}
func removeCustomModel(withId id: UUID) {
customModels.removeAll { $0.id == id }
saveCustomModels()
logger.info("Removed custom model with ID: \(id)")
}
func updateCustomModel(_ updatedModel: CustomCloudModel) {
if let index = customModels.firstIndex(where: { $0.id == updatedModel.id }) {
customModels[index] = updatedModel
saveCustomModels()
logger.info("Updated custom model: \(updatedModel.displayName)")
}
}
// MARK: - Persistence
private func loadCustomModels() {
guard let data = userDefaults.data(forKey: customModelsKey) else {
logger.info("No custom models found in UserDefaults")
return
}
do {
customModels = try JSONDecoder().decode([CustomCloudModel].self, from: data)
} catch {
logger.error("Failed to decode custom models: \(error.localizedDescription)")
customModels = []
}
}
private func saveCustomModels() {
do {
let data = try JSONEncoder().encode(customModels)
userDefaults.set(data, forKey: customModelsKey)
} catch {
logger.error("Failed to encode custom models: \(error.localizedDescription)")
}
}
// MARK: - Validation
func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String) -> [String] {
var errors: [String] = []
if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Name cannot be empty")
}
if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Display name cannot be empty")
}
if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("API endpoint cannot be empty")
} else if !isValidURL(apiEndpoint) {
errors.append("API endpoint must be a valid URL")
}
if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("API key cannot be empty")
}
if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Model name cannot be empty")
}
// Check for duplicate names
if customModels.contains(where: { $0.name == name }) {
errors.append("A model with this name already exists")
}
return errors
}
func validateModel(name: String, displayName: String, apiEndpoint: String, apiKey: String, modelName: String, excludingId: UUID? = nil) -> [String] {
var errors: [String] = []
if name.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Name cannot be empty")
}
if displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Display name cannot be empty")
}
if apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("API endpoint cannot be empty")
} else if !isValidURL(apiEndpoint) {
errors.append("API endpoint must be a valid URL")
}
if apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("API key cannot be empty")
}
if modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
errors.append("Model name cannot be empty")
}
// Check for duplicate names, excluding the specified ID
if customModels.contains(where: { $0.name == name && $0.id != excludingId }) {
errors.append("A model with this name already exists")
}
return errors
}
private func isValidURL(_ string: String) -> Bool {
if let url = URL(string: string) {
return url.scheme != nil && url.host != nil
}
return false
}
}

View File

@ -44,7 +44,7 @@ struct APIKeyManagementView: View {
// Provider Selection
Picker("AI Provider", selection: $aiService.selectedProvider) {
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs }, id: \.self) { provider in
ForEach(AIProvider.allCases.filter { $0 != .elevenLabs && $0 != .deepgram }, id: \.self) { provider in
Text(provider.rawValue).tag(provider)
}
}

View File

@ -0,0 +1,290 @@
import SwiftUI
struct AddCustomModelCardView: View {
@ObservedObject var customModelManager: CustomModelManager
var onModelAdded: () -> Void
var editingModel: CustomCloudModel? = nil
@State private var isExpanded = false
@State private var displayName = ""
@State private var apiEndpoint = ""
@State private var apiKey = ""
@State private var modelName = ""
@State private var isMultilingual = true
@State private var validationErrors: [String] = []
@State private var showingAlert = false
@State private var isSaving = false
var body: some View {
VStack(spacing: 0) {
// Simple Add Model Button
if !isExpanded {
Button(action: {
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded = true
// Pre-fill values - either from editing model or defaults
if let editing = editingModel {
displayName = editing.displayName
apiEndpoint = editing.apiEndpoint
apiKey = editing.apiKey
modelName = editing.modelName
isMultilingual = editing.isMultilingualModel
} else {
// Pre-fill some default values when adding new
if apiEndpoint.isEmpty {
apiEndpoint = "https://api.example.com/v1/audio/transcriptions"
}
if modelName.isEmpty {
modelName = "large-v3-turbo"
}
}
}
}) {
HStack(spacing: 8) {
Image(systemName: "plus")
.font(.system(size: 14, weight: .medium))
Text(editingModel != nil ? "Edit Model" : "Add Model")
.font(.system(size: 14, weight: .semibold))
}
.foregroundColor(.white)
.frame(maxWidth: .infinity)
.padding(.vertical, 12)
.background(Color.accentColor)
.cornerRadius(12)
}
.buttonStyle(.plain)
.shadow(color: Color.accentColor.opacity(0.3), radius: 8, y: 4)
}
// Expandable Form Section
if isExpanded {
VStack(alignment: .leading, spacing: 20) {
// Header
HStack {
Text(editingModel != nil ? "Edit Custom Model" : "Add Custom Model")
.font(.headline)
.foregroundColor(.primary)
Spacer()
Button(action: {
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded = false
clearForm()
}
}) {
Image(systemName: "xmark")
.font(.system(size: 12, weight: .medium))
.foregroundColor(.secondary)
}
.buttonStyle(.plain)
}
// Disclaimer
HStack(spacing: 8) {
Image(systemName: "exclamationmark.triangle.fill")
.foregroundColor(.orange)
.font(.caption)
Text("Only OpenAI-compatible transcription APIs are supported")
.font(.caption)
.foregroundColor(.secondary)
}
.padding(.horizontal, 12)
.padding(.vertical, 8)
.background(Color.orange.opacity(0.1))
.cornerRadius(8)
// Form fields
VStack(alignment: .leading, spacing: 16) {
FormField(title: "Display Name", text: $displayName, placeholder: "My Custom Model")
FormField(title: "API Endpoint", text: $apiEndpoint, placeholder: "https://api.example.com/v1/audio/transcriptions")
FormField(title: "API Key", text: $apiKey, placeholder: "your-api-key", isSecure: true)
FormField(title: "Model Name", text: $modelName, placeholder: "whisper-1")
Toggle("Multilingual Model", isOn: $isMultilingual)
}
// Action buttons
HStack(spacing: 12) {
Button(action: {
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded = false
clearForm()
}
}) {
Text("Cancel")
.font(.system(size: 13, weight: .medium))
.foregroundColor(.secondary)
.frame(maxWidth: .infinity)
.padding(.vertical, 10)
.background(Color.secondary.opacity(0.1))
.cornerRadius(8)
}
.buttonStyle(.plain)
Button(action: {
addModel()
}) {
HStack(spacing: 6) {
if isSaving {
ProgressView()
.scaleEffect(0.8)
.frame(width: 14, height: 14)
} else {
Image(systemName: editingModel != nil ? "checkmark.circle.fill" : "plus.circle.fill")
.font(.system(size: 14))
}
Text(editingModel != nil ? "Update Model" : "Add Model")
.font(.system(size: 13, weight: .medium))
}
.foregroundColor(.white)
.frame(maxWidth: .infinity)
.padding(.vertical, 10)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(isFormValid ? Color(.controlAccentColor) : Color.secondary)
.shadow(color: (isFormValid ? Color(.controlAccentColor) : Color.secondary).opacity(0.2), radius: 2, x: 0, y: 1)
)
}
.buttonStyle(.plain)
.disabled(!isFormValid || isSaving)
}
}
.padding(16)
.background(
RoundedRectangle(cornerRadius: 12)
.fill(Color(.windowBackgroundColor))
.overlay(
RoundedRectangle(cornerRadius: 12)
.stroke(Color(.separatorColor), lineWidth: 1)
)
)
}
}
.alert("Validation Errors", isPresented: $showingAlert) {
Button("OK") { }
} message: {
Text(validationErrors.joined(separator: "\n"))
}
.onChange(of: editingModel) { oldValue, newValue in
if newValue != nil {
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded = true
// Pre-fill values from editing model
if let editing = newValue {
displayName = editing.displayName
apiEndpoint = editing.apiEndpoint
apiKey = editing.apiKey
modelName = editing.modelName
isMultilingual = editing.isMultilingualModel
}
}
}
}
}
private var isFormValid: Bool {
!displayName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
!apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
!apiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty &&
!modelName.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
}
private func clearForm() {
displayName = ""
apiEndpoint = ""
apiKey = ""
modelName = ""
isMultilingual = true
}
private func addModel() {
let trimmedDisplayName = displayName.trimmingCharacters(in: .whitespacesAndNewlines)
let trimmedApiEndpoint = apiEndpoint.trimmingCharacters(in: .whitespacesAndNewlines)
let trimmedApiKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines)
let trimmedModelName = modelName.trimmingCharacters(in: .whitespacesAndNewlines)
// Generate a name from display name (lowercase, no spaces)
let generatedName = trimmedDisplayName.lowercased().replacingOccurrences(of: " ", with: "-")
validationErrors = customModelManager.validateModel(
name: generatedName,
displayName: trimmedDisplayName,
apiEndpoint: trimmedApiEndpoint,
apiKey: trimmedApiKey,
modelName: trimmedModelName,
excludingId: editingModel?.id
)
if !validationErrors.isEmpty {
showingAlert = true
return
}
isSaving = true
// Simulate a brief save operation for better UX
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
if let editing = editingModel {
// Update existing model
let updatedModel = CustomCloudModel(
id: editing.id,
name: generatedName,
displayName: trimmedDisplayName,
description: "Custom transcription model",
apiEndpoint: trimmedApiEndpoint,
apiKey: trimmedApiKey,
modelName: trimmedModelName,
isMultilingual: isMultilingual
)
customModelManager.updateCustomModel(updatedModel)
} else {
// Add new model
let customModel = CustomCloudModel(
name: generatedName,
displayName: trimmedDisplayName,
description: "Custom transcription model",
apiEndpoint: trimmedApiEndpoint,
apiKey: trimmedApiKey,
modelName: trimmedModelName,
isMultilingual: isMultilingual
)
customModelManager.addCustomModel(customModel)
}
onModelAdded()
// Reset form and collapse
withAnimation(.interpolatingSpring(stiffness: 170, damping: 20)) {
isExpanded = false
clearForm()
isSaving = false
}
}
}
}
struct FormField: View {
let title: String
@Binding var text: String
let placeholder: String
var isSecure: Bool = false
var body: some View {
VStack(alignment: .leading, spacing: 6) {
Text(title)
.font(.subheadline)
.fontWeight(.medium)
.foregroundColor(.primary)
if isSecure {
SecureField(placeholder, text: $text)
.textFieldStyle(.roundedBorder)
} else {
TextField(placeholder, text: $text)
.textFieldStyle(.roundedBorder)
}
}
}
}

View File

@ -12,6 +12,7 @@ struct ModelCardRowView: View {
var deleteAction: () -> Void
var setDefaultAction: () -> Void
var downloadAction: () -> Void
var editAction: ((CustomCloudModel) -> Void)?
var body: some View {
Group {
@ -37,6 +38,16 @@ struct ModelCardRowView: View {
setDefaultAction: setDefaultAction
)
}
case .custom:
if let customModel = model as? CustomCloudModel {
CustomModelCardView(
model: customModel,
isCurrent: isCurrent,
setDefaultAction: setDefaultAction,
deleteAction: deleteAction,
editAction: editAction ?? { _ in }
)
}
}
}
}
@ -576,4 +587,132 @@ private func performanceColor(value: Double) -> Color {
case 0.4..<0.6: return Color(.systemOrange)
default: return Color(.systemRed)
}
}
// MARK: - Custom Model Card View
struct CustomModelCardView: View {
let model: CustomCloudModel
let isCurrent: Bool
var setDefaultAction: () -> Void
var deleteAction: () -> Void
var editAction: (CustomCloudModel) -> Void
var body: some View {
VStack(alignment: .leading, spacing: 0) {
// Main card content
HStack(alignment: .top, spacing: 16) {
VStack(alignment: .leading, spacing: 6) {
headerSection
metadataSection
descriptionSection
}
.frame(maxWidth: .infinity, alignment: .leading)
actionSection
}
.padding(16)
}
.background(CardBackground(isSelected: isCurrent, useAccentGradientWhenSelected: isCurrent))
}
private var headerSection: some View {
HStack(alignment: .firstTextBaseline) {
Text(model.displayName)
.font(.system(size: 13, weight: .semibold))
.foregroundColor(Color(.labelColor))
statusBadge
Spacer()
}
}
private var statusBadge: some View {
Group {
if isCurrent {
Text("Default")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color.accentColor))
.foregroundColor(.white)
} else {
Text("Custom")
.font(.system(size: 11, weight: .medium))
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Capsule().fill(Color.orange.opacity(0.2)))
.foregroundColor(Color.orange)
}
}
}
private var metadataSection: some View {
HStack(spacing: 12) {
// Provider
Label("Custom Provider", systemImage: "cloud")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(1)
// Language
Label(model.language, systemImage: "globe")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(1)
// OpenAI Compatible
Label("OpenAI Compatible", systemImage: "checkmark.seal")
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(1)
}
.lineLimit(1)
}
private var descriptionSection: some View {
Text(model.description)
.font(.system(size: 11))
.foregroundColor(Color(.secondaryLabelColor))
.lineLimit(2)
.fixedSize(horizontal: false, vertical: true)
.padding(.top, 4)
}
private var actionSection: some View {
HStack(spacing: 8) {
if isCurrent {
Text("Default Model")
.font(.system(size: 12))
.foregroundColor(Color(.secondaryLabelColor))
} else {
Button(action: setDefaultAction) {
Text("Set as Default")
.font(.system(size: 12))
}
.buttonStyle(.bordered)
.controlSize(.small)
}
Menu {
Button {
editAction(model)
} label: {
Label("Edit Model", systemImage: "pencil")
}
Button(role: .destructive) {
deleteAction()
} label: {
Label("Delete Model", systemImage: "trash")
}
} label: {
Image(systemName: "ellipsis.circle")
.font(.system(size: 14))
}
.menuStyle(.borderlessButton)
.menuIndicator(.hidden)
.frame(width: 20, height: 20)
}
}
}

View File

@ -4,10 +4,14 @@ import SwiftData
struct ModelManagementView: View {
@ObservedObject var whisperState: WhisperState
@State private var modelToDelete: WhisperModel?
@State private var customModelToDelete: CustomCloudModel?
@State private var customModelToEdit: CustomCloudModel?
@StateObject private var aiService = AIService()
@StateObject private var customModelManager = CustomModelManager.shared
@EnvironmentObject private var enhancementService: AIEnhancementService
@Environment(\.modelContext) private var modelContext
@StateObject private var whisperPrompt = WhisperPrompt()
var body: some View {
ScrollView {
@ -32,6 +36,19 @@ struct ModelManagementView: View {
secondaryButton: .cancel()
)
}
.alert(item: $customModelToDelete) { model in
Alert(
title: Text("Delete Custom Model"),
message: Text("Are you sure you want to delete the custom model '\(model.displayName)'?"),
primaryButton: .destructive(Text("Delete")) {
customModelManager.removeCustomModel(withId: model.id)
// Update whisperState to refresh the UI
whisperState.refreshAllAvailableModels()
},
secondaryButton: .cancel()
)
}
}
private var defaultModelSection: some View {
@ -76,7 +93,9 @@ struct ModelManagementView: View {
downloadProgress: whisperState.downloadProgress,
modelURL: whisperState.availableModels.first { $0.name == model.name }?.url,
deleteAction: {
if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
if model.provider == .custom, let customModel = model as? CustomCloudModel {
customModelToDelete = customModel
} else if let downloadedModel = whisperState.availableModels.first(where: { $0.name == model.name }) {
modelToDelete = downloadedModel
}
},
@ -91,9 +110,22 @@ struct ModelManagementView: View {
await whisperState.downloadModel(localModel)
}
}
}
},
editAction: model.provider == .custom ? { customModel in
customModelToEdit = customModel
} : nil
)
}
// Add Custom Model Card at the bottom
AddCustomModelCardView(
customModelManager: customModelManager,
editingModel: customModelToEdit
) {
// Refresh the models when a new custom model is added
whisperState.refreshAllAvailableModels()
customModelToEdit = nil // Clear editing state
}
}
}
.padding()

View File

@ -15,6 +15,9 @@ extension WhisperState {
case .deepgram:
let key = UserDefaults.standard.string(forKey: "DeepgramAPIKey")
return key != nil && !key!.isEmpty
case .custom:
// Custom models are always usable since they contain their own API keys
return true
}
}
}

View File

@ -114,6 +114,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
createRecordingsDirectoryIfNeeded()
loadAvailableModels()
loadCurrentTranscriptionModel()
refreshAllAvailableModels()
}
private func createRecordingsDirectoryIfNeeded() {
@ -512,6 +513,10 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
func getEnhancementService() -> AIEnhancementService? {
return enhancementService
}
func refreshAllAvailableModels() {
allAvailableModels = PredefinedModels.models
}
}
struct WhisperModel: Identifiable {