Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Libraries/Embedders/Bert.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import MLX
import MLXNN
import MLXLMCommon

extension MLXArray {
public static func arange(_ size: Int) -> MLXArray {
Expand Down Expand Up @@ -196,6 +197,10 @@ public class BertModel: Module, EmbeddingModel {
result[key] = item.value
}.filter { key, _ in key != "embeddings.position_ids" }
}

public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] {
fatalError("Bert does not support quantization")
}
}

public class DistilBertModel: BertModel {
Expand Down
8 changes: 8 additions & 0 deletions Libraries/Embedders/Configuration.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation
import MLXLLM

public enum StringOrNumber: Codable, Equatable, Sendable {
case string(String)
Expand Down Expand Up @@ -69,6 +70,13 @@ private class ModelTypeRegistry: @unchecked Sendable {
let model = NomicBertModel(configuration)
return model
},
"gemma3_text": {
url in
let configuration = try JSONDecoder().decode(
Gemma3TextConfiguration.self, from: Data(contentsOf: url))
let model = EmbeddingGemma(configuration)
return model
},
]

public func registerModelType(
Expand Down
104 changes: 104 additions & 0 deletions Libraries/Embedders/EmbeddingGemma.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import MLX
import MLXNN
import MLXLLM
import MLXLMCommon

public class EmbeddingGemma: Module, EmbeddingModel {
@ModuleInfo private var model: Gemma3TextModel
@ModuleInfo private var dense: [Module]

public let config: Gemma3TextConfiguration
public var vocabularySize: Int { config.vocabularySize }

public init(_ config: Gemma3TextConfiguration) {
self.config = config
self.model = Gemma3TextModel(config)
self.dense = [
Linear(768, 3072, bias: false), Linear(3072, 768, bias: false)
]
}

public func callAsFunction(
_ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?,
attentionMask: MLXArray?
) -> EmbeddingModelOutput {
var out = model.getHiddenStates(inputs, mask: nil, cache: nil)

// mean pooling
let notPadding = inputs .!= 0
let sum = (out * notPadding[.ellipsis, .newAxis]).sum(axis:1)
let nonMasked = notPadding.sum(axis: -1, keepDims: true)
out = sum / nonMasked

for dense in self.dense {
if let dense = dense as? Linear {
out = dense(out)
} else if let dense = dense as? QuantizedLinear {
out = dense(out)
}
}

// normalize
out = out.asType(Float32.self)
let norm = maximum(norm(out, ord:2.0, axis:-1, keepDims: true), MLXArray(1e-6))
let pooledOutput = out / norm

return EmbeddingModelOutput(hiddenStates: out, pooledOutput: pooledOutput)
}

/// Get hidden states before the dense projection head
public func getHiddenStates(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray {
return model(inputs, mask: mask, cache: cache)
}


public func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization? = nil)
-> [String: MLXArray]
{
var processedWeights = model.sanitize(weights: weights, quantizationConfig: quantizationConfig)

// 1. Add a model. prefix to all model. weights
processedWeights = Dictionary(uniqueKeysWithValues: processedWeights.map { key, value in
if key.hasPrefix("model.") || key.hasPrefix("lm_head.") {
return ("model.\(key)", value)
} else {
return (key, value)
}
})

// 2. Apply quantization to dense layers, if needed
let hasQuantizedDense = hasQuantizedWeights(layerPath: "dense.0", in: processedWeights)
if hasQuantizedDense {
let groupSize = quantizationConfig?.groupSize ?? 64
let bits = quantizationConfig?.bits ?? 4

quantize(model: self) { path, module in
if hasQuantizedWeights(layerPath: path, in: processedWeights) {
return (groupSize, bits)
}
return nil
}
}

return processedWeights.filter { key, _ in
!key.contains("self_attn.rotary_emb.inv_freq")
}
}

public func sanitize(weights: [String : MLXArray]) -> [String : MLXArray] {
sanitize(weights: weights, quantizationConfig: nil)
}

/// Check if a layer has quantized weights
private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool {
let scalesKey = "\(layerPath).scales"
let biasesKey = "\(layerPath).biases"
let weightKey = "\(layerPath).weight"

let hasScales = weights[scalesKey] != nil
let hasBiases = weights[biasesKey] != nil
let hasWeight = weights[weightKey]?.dtype == .uint32

return hasScales && hasBiases && hasWeight
}
}
6 changes: 4 additions & 2 deletions Libraries/Embedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Foundation
@preconcurrency import Hub
import MLX
import MLXNN
import MLXLMCommon
import Tokenizers

/// Container for models that guarantees single threaded access.
Expand Down Expand Up @@ -87,8 +88,8 @@ extension Module {
}

public struct EmbeddingModelOutput {
let hiddenStates: MLXArray?
let pooledOutput: MLXArray?
public let hiddenStates: MLXArray?
public let pooledOutput: MLXArray?
}

public protocol EmbeddingModel: Module {
Expand All @@ -99,6 +100,7 @@ public protocol EmbeddingModel: Module {
) -> EmbeddingModelOutput
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String: MLXArray]
}

extension EmbeddingModel {
Expand Down
5 changes: 4 additions & 1 deletion Libraries/Embedders/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Foundation
@preconcurrency import Hub
import MLX
import MLXNN
import MLXLMCommon
import Tokenizers

struct EmbedderError: Error {
Expand Down Expand Up @@ -60,6 +61,8 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let commonBaseConfig = try JSONDecoder().decode(
MLXLMCommon.BaseConfiguration.self, from: Data(contentsOf: configurationURL))

let modelType = ModelType(rawValue: baseConfig.modelType)
let model = try modelType.createModel(configuration: configurationURL)
Expand All @@ -78,7 +81,7 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
}

// per-model cleanup
weights = model.sanitize(weights: weights)
weights = model.sanitize(weights: weights, quantizationConfig: commonBaseConfig.quantization)

// quantize if needed
if let perLayerQuantization = baseConfig.perLayerQuantization {
Expand Down
12 changes: 12 additions & 0 deletions Libraries/Embedders/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ extension ModelConfiguration {
public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3")
public static let mixedbread_large = ModelConfiguration(
id: "mixedbread-ai/mxbai-embed-large-v1")
public static let embeddinggemma_300m = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-bf16")
public static let embeddinggemma_300m_8bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-8bit")
public static let embeddinggemma_300m_6bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-6bit")
public static let embeddinggemma_300m_4bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-4bit")

private enum BootstrapState: Sendable {
case idle
Expand Down Expand Up @@ -138,6 +146,10 @@ extension ModelConfiguration {
snowflake_lg,
bge_m3,
mixedbread_large,
embeddinggemma_300m,
embeddinggemma_300m_8bit,
embeddinggemma_300m_6bit,
embeddinggemma_300m_4bit,
])
bootstrapState = .bootstrapped

Expand Down
5 changes: 5 additions & 0 deletions Libraries/Embedders/NomicBert.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import Foundation
import MLX
import MLXNN
import MLXLMCommon

class NomicEmbedding: Module {

Expand Down Expand Up @@ -390,6 +391,10 @@ public class NomicBertModel: Module, EmbeddingModel {
result[key] = item.value
}
}

public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] {
fatalError("Nomic does not support quantization")
}
}

public struct NomicBertConfiguration: Decodable, Sendable {
Expand Down
Loading