Skip to content
This repository was archived by the owner on Apr 28, 2026. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 213 additions & 4 deletions src-tauri/swift-permissions/src/speech_bridge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,164 @@ private func compactDiarizedSpeakerIds(_ segments: [DiarizedSegment]) -> [Diariz
}
}

private let embeddingReassignmentMinDurationSeconds: Float = 1.0
private let embeddingReassignmentSampleRate = 16000

private func sliceAudioSamples(
_ audio: [Float],
sampleRate: Int,
startSeconds: Float,
endSeconds: Float
) -> [Float] {
guard !audio.isEmpty, sampleRate > 0 else {
return []
}

let clampedStart = max(0, startSeconds)
let clampedEnd = max(clampedStart, endSeconds)
let startIndex = min(audio.count, max(0, Int(floor(Double(clampedStart) * Double(sampleRate)))))
let endIndex = min(audio.count, max(startIndex, Int(ceil(Double(clampedEnd) * Double(sampleRate)))))
guard endIndex > startIndex else {
return []
}

return Array(audio[startIndex..<endIndex])
}

private func constrainDiarizedSegmentsUsingEmbeddings(
_ segments: [DiarizedSegment],
requestedSpeakerCount: Int?,
audio: [Float],
embeddingModel: WeSpeakerModel
) -> [DiarizedSegment] {
guard
let requestedSpeakerCount,
requestedSpeakerCount > 0,
!segments.isEmpty
else {
return segments
}

if requestedSpeakerCount == 1 {
return segments.map { segment in
DiarizedSegment(
startTime: segment.startTime,
endTime: segment.endTime,
speakerId: 0
)
}
}

let speakerDurations = Dictionary(grouping: segments, by: \.speakerId)
.mapValues { speakerSegments in
speakerSegments.reduce(Float.zero) { partialResult, segment in
partialResult + segment.duration
}
}

if speakerDurations.count <= requestedSpeakerCount {
return compactDiarizedSpeakerIds(segments)
}

let retainedSpeakerIds = Set(
speakerDurations
.sorted { lhs, rhs in
if lhs.value == rhs.value {
return lhs.key < rhs.key
}

return lhs.value > rhs.value
}
.prefix(requestedSpeakerCount)
.map(\.key)
)

var retainedCentroids = [Int: [Float]]()
for speakerId in retainedSpeakerIds {
let speakerSegments = segments.filter { $0.speakerId == speakerId }
let embeddings = speakerSegments.compactMap { segment -> [Float]? in
let samples = sliceAudioSamples(
audio,
sampleRate: embeddingReassignmentSampleRate,
startSeconds: segment.startTime,
endSeconds: segment.endTime
)
guard samples.count >= embeddingReassignmentSampleRate else {
return nil
}

let embedding = embeddingModel.embed(
audio: samples,
sampleRate: embeddingReassignmentSampleRate
)
guard embedding.contains(where: { $0 != 0 }) else {
return nil
}

return embedding
}

let centroid = normalizedEmbeddingCentroid(embeddings)
guard !centroid.isEmpty else { continue }
retainedCentroids[speakerId] = centroid
}

let fallbackSpeakerId = retainedSpeakerIds.min() ?? 0
let remapped = segments.map { segment -> DiarizedSegment in
if retainedSpeakerIds.contains(segment.speakerId) {
return segment
}

let samples = sliceAudioSamples(
audio,
sampleRate: embeddingReassignmentSampleRate,
startSeconds: segment.startTime,
endSeconds: segment.endTime
)
let minSamples = Int(
embeddingReassignmentMinDurationSeconds * Float(embeddingReassignmentSampleRate)
)

if samples.count >= minSamples, !retainedCentroids.isEmpty {
let embedding = embeddingModel.embed(
audio: samples,
sampleRate: embeddingReassignmentSampleRate
)
if embedding.contains(where: { $0 != 0 }) {
var bestSpeakerId = fallbackSpeakerId
var bestSimilarity = -Float.infinity
for (speakerId, centroid) in retainedCentroids {
let similarity = WeSpeakerModel.cosineSimilarity(embedding, centroid)
if similarity > bestSimilarity {
bestSimilarity = similarity
bestSpeakerId = speakerId
}
}

return DiarizedSegment(
startTime: segment.startTime,
endTime: segment.endTime,
speakerId: bestSpeakerId
)
}
}

let retainedSegments = segments.filter { retainedSpeakerIds.contains($0.speakerId) }
let temporalFallback = retainedSegments.min(by: { lhs, rhs in
diarizedSegmentDistance(from: segment, to: lhs)
< diarizedSegmentDistance(from: segment, to: rhs)
})?.speakerId ?? fallbackSpeakerId

return DiarizedSegment(
startTime: segment.startTime,
endTime: segment.endTime,
speakerId: temporalFallback
)
}

return compactDiarizedSpeakerIds(remapped)
}

private func encodeJSON<T: Encodable>(_ value: T) -> String {
guard let data = try? JSONEncoder().encode(value),
let string = String(data: data, encoding: .utf8)
Expand Down Expand Up @@ -1117,6 +1275,8 @@ private final class LiveTranscriptionSession {
private actor DiarizationPipeline {
private var diarizer: SortformerDiarizer?
private var diarizerTask: Task<SortformerDiarizer, Error>?
private var embeddingModel: WeSpeakerModel?
private var embeddingModelTask: Task<WeSpeakerModel, Error>?

func diarizeAudioFile(
atPath path: String,
Expand All @@ -1127,10 +1287,32 @@ private actor DiarizationPipeline {
let audio = try AudioFileLoader.load(url: url, targetSampleRate: 16000)
let diarizer = try await ensureLoaded()
let result = diarizer.diarize(audio: audio, sampleRate: 16000, config: .default)
let segments = constrainDiarizedSegments(
result.segments,
requestedSpeakerCount: requestedSpeakerCount
)

let segments: [DiarizedSegment]
let distinctSortformerSpeakers = Set(result.segments.map(\.speakerId)).count
if let requestedSpeakerCount,
requestedSpeakerCount > 0,
distinctSortformerSpeakers > requestedSpeakerCount
{
if let embeddingModel = try? await ensureEmbeddingModelLoaded() {
segments = constrainDiarizedSegmentsUsingEmbeddings(
result.segments,
requestedSpeakerCount: requestedSpeakerCount,
audio: audio,
embeddingModel: embeddingModel
)
} else {
segments = constrainDiarizedSegments(
result.segments,
requestedSpeakerCount: requestedSpeakerCount
)
}
} else {
segments = constrainDiarizedSegments(
result.segments,
requestedSpeakerCount: requestedSpeakerCount
)
}

return FileDiarizationPayload(
segments: segments.map { segment in
Expand Down Expand Up @@ -1172,6 +1354,33 @@ private actor DiarizationPipeline {
throw error
}
}

private func ensureEmbeddingModelLoaded() async throws -> WeSpeakerModel {
if let embeddingModel {
return embeddingModel
}

if let task = embeddingModelTask {
let model = try await task.value
self.embeddingModel = model
return model
}

let task = Task<WeSpeakerModel, Error> {
try await WeSpeakerModel.fromPretrained(engine: .coreml)
}
embeddingModelTask = task

do {
let model = try await task.value
self.embeddingModel = model
embeddingModelTask = nil
return model
} catch {
embeddingModelTask = nil
throw error
}
}
}

private actor SpeakerEmbeddingPipeline {
Expand Down
Loading