From 46be017e9f4b076f2d0842cf78175ac42d894b0a Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Fri, 5 Sep 2025 19:55:03 -0700 Subject: [PATCH 1/9] Add bidirectional attention support to Gemma3Text - Add useBidirectionalAttention config parameter - Apply sliding window size adjustment for bidirectional mode - Implement createBidirectionalSlidingWindowMask function - Update mask creation logic to support both causal and bidirectional attention - Based on patches 40694 and 40700 for EmbeddingGemma support --- Libraries/MLXLLM/Models/Gemma3Text.swift | 50 ++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index eb953ec4..bc741eb7 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -14,6 +14,24 @@ import MLXLLM import MLXLMCommon import MLXNN +/// Create a bidirectional sliding window mask where tokens can attend to others within the sliding window distance +func createBidirectionalSlidingWindowMask( + n: Int, + offset: Int, + windowSize: Int +) -> MLXArray { + let rinds = MLXArray(Int32(0) ..< Int32(offset + n)) + var linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds + linds = linds[0..., .newAxis] + let rindsBcast = rinds[.newAxis] + + // Create mask where abs(q_idx - kv_idx) < windowSize (bidirectional window) + let distance = abs(linds - rindsBcast) + let mask = distance .< windowSize + + return mask +} + public struct Gemma3TextConfiguration: Codable { let modelType: String let hiddenSize: Int @@ -30,6 +48,7 @@ public struct Gemma3TextConfiguration: Codable { let queryPreAttnScalar: Float let slidingWindow: Int let slidingWindowPattern: Int + let useBidirectionalAttention: Bool enum CodingKeys: String, CodingKey { case modelType = "model_type" @@ -47,6 +66,7 @@ public struct Gemma3TextConfiguration: Codable { case queryPreAttnScalar = "query_pre_attn_scalar" case slidingWindow = "sliding_window" case slidingWindowPattern = "sliding_window_pattern" + case useBidirectionalAttention = "use_bidirectional_attention" } enum VLMCodingKeys: String, CodingKey { @@ -82,7 +102,12 @@ public struct Gemma3TextConfiguration: Codable { try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false queryPreAttnScalar = try container.decodeIfPresent(Float.self, forKey: .queryPreAttnScalar) ?? 256 - slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512 + useBidirectionalAttention = + try container.decodeIfPresent(Bool.self, forKey: .useBidirectionalAttention) ?? false + + let rawSlidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512 + // Apply sliding window adjustment for bidirectional attention (from patch: (sliding_window // 2) + 1) + slidingWindow = useBidirectionalAttention ? (rawSlidingWindow / 2) + 1 : rawSlidingWindow slidingWindowPattern = try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6 } @@ -98,6 +123,7 @@ private class Attention: Module { let isSliding: Bool let slidingWindow: Int let slidingWindowPattern: Int + let useBidirectionalAttention: Bool @ModuleInfo(key: "q_proj") var queryProj: Linear @ModuleInfo(key: "k_proj") var keyProj: Linear @@ -118,6 +144,7 @@ private class Attention: Module { self.layerIdx = layerIdx self.slidingWindow = config.slidingWindow self.slidingWindowPattern = config.slidingWindowPattern + self.useBidirectionalAttention = config.useBidirectionalAttention self.scale = pow(config.queryPreAttnScalar, -0.5) @@ -307,9 +334,24 @@ private class Gemma3Model: Module { } else { globalLayerCache = [] } - fullMask = createAttentionMask(h: h, cache: globalLayerCache) - let allCaches = layerCache?.compactMap { $0 } ?? [] - slidingWindowMask = createAttentionMask(h: h, cache: allCaches) + + if config.useBidirectionalAttention { + // For bidirectional attention: full attention for global layers, bidirectional sliding window for others + fullMask = .array(MLXArray.ones([h.dim(1), h.dim(1)], dtype: .bool)) + + let t = h.dim(1) + var offset = 0 + if let cache = layerCache?.compactMap({ $0 }).first { + offset = cache.offset + } + slidingWindowMask = .array(createBidirectionalSlidingWindowMask( + n: t, offset: offset, windowSize: config.slidingWindow)) + } else { + // Standard causal attention + fullMask = createAttentionMask(h: h, cache: globalLayerCache) + let allCaches = layerCache?.compactMap { $0 } ?? [] + slidingWindowMask = createAttentionMask(h: h, cache: allCaches) + } } for (i, layer) in layers.enumerated() { let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1) From 8dc179ccc21b26fb0856016ec9f2b7d5792979e0 Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Fri, 5 Sep 2025 20:00:30 -0700 Subject: [PATCH 2/9] Make createBidirectionalSlidingWindowMask function public for testing --- Libraries/MLXLLM/Models/Gemma3Text.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index bc741eb7..6c1817ec 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -15,7 +15,7 @@ import MLXLMCommon import MLXNN /// Create a bidirectional sliding window mask where tokens can attend to others within the sliding window distance -func createBidirectionalSlidingWindowMask( +public func createBidirectionalSlidingWindowMask( n: Int, offset: Int, windowSize: Int From 733e142542cfaf85ca0304d37f908b176c54edfc Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Fri, 5 Sep 2025 23:30:20 -0700 Subject: [PATCH 3/9] Add getHiddenStates method to Gemma3TextModel for embedding use cases --- Libraries/MLXLLM/Models/Gemma3Text.swift | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index 6c1817ec..adb95e64 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -390,6 +390,11 @@ public class Gemma3TextModel: Module, LLMModel { out = lmHead(out) return out } + + /// Get hidden states before the language modeling head for embedding use cases + public func getHiddenStates(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + return model(inputs, mask: nil, cache: cache) + } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] From 0047c969a15be351af50ed85fec53c043075bece Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Fri, 5 Sep 2025 23:32:46 -0700 Subject: [PATCH 4/9] Make Gemma3TextConfiguration fields public for external testing --- Libraries/MLXLLM/Models/Gemma3Text.swift | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index adb95e64..9db3c09a 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -33,22 +33,22 @@ public func createBidirectionalSlidingWindowMask( } public struct Gemma3TextConfiguration: Codable { - let modelType: String - let hiddenSize: Int - let hiddenLayers: Int - let intermediateSize: Int - let attentionHeads: Int - let headDim: Int - let rmsNormEps: Float - let vocabularySize: Int - let kvHeads: Int - let ropeGlobalBaseFreq: Float - let ropeLocalBaseFreq: Float - let ropeTraditional: Bool - let queryPreAttnScalar: Float - let slidingWindow: Int - let slidingWindowPattern: Int - let useBidirectionalAttention: Bool + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + public let headDim: Int + public let rmsNormEps: Float + public let vocabularySize: Int + public let kvHeads: Int + public let ropeGlobalBaseFreq: Float + public let ropeLocalBaseFreq: Float + public let ropeTraditional: Bool + public let queryPreAttnScalar: Float + public let slidingWindow: Int + public let slidingWindowPattern: Int + public let useBidirectionalAttention: Bool enum CodingKeys: String, CodingKey { case modelType = "model_type" From 878d03ad2033ce52eb30d2dd97de22a0b489dc66 Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Sun, 7 Sep 2025 21:17:12 -0700 Subject: [PATCH 5/9] make gemma text work for embedding gemma --- Libraries/MLXLLM/Models/Gemma3Text.swift | 118 +++++++++++++++-------- 1 file changed, 79 insertions(+), 39 deletions(-) diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index 9db3c09a..77891295 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -32,6 +32,13 @@ public func createBidirectionalSlidingWindowMask( return mask } +func simpleSDPA(queries: MLXArray, keys: MLXArray, values: MLXArray, mask: MLXArray, scale: Float) -> MLXArray { + var attn = matmul(queries, keys.transposed(0, 1, 3, 2)) + attn = attn - (1 - mask) * 1e6 + let weights = softmax(scale * attn, axis:-1) + return matmul(weights, values) +} + public struct Gemma3TextConfiguration: Codable { public let modelType: String public let hiddenSize: Int @@ -73,6 +80,25 @@ public struct Gemma3TextConfiguration: Codable { case textConfig = "text_config" } + public init(modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int, attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int, ropeGlobalBaseFreq: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool, queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int, useBidirectionalAttention: Bool) { + self.modelType = modelType + self.hiddenSize = hiddenSize + self.hiddenLayers = hiddenLayers + self.intermediateSize = intermediateSize + self.attentionHeads = attentionHeads + self.headDim = headDim + self.rmsNormEps = rmsNormEps + self.vocabularySize = vocabularySize + self.kvHeads = kvHeads + self.ropeGlobalBaseFreq = ropeGlobalBaseFreq + self.ropeLocalBaseFreq = ropeLocalBaseFreq + self.ropeTraditional = ropeTraditional + self.queryPreAttnScalar = queryPreAttnScalar + self.slidingWindow = slidingWindow + self.slidingWindowPattern = slidingWindowPattern + self.useBidirectionalAttention = useBidirectionalAttention + } + public init(from decoder: Decoder) throws { let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self) @@ -205,14 +231,17 @@ private class Attention: Module { } } - let output = attentionWithCacheUpdate( - queries: queries, - keys: keys, - values: values, - cache: cache, - scale: scale, - mask: finalMask - ) + let maskArr: MLXArray + if case .array(let maskArray) = finalMask { + maskArr = maskArray + } else { + fatalError("oh noes") + } + let output = simpleSDPA(queries: queries, + keys: keys, + values: values, + mask: maskArr, + scale: scale) .transposed(0, 2, 1, 3) .reshaped(B, L, -1) return outputProj(output) @@ -317,7 +346,7 @@ private class Gemma3Model: Module { { var h: MLXArray h = embedTokens(inputs) - let scale = MLXArray(sqrt(Float(config.hiddenSize)), dtype: .bfloat16) + let scale = MLXArray(sqrt(Float(config.hiddenSize)), dtype: .float32) h = h * scale.asType(h.dtype) var layerCache = cache if layerCache == nil { @@ -326,40 +355,45 @@ private class Gemma3Model: Module { // Create attention masks var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none - if mask == nil { - let j = config.slidingWindowPattern - let globalLayerCache: [KVCache] - if j > 0 && j <= (layerCache?.count ?? 0), let globalCache = layerCache?[j - 1] { - globalLayerCache = [globalCache] - } else { - globalLayerCache = [] + let j = config.slidingWindowPattern + let globalLayerCache: [KVCache] + if j > 0 && j <= (layerCache?.count ?? 0), let globalCache = layerCache?[j - 1] { + globalLayerCache = [globalCache] + } else { + globalLayerCache = [] + } + + if config.useBidirectionalAttention { + // For bidirectional attention: full attention for global layers, bidirectional sliding window for others + var fullMaskArray = MLXArray.ones([h.dim(1), h.dim(1)], dtype: .bool) + if case .array(let maskArray) = mask { + fullMaskArray = fullMaskArray & maskArray } + fullMask = .array(fullMaskArray) - if config.useBidirectionalAttention { - // For bidirectional attention: full attention for global layers, bidirectional sliding window for others - fullMask = .array(MLXArray.ones([h.dim(1), h.dim(1)], dtype: .bool)) - - let t = h.dim(1) - var offset = 0 - if let cache = layerCache?.compactMap({ $0 }).first { - offset = cache.offset - } - slidingWindowMask = .array(createBidirectionalSlidingWindowMask( - n: t, offset: offset, windowSize: config.slidingWindow)) - } else { - // Standard causal attention - fullMask = createAttentionMask(h: h, cache: globalLayerCache) - let allCaches = layerCache?.compactMap { $0 } ?? [] - slidingWindowMask = createAttentionMask(h: h, cache: allCaches) + let t = h.dim(1) + var offset = 0 + if let cache = layerCache?.compactMap({ $0 }).first { + offset = cache.offset } + var slidingWindowMaskArray = createBidirectionalSlidingWindowMask( + n: t, offset: offset, windowSize: config.slidingWindow) + if case .array(let maskArray) = mask { + slidingWindowMaskArray = slidingWindowMaskArray & maskArray + } + slidingWindowMask = .array(slidingWindowMaskArray) + } else { + // Standard causal attention + // TODO: probably need to merge the custom mask in + fullMask = createAttentionMask(h: h, cache: globalLayerCache) + let allCaches = layerCache?.compactMap { $0 } ?? [] + slidingWindowMask = createAttentionMask(h: h, cache: allCaches) } for (i, layer) in layers.enumerated() { let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1) let localMask: MLXFast.ScaledDotProductAttentionMaskMode - if let mask { - localMask = mask - } else if isGlobal { + if isGlobal { localMask = fullMask } else { localMask = slidingWindowMask @@ -385,15 +419,15 @@ public class Gemma3TextModel: Module, LLMModel { super.init() } - public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { - var out = model(inputs, mask: nil, cache: cache) + public func callAsFunction(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + var out = model(inputs, mask: mask, cache: cache) out = lmHead(out) return out } /// Get hidden states before the language modeling head for embedding use cases - public func getHiddenStates(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { - return model(inputs, mask: nil, cache: cache) + public func getHiddenStates(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + return model(inputs, mask: mask, cache: cache) } public func sanitize(weights: [String: MLXArray]) @@ -415,6 +449,12 @@ public class Gemma3TextModel: Module, LLMModel { } } } + + // EmbeddingGemma contains additional 'dense' weights, which + // confuse loading into a normal Gemma3 + processedWeights["dense.0.weight"] = nil + processedWeights["dense.1.weight"] = nil + return processedWeights } From 86bb1265168363cc5096b8df5f82075a5702ef2e Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Thu, 25 Sep 2025 01:41:19 -0700 Subject: [PATCH 6/9] embedding gemma --- Libraries/Embedders/Configuration.swift | 8 ++ Libraries/Embedders/EmbeddingGemma.swift | 104 +++++++++++++++++++++++ Libraries/Embedders/EmbeddingModel.swift | 4 +- Libraries/Embedders/Models.swift | 3 + Libraries/MLXLLM/Models/Gemma3Text.swift | 89 +++++++++++++++---- 5 files changed, 190 insertions(+), 18 deletions(-) create mode 100644 Libraries/Embedders/EmbeddingGemma.swift diff --git a/Libraries/Embedders/Configuration.swift b/Libraries/Embedders/Configuration.swift index 358dcf82..b36021ea 100644 --- a/Libraries/Embedders/Configuration.swift +++ b/Libraries/Embedders/Configuration.swift @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. import Foundation +import MLXLLM public enum StringOrNumber: Codable, Equatable, Sendable { case string(String) @@ -69,6 +70,13 @@ private class ModelTypeRegistry: @unchecked Sendable { let model = NomicBertModel(configuration) return model }, + "EmbeddingGemma": { + url in + let configuration = try JSONDecoder().decode( + Gemma3TextConfiguration.self, from: Data(contentsOf: url)) + let model = EmbeddingGemma(configuration) + return model + }, ] public func registerModelType( diff --git a/Libraries/Embedders/EmbeddingGemma.swift b/Libraries/Embedders/EmbeddingGemma.swift new file mode 100644 index 00000000..da2f5386 --- /dev/null +++ b/Libraries/Embedders/EmbeddingGemma.swift @@ -0,0 +1,104 @@ +import MLX +import MLXNN +import MLXLLM +import MLXLMCommon + +public class EmbeddingGemma: Module, EmbeddingModel { + @ModuleInfo private var model: Gemma3TextModel + @ModuleInfo private var dense: [Module] + + public let config: Gemma3TextConfiguration + public var vocabularySize: Int { config.vocabularySize } + + public init(_ config: Gemma3TextConfiguration) { + self.config = config + self.model = Gemma3TextModel(config) + self.dense = [ + Linear(768, 3072, bias: false), Linear(3072, 768, bias: false) + ] + } + + public func callAsFunction( + _ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?, + attentionMask: MLXArray? + ) -> EmbeddingModelOutput { + var out = model.getHiddenStates(inputs, mask: nil, cache: nil) + + // mean pooling + let notPadding = inputs .!= 0 + let sum = (out * notPadding[.ellipsis, .newAxis]).sum(axis:1) + let nonMasked = notPadding.sum(axis: -1, keepDims: true) + out = sum / nonMasked + + for dense in self.dense { + if let dense = dense as? Linear { + out = dense(out) + } else if let dense = dense as? QuantizedLinear { + out = dense(out) + } + } + + // normalize + out = out.asType(Float32.self) + let norm = maximum(norm(out, ord:2.0, axis:-1, keepDims: true), MLXArray(1e-6)) + let pooledOutput = out / norm + + return EmbeddingModelOutput(hiddenStates: out, pooledOutput: pooledOutput) + } + + /// Get hidden states before the dense projection head + public func getHiddenStates(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + return model(inputs, mask: mask, cache: cache) + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + sanitize(weights: weights, quantizationConfig: nil) + } + + public func sanitize(weights: [String: MLXArray], + quantizationConfig: QuantizationConfig? = nil) + -> [String: MLXArray] + { + var processedWeights = model.sanitize(weights: weights, quantizationConfig: quantizationConfig) + + // 1. Add a model. prefix to all model. weights + processedWeights = Dictionary(uniqueKeysWithValues: processedWeights.map { key, value in + if key.hasPrefix("model.") || key.hasPrefix("lm_head.") { + return ("model.\(key)", value) + } else { + return (key, value) + } + }) + + // 2. Apply quantization to dense layers, if needed + let hasQuantizedDense = hasQuantizedWeights(layerPath: "dense.0", in: processedWeights) + if hasQuantizedDense { + let groupSize = quantizationConfig?.groupSize ?? 64 + let bits = quantizationConfig?.bits ?? 4 + + quantize(model: self) { path, module in + if hasQuantizedWeights(layerPath: path, in: processedWeights) { + return (groupSize, bits) + } + return nil + } + } + + return processedWeights.filter { key, _ in + !key.contains("self_attn.rotary_emb.inv_freq") + } + } + + /// Check if a layer has quantized weights + private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { + let scalesKey = "\(layerPath).scales" + let biasesKey = "\(layerPath).biases" + let weightKey = "\(layerPath).weight" + + let hasScales = weights[scalesKey] != nil + let hasBiases = weights[biasesKey] != nil + let hasWeight = weights[weightKey]?.dtype == .uint32 + + return hasScales && hasBiases && hasWeight + } +} diff --git a/Libraries/Embedders/EmbeddingModel.swift b/Libraries/Embedders/EmbeddingModel.swift index 3c4fbed7..119ccc8f 100644 --- a/Libraries/Embedders/EmbeddingModel.swift +++ b/Libraries/Embedders/EmbeddingModel.swift @@ -87,8 +87,8 @@ extension Module { } public struct EmbeddingModelOutput { - let hiddenStates: MLXArray? - let pooledOutput: MLXArray? + public let hiddenStates: MLXArray? + public let pooledOutput: MLXArray? } public protocol EmbeddingModel: Module { diff --git a/Libraries/Embedders/Models.swift b/Libraries/Embedders/Models.swift index 3c47a83c..e8b21c4d 100644 --- a/Libraries/Embedders/Models.swift +++ b/Libraries/Embedders/Models.swift @@ -108,6 +108,8 @@ extension ModelConfiguration { public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3") public static let mixedbread_large = ModelConfiguration( id: "mixedbread-ai/mxbai-embed-large-v1") + public static let embeddinggemma_300m = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-bf16") private enum BootstrapState: Sendable { case idle @@ -138,6 +140,7 @@ extension ModelConfiguration { snowflake_lg, bge_m3, mixedbread_large, + embeddinggemma_300m, ]) bootstrapState = .bootstrapped diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index 77891295..b9dec45b 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -56,6 +56,7 @@ public struct Gemma3TextConfiguration: Codable { public let slidingWindow: Int public let slidingWindowPattern: Int public let useBidirectionalAttention: Bool + public let quantizationConfig: QuantizationConfig? enum CodingKeys: String, CodingKey { case modelType = "model_type" @@ -74,13 +75,14 @@ public struct Gemma3TextConfiguration: Codable { case slidingWindow = "sliding_window" case slidingWindowPattern = "sliding_window_pattern" case useBidirectionalAttention = "use_bidirectional_attention" + case quantizationConfig = "quantization" } enum VLMCodingKeys: String, CodingKey { case textConfig = "text_config" } - public init(modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int, attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int, ropeGlobalBaseFreq: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool, queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int, useBidirectionalAttention: Bool) { + public init(modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int, attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int, ropeGlobalBaseFreq: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool, queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int, useBidirectionalAttention: Bool, quantizationConfig: QuantizationConfig? = nil) { self.modelType = modelType self.hiddenSize = hiddenSize self.hiddenLayers = hiddenLayers @@ -97,6 +99,7 @@ public struct Gemma3TextConfiguration: Codable { self.slidingWindow = slidingWindow self.slidingWindowPattern = slidingWindowPattern self.useBidirectionalAttention = useBidirectionalAttention + self.quantizationConfig = quantizationConfig } public init(from decoder: Decoder) throws { @@ -136,6 +139,20 @@ public struct Gemma3TextConfiguration: Codable { slidingWindow = useBidirectionalAttention ? (rawSlidingWindow / 2) + 1 : rawSlidingWindow slidingWindowPattern = try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6 + + quantizationConfig = try container.decodeIfPresent(QuantizationConfig.self, forKey: .quantizationConfig) + } +} + +// MARK: - Quantization Configuration + +public struct QuantizationConfig: Codable, Sendable { + public let groupSize: Int + public let bits: Int + + enum CodingKeys: String, CodingKey { + case groupSize = "group_size" + case bits } } @@ -407,7 +424,7 @@ private class Gemma3Model: Module { public class Gemma3TextModel: Module, LLMModel { @ModuleInfo private var model: Gemma3Model - @ModuleInfo(key: "lm_head") var lmHead: Linear + @ModuleInfo(key: "lm_head") var lmHead: Module // Can be Linear or QuantizedLinear public let config: Gemma3TextConfiguration public var vocabularySize: Int { config.vocabularySize } @@ -421,7 +438,16 @@ public class Gemma3TextModel: Module, LLMModel { public func callAsFunction(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { var out = model(inputs, mask: mask, cache: cache) - out = lmHead(out) + + // Call the lmHead (works whether it's Linear or QuantizedLinear) + if let linear = lmHead as? Linear { + out = linear(out) + } else if let quantized = lmHead as? QuantizedLinear { + out = quantized(out) + } else { + fatalError("lmHead must be Linear or QuantizedLinear") + } + return out } @@ -430,32 +456,63 @@ public class Gemma3TextModel: Module, LLMModel { return model(inputs, mask: mask, cache: cache) } - public func sanitize(weights: [String: MLXArray]) - -> [String: MLXArray] - { + public func sanitize( + weights: [String: MLXArray], + quantizationConfig: QuantizationConfig? = nil + ) -> [String: MLXArray] { var processedWeights = weights - // VLM models converted using mlx_vlm.convert will still have - // the weights are under a language_model key + // 1. Handle VLM weight extraction first - VLM models converted using mlx_vlm.convert + // will still have the weights under a language_model key let unflattened = ModuleParameters.unflattened(weights) if let lm = unflattened["language_model"] { processedWeights = Dictionary(uniqueKeysWithValues: lm.flattened()) } + // 2. Handle weight sharing (works for both regular and quantized) + // Copy embedding weights to lm_head if lm_head weights don't exist (weight tying) if processedWeights["lm_head.weight"] == nil { - ["weight", "scales", "biases"].forEach { key in - if let embedWeight = processedWeights["model.embed_tokens.\(key)"] { - processedWeights["lm_head.\(key)"] = embedWeight + for suffix in ["weight", "scales", "biases"] { + let embedKey = "model.embed_tokens.\(suffix)" + let lmHeadKey = "lm_head.\(suffix)" + + if let embedWeight = processedWeights[embedKey] { + processedWeights[lmHeadKey] = embedWeight } } } - // EmbeddingGemma contains additional 'dense' weights, which - // confuse loading into a normal Gemma3 - processedWeights["dense.0.weight"] = nil - processedWeights["dense.1.weight"] = nil + // 3. Apply quantization if needed + let hasQuantizedLmHead = hasQuantizedWeights(layerPath: "lm_head", in: processedWeights) + if hasQuantizedLmHead { + let groupSize = quantizationConfig?.groupSize ?? 64 + let bits = quantizationConfig?.bits ?? 4 + + quantize(model: self) { path, module in + if hasQuantizedWeights(layerPath: path, in: processedWeights) { + return (groupSize, bits) + } + return nil + } + } + + // Remove unused precomputed rotary freqs + return processedWeights.filter { key, _ in + !key.contains("self_attn.rotary_emb.inv_freq") + } + } + + /// Check if a layer has quantized weights + private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { + let scalesKey = "\(layerPath).scales" + let biasesKey = "\(layerPath).biases" + let weightKey = "\(layerPath).weight" + + let hasScales = weights[scalesKey] != nil + let hasBiases = weights[biasesKey] != nil + let hasWeight = weights[weightKey]?.dtype == .uint32 - return processedWeights + return hasScales && hasBiases && hasWeight } public func newCache(parameters: GenerateParameters? = nil) -> [KVCache] { From d44e2c3d6d5365655aa0e179432cf3548ecd17d4 Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Thu, 25 Sep 2025 02:30:53 -0700 Subject: [PATCH 7/9] loading --- Libraries/Embedders/Bert.swift | 5 ++ Libraries/Embedders/Configuration.swift | 2 +- Libraries/Embedders/EmbeddingGemma.swift | 10 ++-- Libraries/Embedders/EmbeddingModel.swift | 2 + Libraries/Embedders/Load.swift | 5 +- Libraries/Embedders/Models.swift | 9 +++ Libraries/Embedders/NomicBert.swift | 5 ++ Libraries/MLXLLM/Models/Gemma3Text.swift | 7 +-- Package.swift | 16 ++++++ Tools/TestEmbeddingGemma/main.swift | 72 ++++++++++++++++++++++++ 10 files changed, 120 insertions(+), 13 deletions(-) create mode 100644 Tools/TestEmbeddingGemma/main.swift diff --git a/Libraries/Embedders/Bert.swift b/Libraries/Embedders/Bert.swift index aa9cdc5a..0bba6502 100644 --- a/Libraries/Embedders/Bert.swift +++ b/Libraries/Embedders/Bert.swift @@ -2,6 +2,7 @@ import MLX import MLXNN +import MLXLMCommon extension MLXArray { public static func arange(_ size: Int) -> MLXArray { @@ -196,6 +197,10 @@ public class BertModel: Module, EmbeddingModel { result[key] = item.value }.filter { key, _ in key != "embeddings.position_ids" } } + + public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] { + fatalError("Bert does not support quantization") + } } public class DistilBertModel: BertModel { diff --git a/Libraries/Embedders/Configuration.swift b/Libraries/Embedders/Configuration.swift index b36021ea..473c7f07 100644 --- a/Libraries/Embedders/Configuration.swift +++ b/Libraries/Embedders/Configuration.swift @@ -70,7 +70,7 @@ private class ModelTypeRegistry: @unchecked Sendable { let model = NomicBertModel(configuration) return model }, - "EmbeddingGemma": { + "gemma3_text": { url in let configuration = try JSONDecoder().decode( Gemma3TextConfiguration.self, from: Data(contentsOf: url)) diff --git a/Libraries/Embedders/EmbeddingGemma.swift b/Libraries/Embedders/EmbeddingGemma.swift index da2f5386..75235bc3 100644 --- a/Libraries/Embedders/EmbeddingGemma.swift +++ b/Libraries/Embedders/EmbeddingGemma.swift @@ -51,12 +51,8 @@ public class EmbeddingGemma: Module, EmbeddingModel { return model(inputs, mask: mask, cache: cache) } - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { - sanitize(weights: weights, quantizationConfig: nil) - } - public func sanitize(weights: [String: MLXArray], - quantizationConfig: QuantizationConfig? = nil) + public func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization? = nil) -> [String: MLXArray] { var processedWeights = model.sanitize(weights: weights, quantizationConfig: quantizationConfig) @@ -89,6 +85,10 @@ public class EmbeddingGemma: Module, EmbeddingModel { } } + public func sanitize(weights: [String : MLXArray]) -> [String : MLXArray] { + sanitize(weights: weights, quantizationConfig: nil) + } + /// Check if a layer has quantized weights private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { let scalesKey = "\(layerPath).scales" diff --git a/Libraries/Embedders/EmbeddingModel.swift b/Libraries/Embedders/EmbeddingModel.swift index 119ccc8f..ab10e0c9 100644 --- a/Libraries/Embedders/EmbeddingModel.swift +++ b/Libraries/Embedders/EmbeddingModel.swift @@ -4,6 +4,7 @@ import Foundation @preconcurrency import Hub import MLX import MLXNN +import MLXLMCommon import Tokenizers /// Container for models that guarantees single threaded access. @@ -99,6 +100,7 @@ public protocol EmbeddingModel: Module { ) -> EmbeddingModelOutput /// Optionally preprocess the weights and modify / remove values as needed. func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] + func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String: MLXArray] } extension EmbeddingModel { diff --git a/Libraries/Embedders/Load.swift b/Libraries/Embedders/Load.swift index 868cc6c4..ecab22d0 100644 --- a/Libraries/Embedders/Load.swift +++ b/Libraries/Embedders/Load.swift @@ -4,6 +4,7 @@ import Foundation @preconcurrency import Hub import MLX import MLXNN +import MLXLMCommon import Tokenizers struct EmbedderError: Error { @@ -60,6 +61,8 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel { let configurationURL = modelDirectory.appending(component: "config.json") let baseConfig = try JSONDecoder().decode( BaseConfiguration.self, from: Data(contentsOf: configurationURL)) + let commonBaseConfig = try JSONDecoder().decode( + MLXLMCommon.BaseConfiguration.self, from: Data(contentsOf: configurationURL)) let modelType = ModelType(rawValue: baseConfig.modelType) let model = try modelType.createModel(configuration: configurationURL) @@ -78,7 +81,7 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel { } // per-model cleanup - weights = model.sanitize(weights: weights) + weights = model.sanitize(weights: weights, quantizationConfig: commonBaseConfig.quantization) // quantize if needed if let perLayerQuantization = baseConfig.perLayerQuantization { diff --git a/Libraries/Embedders/Models.swift b/Libraries/Embedders/Models.swift index e8b21c4d..fbc2dfd6 100644 --- a/Libraries/Embedders/Models.swift +++ b/Libraries/Embedders/Models.swift @@ -110,6 +110,12 @@ extension ModelConfiguration { id: "mixedbread-ai/mxbai-embed-large-v1") public static let embeddinggemma_300m = ModelConfiguration( id: "mlx-community/embeddinggemma-300m-bf16") + public static let embeddinggemma_300m_8bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-8bit") + public static let embeddinggemma_300m_6bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-6bit") + public static let embeddinggemma_300m_4bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-4bit") private enum BootstrapState: Sendable { case idle @@ -141,6 +147,9 @@ extension ModelConfiguration { bge_m3, mixedbread_large, embeddinggemma_300m, + embeddinggemma_300m_8bit, + embeddinggemma_300m_6bit, + embeddinggemma_300m_4bit, ]) bootstrapState = .bootstrapped diff --git a/Libraries/Embedders/NomicBert.swift b/Libraries/Embedders/NomicBert.swift index f98a69be..5107570d 100644 --- a/Libraries/Embedders/NomicBert.swift +++ b/Libraries/Embedders/NomicBert.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXNN +import MLXLMCommon class NomicEmbedding: Module { @@ -390,6 +391,10 @@ public class NomicBertModel: Module, EmbeddingModel { result[key] = item.value } } + + public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] { + fatalError("Nomic does not support quantization") + } } public struct NomicBertConfiguration: Decodable, Sendable { diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index b9dec45b..9d2f8b7b 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -56,7 +56,6 @@ public struct Gemma3TextConfiguration: Codable { public let slidingWindow: Int public let slidingWindowPattern: Int public let useBidirectionalAttention: Bool - public let quantizationConfig: QuantizationConfig? enum CodingKeys: String, CodingKey { case modelType = "model_type" @@ -75,7 +74,6 @@ public struct Gemma3TextConfiguration: Codable { case slidingWindow = "sliding_window" case slidingWindowPattern = "sliding_window_pattern" case useBidirectionalAttention = "use_bidirectional_attention" - case quantizationConfig = "quantization" } enum VLMCodingKeys: String, CodingKey { @@ -99,7 +97,6 @@ public struct Gemma3TextConfiguration: Codable { self.slidingWindow = slidingWindow self.slidingWindowPattern = slidingWindowPattern self.useBidirectionalAttention = useBidirectionalAttention - self.quantizationConfig = quantizationConfig } public init(from decoder: Decoder) throws { @@ -139,8 +136,6 @@ public struct Gemma3TextConfiguration: Codable { slidingWindow = useBidirectionalAttention ? (rawSlidingWindow / 2) + 1 : rawSlidingWindow slidingWindowPattern = try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6 - - quantizationConfig = try container.decodeIfPresent(QuantizationConfig.self, forKey: .quantizationConfig) } } @@ -458,7 +453,7 @@ public class Gemma3TextModel: Module, LLMModel { public func sanitize( weights: [String: MLXArray], - quantizationConfig: QuantizationConfig? = nil + quantizationConfig: BaseConfiguration.Quantization? = nil ) -> [String: MLXArray] { var processedWeights = weights diff --git a/Package.swift b/Package.swift index 365d7814..b58bd52c 100644 --- a/Package.swift +++ b/Package.swift @@ -25,6 +25,9 @@ let package = Package( .library( name: "StableDiffusion", targets: ["StableDiffusion"]), + .executable( + name: "test-embedding-gemma", + targets: ["TestEmbeddingGemma"]), ], dependencies: [ .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.25.5")), @@ -32,6 +35,7 @@ let package = Package( url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.23") ), .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), ], targets: [ .target( @@ -113,6 +117,7 @@ let package = Package( .target( name: "MLXEmbedders", dependencies: [ + "MLXLLM", .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), @@ -159,6 +164,17 @@ let package = Package( .enableExperimentalFeature("StrictConcurrency") ] ), + .executableTarget( + name: "TestEmbeddingGemma", + dependencies: [ + "MLXEmbedders", + "MLXLMCommon", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ], + path: "Tools/TestEmbeddingGemma" + ), ] ) diff --git a/Tools/TestEmbeddingGemma/main.swift b/Tools/TestEmbeddingGemma/main.swift new file mode 100644 index 00000000..6268e0bb --- /dev/null +++ b/Tools/TestEmbeddingGemma/main.swift @@ -0,0 +1,72 @@ +import MLX +import MLXEmbedders +import MLXLMCommon +import MLXLLM +import Foundation +import Hub +import Hub +import MLX +import MLXNN +import MLXLLM +import MLXLMCommon +import Cmlx +import ArgumentParser +@preconcurrency import Tokenizers + +import Foundation + +extension Tokenizer { + func tokenize(_ strings: [String]) -> (MLXArray, MLXArray) { + let tokenized = strings.map {self($0)} + let maxCount = tokenized.map(\.count).max()! + let padded = stacked(tokenized.map { + MLXArray($0 + Array(repeating: 0, count: maxCount - $0.count)) + }) + let mask = stacked(tokenized.map { + let basicMask = MLXArray.zeros([maxCount, maxCount], dtype: .bool) + basicMask[0..., ..<$0.count] = MLXArray(true) + return basicMask + }).reshaped(padded.shape[0], 1, padded.shape[1], padded.shape[1]) + return (padded, mask) + } +} + +@main +struct Run: AsyncParsableCommand { + mutating func run() async throws { + let configurations = [ + ModelConfiguration.embeddinggemma_300m, + ModelConfiguration.embeddinggemma_300m_8bit, + ModelConfiguration.embeddinggemma_300m_6bit, + ModelConfiguration.embeddinggemma_300m_4bit + ] + + for config in configurations { + print("Testing \(config.name)...") + let (model, tokenizer) = try await load(configuration: config) + + let (tokens, mask) = tokenizer.tokenize([ + "the cat smells of farts", + "the dog smells the cat", + "the dog smells like the cat", + "the car is not a train" + ]) + + let out: EmbeddingModelOutput = model(tokens, positionIds: nil, tokenTypeIds: nil, attentionMask: mask) + let sim = matmul(out.pooledOutput!, out.pooledOutput!.transposed()) + let time = ContinuousClock().measure { + for _ in 0..<100 { + let a = model(tokens, positionIds: nil, tokenTypeIds: nil, attentionMask: mask) + eval(a.pooledOutput) + } + } + print(sim) + + if #available(macOS 15, *) { + print(Double(time.attoseconds)/(100*1e18)) + } + } + } +} + + From 0122a15e4d7586362a206353782348d46289d595 Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Thu, 25 Sep 2025 10:54:04 -0700 Subject: [PATCH 8/9] backout test tool --- Package.swift | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/Package.swift b/Package.swift index b58bd52c..29f28885 100644 --- a/Package.swift +++ b/Package.swift @@ -25,9 +25,6 @@ let package = Package( .library( name: "StableDiffusion", targets: ["StableDiffusion"]), - .executable( - name: "test-embedding-gemma", - targets: ["TestEmbeddingGemma"]), ], dependencies: [ .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.25.5")), @@ -35,7 +32,6 @@ let package = Package( url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.23") ), .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST - .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), ], targets: [ .target( @@ -164,17 +160,6 @@ let package = Package( .enableExperimentalFeature("StrictConcurrency") ] ), - .executableTarget( - name: "TestEmbeddingGemma", - dependencies: [ - "MLXEmbedders", - "MLXLMCommon", - .product(name: "MLX", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ], - path: "Tools/TestEmbeddingGemma" - ), ] ) From 96ee882cd7c6fd3573b034686d3f3c5afe1ee04a Mon Sep 17 00:00:00 2001 From: Tom Nickson Date: Thu, 25 Sep 2025 17:16:58 -0700 Subject: [PATCH 9/9] rm tool --- Tools/TestEmbeddingGemma/main.swift | 72 ----------------------------- 1 file changed, 72 deletions(-) delete mode 100644 Tools/TestEmbeddingGemma/main.swift diff --git a/Tools/TestEmbeddingGemma/main.swift b/Tools/TestEmbeddingGemma/main.swift deleted file mode 100644 index 6268e0bb..00000000 --- a/Tools/TestEmbeddingGemma/main.swift +++ /dev/null @@ -1,72 +0,0 @@ -import MLX -import MLXEmbedders -import MLXLMCommon -import MLXLLM -import Foundation -import Hub -import Hub -import MLX -import MLXNN -import MLXLLM -import MLXLMCommon -import Cmlx -import ArgumentParser -@preconcurrency import Tokenizers - -import Foundation - -extension Tokenizer { - func tokenize(_ strings: [String]) -> (MLXArray, MLXArray) { - let tokenized = strings.map {self($0)} - let maxCount = tokenized.map(\.count).max()! - let padded = stacked(tokenized.map { - MLXArray($0 + Array(repeating: 0, count: maxCount - $0.count)) - }) - let mask = stacked(tokenized.map { - let basicMask = MLXArray.zeros([maxCount, maxCount], dtype: .bool) - basicMask[0..., ..<$0.count] = MLXArray(true) - return basicMask - }).reshaped(padded.shape[0], 1, padded.shape[1], padded.shape[1]) - return (padded, mask) - } -} - -@main -struct Run: AsyncParsableCommand { - mutating func run() async throws { - let configurations = [ - ModelConfiguration.embeddinggemma_300m, - ModelConfiguration.embeddinggemma_300m_8bit, - ModelConfiguration.embeddinggemma_300m_6bit, - ModelConfiguration.embeddinggemma_300m_4bit - ] - - for config in configurations { - print("Testing \(config.name)...") - let (model, tokenizer) = try await load(configuration: config) - - let (tokens, mask) = tokenizer.tokenize([ - "the cat smells of farts", - "the dog smells the cat", - "the dog smells like the cat", - "the car is not a train" - ]) - - let out: EmbeddingModelOutput = model(tokens, positionIds: nil, tokenTypeIds: nil, attentionMask: mask) - let sim = matmul(out.pooledOutput!, out.pooledOutput!.transposed()) - let time = ContinuousClock().measure { - for _ in 0..<100 { - let a = model(tokens, positionIds: nil, tokenTypeIds: nil, attentionMask: mask) - eval(a.pooledOutput) - } - } - print(sim) - - if #available(macOS 15, *) { - print(Double(time.attoseconds)/(100*1e18)) - } - } - } -} - -