From 7929fc1b8b88798417d860c09a94cd8417402c09 Mon Sep 17 00:00:00 2001 From: ComputelessComputer <63365510+ComputelessComputer@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:41:31 +0900 Subject: [PATCH] Reassign excess diarized speakers by embedding similarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the user requests N speakers but Sortformer returns more, the old constrainer reassigned excess segments to the temporally-nearest retained speaker. That ignores voice identity — a short interjection by Alice gets merged into Bob just because Bob was speaking around the same time. The new path loads WeSpeaker (already used by the speaker embedding pipeline) lazily in DiarizationPipeline, builds a centroid embedding per retained speaker, embeds each excess segment, and assigns it to the max-cosine-similarity retained centroid. Segments shorter than 1 s fall back to the old temporal heuristic because WeSpeaker is unreliable on very short clips. Only engages when the user set a specific speaker count AND Sortformer's raw output exceeded it. Automatic mode, under-count cases, and short-audio edge cases all take the unchanged fast path. --- .../swift-permissions/src/speech_bridge.swift | 217 +++++++++++++++++- 1 file changed, 213 insertions(+), 4 deletions(-) diff --git a/src-tauri/swift-permissions/src/speech_bridge.swift b/src-tauri/swift-permissions/src/speech_bridge.swift index cb5744d..a0009b7 100644 --- a/src-tauri/swift-permissions/src/speech_bridge.swift +++ b/src-tauri/swift-permissions/src/speech_bridge.swift @@ -508,6 +508,164 @@ private func compactDiarizedSpeakerIds(_ segments: [DiarizedSegment]) -> [Diariz } } +private let embeddingReassignmentMinDurationSeconds: Float = 1.0 +private let embeddingReassignmentSampleRate = 16000 + +private func sliceAudioSamples( + _ audio: [Float], + sampleRate: Int, + startSeconds: Float, + endSeconds: Float +) -> [Float] { + guard !audio.isEmpty, sampleRate > 0 else { + return [] + } + + let clampedStart = max(0, startSeconds) + let clampedEnd = max(clampedStart, endSeconds) + let startIndex = min(audio.count, max(0, Int(floor(Double(clampedStart) * Double(sampleRate))))) + let endIndex = min(audio.count, max(startIndex, Int(ceil(Double(clampedEnd) * Double(sampleRate))))) + guard endIndex > startIndex else { + return [] + } + + return Array(audio[startIndex.. [DiarizedSegment] { + guard + let requestedSpeakerCount, + requestedSpeakerCount > 0, + !segments.isEmpty + else { + return segments + } + + if requestedSpeakerCount == 1 { + return segments.map { segment in + DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: 0 + ) + } + } + + let speakerDurations = Dictionary(grouping: segments, by: \.speakerId) + .mapValues { speakerSegments in + speakerSegments.reduce(Float.zero) { partialResult, segment in + partialResult + segment.duration + } + } + + if speakerDurations.count <= requestedSpeakerCount { + return compactDiarizedSpeakerIds(segments) + } + + let retainedSpeakerIds = Set( + speakerDurations + .sorted { lhs, rhs in + if lhs.value == rhs.value { + return lhs.key < rhs.key + } + + return lhs.value > rhs.value + } + .prefix(requestedSpeakerCount) + .map(\.key) + ) + + var retainedCentroids = [Int: [Float]]() + for speakerId in retainedSpeakerIds { + let speakerSegments = segments.filter { $0.speakerId == speakerId } + let embeddings = speakerSegments.compactMap { segment -> [Float]? in + let samples = sliceAudioSamples( + audio, + sampleRate: embeddingReassignmentSampleRate, + startSeconds: segment.startTime, + endSeconds: segment.endTime + ) + guard samples.count >= embeddingReassignmentSampleRate else { + return nil + } + + let embedding = embeddingModel.embed( + audio: samples, + sampleRate: embeddingReassignmentSampleRate + ) + guard embedding.contains(where: { $0 != 0 }) else { + return nil + } + + return embedding + } + + let centroid = normalizedEmbeddingCentroid(embeddings) + guard !centroid.isEmpty else { continue } + retainedCentroids[speakerId] = centroid + } + + let fallbackSpeakerId = retainedSpeakerIds.min() ?? 0 + let remapped = segments.map { segment -> DiarizedSegment in + if retainedSpeakerIds.contains(segment.speakerId) { + return segment + } + + let samples = sliceAudioSamples( + audio, + sampleRate: embeddingReassignmentSampleRate, + startSeconds: segment.startTime, + endSeconds: segment.endTime + ) + let minSamples = Int( + embeddingReassignmentMinDurationSeconds * Float(embeddingReassignmentSampleRate) + ) + + if samples.count >= minSamples, !retainedCentroids.isEmpty { + let embedding = embeddingModel.embed( + audio: samples, + sampleRate: embeddingReassignmentSampleRate + ) + if embedding.contains(where: { $0 != 0 }) { + var bestSpeakerId = fallbackSpeakerId + var bestSimilarity = -Float.infinity + for (speakerId, centroid) in retainedCentroids { + let similarity = WeSpeakerModel.cosineSimilarity(embedding, centroid) + if similarity > bestSimilarity { + bestSimilarity = similarity + bestSpeakerId = speakerId + } + } + + return DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: bestSpeakerId + ) + } + } + + let retainedSegments = segments.filter { retainedSpeakerIds.contains($0.speakerId) } + let temporalFallback = retainedSegments.min(by: { lhs, rhs in + diarizedSegmentDistance(from: segment, to: lhs) + < diarizedSegmentDistance(from: segment, to: rhs) + })?.speakerId ?? fallbackSpeakerId + + return DiarizedSegment( + startTime: segment.startTime, + endTime: segment.endTime, + speakerId: temporalFallback + ) + } + + return compactDiarizedSpeakerIds(remapped) +} + private func encodeJSON(_ value: T) -> String { guard let data = try? JSONEncoder().encode(value), let string = String(data: data, encoding: .utf8) @@ -1117,6 +1275,8 @@ private final class LiveTranscriptionSession { private actor DiarizationPipeline { private var diarizer: SortformerDiarizer? private var diarizerTask: Task? + private var embeddingModel: WeSpeakerModel? + private var embeddingModelTask: Task? func diarizeAudioFile( atPath path: String, @@ -1127,10 +1287,32 @@ private actor DiarizationPipeline { let audio = try AudioFileLoader.load(url: url, targetSampleRate: 16000) let diarizer = try await ensureLoaded() let result = diarizer.diarize(audio: audio, sampleRate: 16000, config: .default) - let segments = constrainDiarizedSegments( - result.segments, - requestedSpeakerCount: requestedSpeakerCount - ) + + let segments: [DiarizedSegment] + let distinctSortformerSpeakers = Set(result.segments.map(\.speakerId)).count + if let requestedSpeakerCount, + requestedSpeakerCount > 0, + distinctSortformerSpeakers > requestedSpeakerCount + { + if let embeddingModel = try? await ensureEmbeddingModelLoaded() { + segments = constrainDiarizedSegmentsUsingEmbeddings( + result.segments, + requestedSpeakerCount: requestedSpeakerCount, + audio: audio, + embeddingModel: embeddingModel + ) + } else { + segments = constrainDiarizedSegments( + result.segments, + requestedSpeakerCount: requestedSpeakerCount + ) + } + } else { + segments = constrainDiarizedSegments( + result.segments, + requestedSpeakerCount: requestedSpeakerCount + ) + } return FileDiarizationPayload( segments: segments.map { segment in @@ -1172,6 +1354,33 @@ private actor DiarizationPipeline { throw error } } + + private func ensureEmbeddingModelLoaded() async throws -> WeSpeakerModel { + if let embeddingModel { + return embeddingModel + } + + if let task = embeddingModelTask { + let model = try await task.value + self.embeddingModel = model + return model + } + + let task = Task { + try await WeSpeakerModel.fromPretrained(engine: .coreml) + } + embeddingModelTask = task + + do { + let model = try await task.value + self.embeddingModel = model + embeddingModelTask = nil + return model + } catch { + embeddingModelTask = nil + throw error + } + } } private actor SpeakerEmbeddingPipeline {