diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift
index ea8aedba..c30e9e6c 100644
--- a/Libraries/MLXLLM/LLMModelFactory.swift
+++ b/Libraries/MLXLLM/LLMModelFactory.swift
@@ -61,6 +61,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
+ "mamba": create(MambaConfiguration.self, MambaModel.init),
]
}
}
diff --git a/Libraries/MLXLLM/Models/Mamba.swift b/Libraries/MLXLLM/Models/Mamba.swift
new file mode 100644
index 00000000..d2955c2a
--- /dev/null
+++ b/Libraries/MLXLLM/Models/Mamba.swift
@@ -0,0 +1,339 @@
+// Copyright © 2025 Apple Inc.
+
+import Foundation
+import MLX
+import MLXFast
+import MLXLMCommon
+import MLXNN
+
+// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mamba.py
+
+struct StringKey: CodingKey, ExpressibleByStringLiteral {
+ var intValue: Int? = nil
+ var stringValue: String
+ init?(intValue: Int) { return nil }
+ init?(stringValue: String) { self.stringValue = stringValue }
+ init(stringLiteral: StringLiteralType) {
+ self.stringValue = stringLiteral
+ }
+}
+
+public struct MambaConfiguration: Codable, Sendable {
+ var modelType: String
+ var vocabSize: Int
+ var hiddenSize: Int
+ var intermediateSize: Int
+ var stateSize: Int
+ var numHiddenLayers: Int
+ var convKernel: Int
+ var useBias: Bool
+ var useConvBias: Bool
+ var timeStepRank: Int
+ var tieWordEmbeddings: Bool
+ var useBcdtRms: Bool
+ var mixerRmsEps: Float
+
+ enum CodingKeys: String, CodingKey {
+ case modelType = "model_type"
+ case vocabSize = "vocab_size"
+ case hiddenSize = "hidden_size"
+ case intermediateSize = "intermediate_size"
+ case stateSize = "state_size"
+ case numHiddenLayers = "num_hidden_layers"
+ case convKernel = "conv_kernel"
+ case useBias = "use_bias"
+ case useConvBias = "use_conv_bias"
+ case timeStepRank = "time_step_rank"
+ case tieWordEmbeddings = "tie_word_embeddings"
+ case useBcdtRms = "use_bcdt_rms"
+ case mixerRmsEps = "mixer_rms_eps"
+ }
+
+ public init(from decoder: Decoder) throws {
+ let container = try decoder.container(keyedBy: CodingKeys.self)
+ let fallback = try decoder.container(keyedBy: StringKey.self)
+
+ modelType = try container.decode(String.self, forKey: .modelType)
+ vocabSize = try container.decode(Int.self, forKey: .vocabSize)
+ hiddenSize =
+ try container
+ .decodeIfPresent(Int.self, forKey: .hiddenSize)
+ ?? fallback
+ .decode(Int.self, forKey: "d_model")
+ intermediateSize =
+ try container
+ .decodeIfPresent(Int.self, forKey: .intermediateSize)
+ ?? fallback
+ .decode(Int.self, forKey: "d_inner")
+ stateSize =
+ try container
+ .decodeIfPresent(Int.self, forKey: .stateSize)
+ ?? fallback
+ .decode(Int.self, forKey: "d_state")
+ numHiddenLayers =
+ try container
+ .decodeIfPresent(Int.self, forKey: .numHiddenLayers)
+ ?? fallback
+ .decodeIfPresent(Int.self, forKey: "n_layer")
+ ?? fallback
+ .decode(Int.self, forKey: "n_layers")
+ convKernel =
+ try container
+ .decodeIfPresent(Int.self, forKey: .convKernel)
+ ?? fallback
+ .decode(Int.self, forKey: "d_conv")
+ useBias =
+ try container
+ .decodeIfPresent(Bool.self, forKey: .useBias)
+ ?? fallback
+ .decode(Bool.self, forKey: "bias")
+ useConvBias =
+ try container
+ .decodeIfPresent(Bool.self, forKey: .useConvBias)
+ ?? fallback
+ .decode(Bool.self, forKey: "conv_bias")
+
+ if let timeStepRankAuto = try? container.decode(String.self, forKey: .timeStepRank),
+ timeStepRankAuto == "auto"
+ {
+ timeStepRank = (hiddenSize + 15) / 16
+ } else {
+ timeStepRank = try container.decode(Int.self, forKey: .timeStepRank)
+ }
+
+ tieWordEmbeddings =
+ try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? true
+ useBcdtRms = try container.decodeIfPresent(Bool.self, forKey: .useBcdtRms) ?? false
+ mixerRmsEps = try container.decodeIfPresent(Float.self, forKey: .mixerRmsEps) ?? 1e-6
+
+ if modelType == "falcon_mamba" {
+ useBcdtRms = true
+ }
+ }
+}
+
+private class MambaBlock: Module {
+
+ let args: MambaConfiguration
+
+ var _mixerNorm: ((MLXArray) -> MLXArray)? = nil
+
+ @ModuleInfo(key: "in_proj") var inProj: Linear
+ @ModuleInfo(key: "conv1d") var conv1d: Conv1d
+ @ModuleInfo(key: "x_proj") var xProj: Linear
+ @ModuleInfo(key: "dt_proj") var dtProj: Linear
+
+ @ParameterInfo(key: "A_log") var aLog: MLXArray
+ @ParameterInfo(key: "D") var d: MLXArray
+
+ @ModuleInfo(key: "out_proj") var outProj: Linear
+
+ public init(_ args: MambaConfiguration) {
+ self.args = args
+ if args.useBcdtRms {
+ self._mixerNorm = {
+ MLXFast.rmsNorm(
+ $0,
+ weight: MLX.ones([$0.dim(-1)], dtype: $0.dtype),
+ eps: args.mixerRmsEps)
+ }
+ }
+
+ self._inProj.wrappedValue = Linear(
+ args.hiddenSize, args.intermediateSize * 2, bias: args.useBias)
+
+ self._conv1d.wrappedValue = Conv1d(
+ inputChannels: args.intermediateSize,
+ outputChannels: args.intermediateSize,
+ kernelSize: args.convKernel,
+ padding: 0,
+ groups: args.intermediateSize,
+ bias: args.useConvBias
+ )
+
+ self._xProj.wrappedValue = Linear(
+ args.intermediateSize,
+ args.timeStepRank + 2 * args.stateSize,
+ bias: false
+ )
+
+ self._dtProj.wrappedValue = Linear(
+ args.timeStepRank, args.intermediateSize, bias: true)
+
+ let A = repeated(
+ MLXArray(1 ..< args.stateSize + 1, [1, args.stateSize]),
+ count: args.intermediateSize,
+ axis: 0
+ )
+
+ self._aLog.wrappedValue = log(A)
+ self._d.wrappedValue = ones([args.intermediateSize])
+
+ self._outProj.wrappedValue = Linear(
+ args.intermediateSize, args.hiddenSize, bias: args.useBias)
+ }
+
+ func ssmStep(_ x: MLXArray, _ A: MLXArray, state: MLXArray?) -> (MLXArray, MLXArray) {
+ let deltaBC = self.xProj(x)
+ var deltaBCParts = split(
+ deltaBC,
+ indices: [self.args.timeStepRank, self.args.timeStepRank + self.args.stateSize],
+ axis: -1
+ ).map {
+ if self.args.useBcdtRms, let mixerNorm = self._mixerNorm {
+ return mixerNorm($0)
+ } else {
+ return $0
+ }
+ }
+ if self.args.useBcdtRms, let mixerNorm = self._mixerNorm {
+ deltaBCParts = deltaBCParts.map { mixerNorm($0) }
+ }
+ var delta = deltaBCParts[0]
+ let B = deltaBCParts[1]
+ let C = deltaBCParts[2]
+
+ delta = softplus(self.dtProj(delta))
+ var newState = expandedDimensions(delta * x, axis: -1) * expandedDimensions(B, axis: 1)
+ if let state {
+ newState += state * exp(expandedDimensions(delta, axis: -1) * A)
+ }
+ var y = newState.matmul(expandedDimensions(C, axis: -1)).squeezed(axis: 2)
+ y = y + self._d.wrappedValue * x
+ return (y, newState)
+ }
+
+ func processSequence(_ x: MLXArray, convCache: MLXArray?, stateCache: MLXArray?)
+ -> (MLXArray, (MLXArray, MLXArray?))
+ {
+ let T = x.dim(1)
+ let xz = self.inProj(x)
+ var (x, z) = xz.split(axis: -1)
+ let K = self.args.convKernel
+ var xFull: MLXArray
+ if let convCache {
+ xFull = concatenated([convCache, x], axis: 1)
+ } else {
+ xFull = padded(
+ x,
+ widths: [
+ .init((0, 0)),
+ .init((K - 1, 0)),
+ .init((0, 0)),
+ ])
+ }
+ let convOut = conv1d(xFull)
+ // TODO there is a failure in the next line, maybe need .newAxis or something
+ // there are only 3 slices in the python code, not 4
+ // I need to figure out how to transalte -(K-1)... to swift
+ // the following compiles, but not sure if it is correct
+ let newConvCache = xFull[0..., (1 - K)..., 0...]
+ x = silu(convOut)
+ let A = -exp(self.aLog)
+ var currentState = stateCache
+ var y: [MLXArray] = []
+ var yT: MLXArray
+ for t in 0 ..< T {
+ (yT, currentState) = self.ssmStep(x[0..., t], A, state: currentState)
+ y.append(yT)
+ }
+ z = self.outProj(silu(z) * stacked(y, axis: 1))
+ return (z, (newConvCache, currentState))
+ }
+
+ public func callAsFunction(_ inputs: MLXArray, cache: MambaCache? = nil) -> MLXArray {
+ let (output, (newConvCache, newStateCache)) = self.processSequence(
+ inputs, convCache: cache?[0], stateCache: cache?[1]
+ )
+ if cache != nil {
+ cache![0] = newConvCache
+ cache![1] = newStateCache
+ }
+ return output
+ }
+
+}
+
+private class ResidualBlock: Module {
+ @ModuleInfo var mixer: MambaBlock
+ @ModuleInfo var norm: RMSNorm
+ public init(_ args: MambaConfiguration) {
+ self._mixer.wrappedValue = MambaBlock(args)
+ self._norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize)
+ }
+ public func callAsFunction(_ inputs: MLXArray, cache: MambaCache? = nil) -> MLXArray {
+ return mixer(norm(inputs), cache: cache) + inputs
+ }
+}
+
+// maps to mamba.Mamba
+private class MambaModelInner: Module {
+ @ModuleInfo var embeddings: Embedding
+ @ModuleInfo var layers: [ResidualBlock]
+ @ModuleInfo(key: "norm_f") var normF: RMSNorm
+
+ public init(_ args: MambaConfiguration) {
+ self._embeddings.wrappedValue = Embedding(
+ embeddingCount: args.vocabSize, dimensions: args.hiddenSize)
+ self._layers.wrappedValue = (0 ..< args.numHiddenLayers).map { _ in
+ ResidualBlock(args)
+ }
+ self._normF.wrappedValue = RMSNorm(dimensions: args.hiddenSize)
+ }
+
+ public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
+ var x = embeddings(inputs)
+ for (i, layer) in layers.enumerated() {
+ x = layer(x, cache: (cache?[i] as? MambaCache))
+ }
+ return normF(x)
+ }
+}
+
+// maps to mamba.Model
+public class MambaModel: Module, LLMModel {
+ let args: MambaConfiguration
+ let modelType: String
+ @ModuleInfo private var backbone: MambaModelInner
+ @ModuleInfo(key: "lm_head") var lmHead: Linear? = nil
+
+ public init(_ args: MambaConfiguration) {
+ self.args = args
+ self.modelType = args.modelType
+ self._backbone.wrappedValue = MambaModelInner(args)
+ if !args.tieWordEmbeddings {
+ self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false)
+ }
+ }
+ public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
+ let x = self.backbone(inputs, cache: cache)
+ var logits: MLXArray
+ if let lmHead {
+ logits = lmHead(x)
+ } else {
+ logits = self.backbone.embeddings.asLinear(x)
+ }
+ return logits
+ }
+
+ public func newCache(parameters: MLXLMCommon.GenerateParameters?)
+ -> [any MLXLMCommon.KVCache]
+ {
+ return (0 ..< args.numHiddenLayers).map { _ in MambaCache() }
+ }
+
+ public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
+ var processedWeights = weights
+ for (key, value) in weights {
+ if key.contains("conv1d.weight") && value.dim(-1) != 1 {
+ processedWeights[key] = value.movedAxis(source: 2, destination: 1)
+ }
+ }
+ return processedWeights
+ }
+
+ public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers {
+ return []
+ }
+
+}
diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
index a092bee2..087c9018 100644
--- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
+++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
@@ -111,6 +111,10 @@
argument = "--model mlx-community/Llama-3.2-1B-Instruct-4bit"
isEnabled = "NO">
+
+