diff --git a/src-tauri/swift-permissions/src/speech_bridge.swift b/src-tauri/swift-permissions/src/speech_bridge.swift index cb5744d..a0009b7 100644 --- a/src-tauri/swift-permissions/src/speech_bridge.swift +++ b/src-tauri/swift-permissions/src/speech_bridge.swift @@ -508,6 +508,164 @@ private func compactDiarizedSpeakerIds(_ segments: [DiarizedSegment]) -> [Diariz } } +private let embeddingReassignmentMinDurationSeconds: Float = 1.0 +private let embeddingReassignmentSampleRate = 16000 + +private func sliceAudioSamples( + _ audio: [Float], + sampleRate: Int, + startSeconds: Float, + endSeconds: Float +) -> [Float] { + guard !audio.isEmpty, sampleRate > 0 else { + return [] + } + + let clampedStart = max(0, startSeconds) + let clampedEnd = max(clampedStart, endSeconds) + let startIndex = min(audio.count, max(0, Int(floor(Double(clampedStart) * Double(sampleRate))))) + let endIndex = min(audio.count, max(startIndex, Int(ceil(Double(clampedEnd) * Double(sampleRate))))) + guard endIndex > startIndex else { + return [] + } + + return Array(audio[startIndex.. [DiarizedSegment] { + guard + let requestedSpeakerCount, + requestedSpeakerCount > 0, + !segments.isEmpty + else { + return segments + } + + if requestedSpeakerCount == 1 { + return segments.map { segment in + DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: 0 + ) + } + } + + let speakerDurations = Dictionary(grouping: segments, by: \.speakerId) + .mapValues { speakerSegments in + speakerSegments.reduce(Float.zero) { partialResult, segment in + partialResult + segment.duration + } + } + + if speakerDurations.count <= requestedSpeakerCount { + return compactDiarizedSpeakerIds(segments) + } + + let retainedSpeakerIds = Set( + speakerDurations + .sorted { lhs, rhs in + if lhs.value == rhs.value { + return lhs.key < rhs.key + } + + return lhs.value > rhs.value + } + .prefix(requestedSpeakerCount) + .map(\.key) + ) + + var retainedCentroids = [Int: [Float]]() + for speakerId in retainedSpeakerIds { + let speakerSegments = segments.filter { $0.speakerId == speakerId } + let embeddings = speakerSegments.compactMap { segment -> [Float]? in + let samples = sliceAudioSamples( + audio, + sampleRate: embeddingReassignmentSampleRate, + startSeconds: segment.startTime, + endSeconds: segment.endTime + ) + guard samples.count >= embeddingReassignmentSampleRate else { + return nil + } + + let embedding = embeddingModel.embed( + audio: samples, + sampleRate: embeddingReassignmentSampleRate + ) + guard embedding.contains(where: { $0 != 0 }) else { + return nil + } + + return embedding + } + + let centroid = normalizedEmbeddingCentroid(embeddings) + guard !centroid.isEmpty else { continue } + retainedCentroids[speakerId] = centroid + } + + let fallbackSpeakerId = retainedSpeakerIds.min() ?? 0 + let remapped = segments.map { segment -> DiarizedSegment in + if retainedSpeakerIds.contains(segment.speakerId) { + return segment + } + + let samples = sliceAudioSamples( + audio, + sampleRate: embeddingReassignmentSampleRate, + startSeconds: segment.startTime, + endSeconds: segment.endTime + ) + let minSamples = Int( + embeddingReassignmentMinDurationSeconds * Float(embeddingReassignmentSampleRate) + ) + + if samples.count >= minSamples, !retainedCentroids.isEmpty { + let embedding = embeddingModel.embed( + audio: samples, + sampleRate: embeddingReassignmentSampleRate + ) + if embedding.contains(where: { $0 != 0 }) { + var bestSpeakerId = fallbackSpeakerId + var bestSimilarity = -Float.infinity + for (speakerId, centroid) in retainedCentroids { + let similarity = WeSpeakerModel.cosineSimilarity(embedding, centroid) + if similarity > bestSimilarity { + bestSimilarity = similarity + bestSpeakerId = speakerId + } + } + + return DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: bestSpeakerId + ) + } + } + + let retainedSegments = segments.filter { retainedSpeakerIds.contains($0.speakerId) } + let temporalFallback = retainedSegments.min(by: { lhs, rhs in + diarizedSegmentDistance(from: segment, to: lhs) + < diarizedSegmentDistance(from: segment, to: rhs) + })?.speakerId ?? fallbackSpeakerId + + return DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: temporalFallback + ) + } + + return compactDiarizedSpeakerIds(remapped) +} + private func encodeJSON(_ value: T) -> String { guard let data = try? JSONEncoder().encode(value), let string = String(data: data, encoding: .utf8) @@ -1117,6 +1275,8 @@ private final class LiveTranscriptionSession { private actor DiarizationPipeline { private var diarizer: SortformerDiarizer? private var diarizerTask: Task? + private var embeddingModel: WeSpeakerModel? + private var embeddingModelTask: Task? func diarizeAudioFile( atPath path: String, @@ -1127,10 +1287,32 @@ private actor DiarizationPipeline { let audio = try AudioFileLoader.load(url: url, targetSampleRate: 16000) let diarizer = try await ensureLoaded() let result = diarizer.diarize(audio: audio, sampleRate: 16000, config: .default) - let segments = constrainDiarizedSegments( - result.segments, - requestedSpeakerCount: requestedSpeakerCount - ) + + let segments: [DiarizedSegment] + let distinctSortformerSpeakers = Set(result.segments.map(\.speakerId)).count + if let requestedSpeakerCount, + requestedSpeakerCount > 0, + distinctSortformerSpeakers > requestedSpeakerCount + { + if let embeddingModel = try? await ensureEmbeddingModelLoaded() { + segments = constrainDiarizedSegmentsUsingEmbeddings( + result.segments, + requestedSpeakerCount: requestedSpeakerCount, + audio: audio, + embeddingModel: embeddingModel + ) + } else { + segments = constrainDiarizedSegments( + result.segments, + requestedSpeakerCount: requestedSpeakerCount + ) + } + } else { + segments = constrainDiarizedSegments( + result.segments, + requestedSpeakerCount: requestedSpeakerCount + ) + } return FileDiarizationPayload( segments: segments.map { segment in @@ -1172,6 +1354,33 @@ private actor DiarizationPipeline { throw error } } + + private func ensureEmbeddingModelLoaded() async throws -> WeSpeakerModel { + if let embeddingModel { + return embeddingModel + } + + if let task = embeddingModelTask { + let model = try await task.value + self.embeddingModel = model + return model + } + + let task = Task { + try await WeSpeakerModel.fromPretrained(engine: .coreml) + } + embeddingModelTask = task + + do { + let model = try await task.value + self.embeddingModel = model + embeddingModelTask = nil + return model + } catch { + embeddingModelTask = nil + throw error + } + } } private actor SpeakerEmbeddingPipeline {