diff --git a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift index d3b65a9..3fe94e9 100644 --- a/VoiceInk/Whisper/WhisperState+LocalModelManager.swift +++ b/VoiceInk/Whisper/WhisperState+LocalModelManager.swift @@ -2,6 +2,7 @@ import Foundation import os import Zip import SwiftUI +import Atomics struct WhisperModel: Identifiable { @@ -34,13 +35,17 @@ struct WhisperModel: Identifiable { private class TaskDelegate: NSObject, URLSessionTaskDelegate { private let continuation: CheckedContinuation - + private let finished = ManagedAtomic(false) + init(_ continuation: CheckedContinuation) { self.continuation = continuation } - + func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - continuation.resume() + // Ensure continuation is resumed only once, even if called multiple times + if finished.exchange(true, ordering: .acquiring) == false { + continuation.resume() + } } } @@ -98,59 +103,72 @@ extension WhisperState { /// Helper function to download a file from a URL with progress tracking private func downloadFileWithProgress(from url: URL, progressKey: String) async throws -> Data { let destinationURL = modelsDirectory.appendingPathComponent(UUID().uuidString) - + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + // Guard to prevent double resume + let finished = ManagedAtomic(false) + + func finishOnce(_ result: Result) { + if finished.exchange(true, ordering: .acquiring) == false { + continuation.resume(with: result) + } + } + let task = URLSession.shared.downloadTask(with: url) { tempURL, response, error in if let error = error { - continuation.resume(throwing: error) + finishOnce(.failure(error)) return } - + guard let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode), let tempURL = tempURL else { - continuation.resume(throwing: URLError(.badServerResponse)) + finishOnce(.failure(URLError(.badServerResponse))) return } - + do { // Move the downloaded file to the final destination try FileManager.default.moveItem(at: tempURL, to: destinationURL) - + // Read the file in chunks to avoid memory pressure let data = try Data(contentsOf: destinationURL, options: .mappedIfSafe) - continuation.resume(returning: data) - + finishOnce(.success(data)) + // Clean up the temporary file try? FileManager.default.removeItem(at: destinationURL) } catch { - continuation.resume(throwing: error) + finishOnce(.failure(error)) } } - + task.resume() - + var lastUpdateTime = Date() var lastProgressValue: Double = 0 - + let observation = task.progress.observe(\.fractionCompleted) { progress, _ in let currentTime = Date() let timeSinceLastUpdate = currentTime.timeIntervalSince(lastUpdateTime) let currentProgress = round(progress.fractionCompleted * 100) / 100 - + if timeSinceLastUpdate >= 0.5 && abs(currentProgress - lastProgressValue) >= 0.01 { lastUpdateTime = currentTime lastProgressValue = currentProgress - + DispatchQueue.main.async { self.downloadProgress[progressKey] = currentProgress } } } - + Task { await withTaskCancellationHandler { observation.invalidate() + // Also ensure continuation is resumed with cancellation if task is cancelled + if finished.exchange(true, ordering: .acquiring) == false { + continuation.resume(throwing: CancellationError()) + } } operation: { await withCheckedContinuation { (_: CheckedContinuation) in } } @@ -211,13 +229,21 @@ extension WhisperState { } private func unzipCoreMLFile(_ zipPath: URL, to destination: URL) async throws { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let finished = ManagedAtomic(false) + + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + func finishOnce(_ result: Result) { + if finished.exchange(true, ordering: .acquiring) == false { + continuation.resume(with: result) + } + } + do { try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true) try Zip.unzipFile(zipPath, destination: destination, overwrite: true, password: nil) - continuation.resume() + finishOnce(.success(())) } catch { - continuation.resume(throwing: error) + finishOnce(.failure(error)) } } }