Skip to content

Commit ab6feba

Browse files
authored
Add Baichuan M1 (#355)
* Add support for Baichuan M1 model
1 parent e506d10 commit ab6feba

File tree

3 files changed

+308
-2
lines changed

3 files changed

+308
-2
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
77
MLX Swift was developed with contributions from the following individuals:
88

9-
- [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2).
10-
9+
- [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2, Baichuan-M1).
1110

1211
<a href="https://github.com/ml-explore/mlx-swift-examples/graphs/contributors">
1312
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx-swift-examples&anon=0&columns=20&max=100&r=true" />

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
5252
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
5353
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
5454
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
55+
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
5556
]
5657
}
5758

@@ -234,6 +235,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
234235
defaultPrompt: "Why is the sky blue?"
235236
)
236237

238+
static public let baichuan_m1_14b_instruct_4bit = ModelConfiguration(
239+
id: "mlx-community/Baichuan-M1-14B-Instruct-4bit-ft",
240+
defaultPrompt: "Why is the sky blue?"
241+
)
242+
237243
static public let smollm3_3b_4bit = ModelConfiguration(
238244
id: "mlx-community/SmolLM3-3B-4bit",
239245
defaultPrompt: "Why is the sky blue?"
@@ -284,6 +290,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
284290
smollm3_3b_4bit,
285291
ernie_45_0_3BPT_bf16_ft,
286292
lfm2_1_2b_4bit,
293+
baichuan_m1_14b_instruct_4bit,
287294
]
288295
}
289296

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
//
2+
// BaichuanM1.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by John Mai on 2025/6/16.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXLMCommon
12+
import MLXNN
13+
import MLXRandom
14+
15+
public struct BaichuanM1Configuration: Codable, Sendable {
16+
var vocabularySize: Int
17+
var hiddenSize: Int
18+
var intermediateSize: Int
19+
var hiddenLayers: Int
20+
var attentionHeads: Int
21+
var kvHeads: Int
22+
var ropeTheta: Float
23+
var slidingWindow: Int
24+
var slidingWindowLayers: [Int]
25+
var convWindow: Int
26+
var rmsNormEps: Float
27+
var swaAttentionHeads: Int?
28+
var swaKvHeads: Int?
29+
var tieWordEmbeddings: Bool = false
30+
31+
enum CodingKeys: String, CodingKey {
32+
case vocabularySize = "vocab_size"
33+
case hiddenSize = "hidden_size"
34+
case intermediateSize = "intermediate_size"
35+
case hiddenLayers = "num_hidden_layers"
36+
case attentionHeads = "num_attention_heads"
37+
case kvHeads = "num_key_value_heads"
38+
case ropeTheta = "rope_theta"
39+
case slidingWindow = "sliding_window"
40+
case slidingWindowLayers = "sliding_window_layers"
41+
case convWindow = "conv_window"
42+
case rmsNormEps = "rms_norm_eps"
43+
case swaAttentionHeads = "num_swa_attention_heads"
44+
case swaKvHeads = "num_swa_key_value_heads"
45+
case tieWordEmbeddings = "tie_word_embeddings"
46+
}
47+
}
48+
49+
private class Attention: Module {
50+
let config: BaichuanM1Configuration
51+
let layerIdx: Int
52+
let isSWA: Bool
53+
let numHeads: Int
54+
let numKVHeads: Int
55+
let hiddenSize: Int
56+
let headDim: Int
57+
let scale: Float
58+
59+
@ModuleInfo(key: "W_pack") var wPack: Linear
60+
@ModuleInfo(key: "o_proj") var oProj: Linear
61+
let rope: RoPE
62+
63+
@ParameterInfo(key: "conv_k") var convK: MLXArray
64+
@ParameterInfo(key: "conv_v") var convV: MLXArray
65+
66+
init(_ config: BaichuanM1Configuration, layerIdx: Int) {
67+
self.config = config
68+
self.layerIdx = layerIdx
69+
70+
self.isSWA = config.slidingWindowLayers.contains(layerIdx)
71+
self.numHeads =
72+
isSWA && config.swaAttentionHeads != nil
73+
? config.swaAttentionHeads! : config.attentionHeads
74+
self.numKVHeads = isSWA && config.swaKvHeads != nil ? config.swaKvHeads! : config.kvHeads
75+
76+
self.hiddenSize = config.hiddenSize
77+
self.headDim = hiddenSize / numHeads
78+
self.scale = pow(Float(headDim), -0.5)
79+
80+
_wPack.wrappedValue = Linear(
81+
config.hiddenSize, config.hiddenSize + 2 * numKVHeads * headDim, bias: false)
82+
_oProj.wrappedValue = Linear(numHeads * headDim, config.hiddenSize, bias: false)
83+
84+
self.rope = RoPE(dimensions: headDim, traditional: false, base: config.ropeTheta)
85+
86+
_convK.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow])
87+
_convV.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow])
88+
}
89+
90+
func customConvolution(_ u: MLXArray, _ weights: MLXArray, state: MLXArray? = nil) -> MLXArray {
91+
let (B, H, L, D) = (u.dim(0), u.dim(1), u.dim(2), u.dim(3))
92+
let reshapedWeights = weights.reshaped(1, H, config.convWindow, 1, 1)
93+
let w0 = reshapedWeights[0..., 0..., 0]
94+
let w1 = reshapedWeights[0..., 0..., 1]
95+
96+
let state = state ?? MLXArray.zeros([B, H, 1, D], dtype: u.dtype)
97+
98+
let uPrev: MLXArray =
99+
L > 1 ? concatenated([state, u[0..., 0..., ..<(L - 1), 0...]], axis: 2) : state
100+
101+
return uPrev * w0 + u * w1
102+
}
103+
104+
func callAsFunction(
105+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
106+
) -> MLXArray {
107+
let (B, L, D) = (x.dim(0), x.dim(1), x.dim(2))
108+
109+
let proj = wPack(x)
110+
let qkv = split(proj, indices: [D, D + self.numKVHeads * self.headDim], axis: -1)
111+
112+
var queries = qkv[0].reshaped(B, L, numHeads, headDim).transposed(0, 2, 1, 3)
113+
var keys = qkv[1].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)
114+
var values = qkv[2].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)
115+
116+
var offset = 0
117+
var lastK: MLXArray? = nil
118+
var lastV: MLXArray? = nil
119+
120+
if let cacheList = cache as? CacheList {
121+
offset = cacheList[1].offset
122+
if let mambaCache = cacheList[0] as? MambaCache {
123+
lastK = mambaCache[0]
124+
lastV = mambaCache[1]
125+
}
126+
}
127+
128+
let kInit = keys
129+
let vInit = values
130+
131+
keys = customConvolution(keys, convK, state: lastK)
132+
values = customConvolution(values, convV, state: lastV)
133+
134+
queries = rope(queries, offset: offset)
135+
keys = rope(keys, offset: offset)
136+
137+
if let cache = cache as? CacheList {
138+
let kvCache = cache[1]
139+
let (cachedKeys, cachedValues) = kvCache.update(keys: keys, values: values)
140+
keys = cachedKeys
141+
values = cachedValues
142+
143+
if L > 0 {
144+
let convCache = cache[0] as! MambaCache
145+
convCache[0] = kInit[0..., 0..., (L - 1)..., 0...]
146+
convCache[1] = vInit[0..., 0..., (L - 1)..., 0...]
147+
}
148+
}
149+
150+
let out = MLXFast.scaledDotProductAttention(
151+
queries: queries, keys: keys, values: values, scale: scale, mask: mask
152+
)
153+
.transposed(0, 2, 1, 3)
154+
.reshaped(B, L, -1)
155+
156+
return oProj(out)
157+
}
158+
}
159+
160+
private class MLP: Module, UnaryLayer {
161+
@ModuleInfo(key: "gate_proj") var gateProj: Linear
162+
@ModuleInfo(key: "up_proj") var upProj: Linear
163+
@ModuleInfo(key: "down_proj") var downProj: Linear
164+
165+
init(_ config: BaichuanM1Configuration) {
166+
_gateProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false)
167+
_upProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false)
168+
_downProj.wrappedValue = Linear(config.intermediateSize, config.hiddenSize, bias: false)
169+
}
170+
171+
func callAsFunction(_ x: MLXArray) -> MLXArray {
172+
return downProj(silu(gateProj(x)) * upProj(x))
173+
}
174+
}
175+
176+
private class DecoderLayer: Module {
177+
@ModuleInfo(key: "self_attn") var attention: Attention
178+
let mlp: MLP
179+
@ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm
180+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm
181+
182+
init(_ config: BaichuanM1Configuration, layerIdx: Int) {
183+
_attention.wrappedValue = Attention(config, layerIdx: layerIdx)
184+
self.mlp = MLP(config)
185+
_inputLayernorm.wrappedValue = RMSNorm(
186+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
187+
_postAttentionLayernorm.wrappedValue = RMSNorm(
188+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
189+
}
190+
191+
func callAsFunction(
192+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
193+
) -> MLXArray {
194+
var r = attention(inputLayernorm(x), mask: mask, cache: cache)
195+
let x = x + r
196+
r = mlp(postAttentionLayernorm(x))
197+
return x + r
198+
}
199+
}
200+
201+
private class BaichuanM1ModelInner: Module {
202+
let args: BaichuanM1Configuration
203+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
204+
205+
fileprivate let layers: [DecoderLayer]
206+
let norm: RMSNorm
207+
208+
init(_ config: BaichuanM1Configuration) {
209+
self.args = config
210+
_embedTokens.wrappedValue = Embedding(
211+
embeddingCount: config.vocabularySize, dimensions: config.hiddenSize)
212+
self.layers = (0 ..< config.hiddenLayers).map { DecoderLayer(config, layerIdx: $0) }
213+
norm = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
214+
}
215+
216+
func callAsFunction(
217+
_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil,
218+
cache: [KVCache]?
219+
) -> MLXArray {
220+
var x = embedTokens(inputs)
221+
222+
let mask = mask ?? createAttentionMask(h: x, cache: cache)
223+
224+
for (i, layer) in layers.enumerated() {
225+
x = layer(x, mask: mask, cache: cache?[i])
226+
}
227+
228+
return norm(x)
229+
}
230+
}
231+
232+
public class BaichuanM1Model: Module, LLMModel, KVCacheDimensionProvider {
233+
234+
public let vocabularySize: Int
235+
public let kvHeads: [Int]
236+
237+
private let model: BaichuanM1ModelInner
238+
let configuration: BaichuanM1Configuration
239+
240+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
241+
242+
public init(_ config: BaichuanM1Configuration) {
243+
self.configuration = config
244+
self.vocabularySize = config.vocabularySize
245+
self.kvHeads = Array(repeating: config.kvHeads, count: config.hiddenLayers)
246+
self.model = BaichuanM1ModelInner(config)
247+
248+
if !config.tieWordEmbeddings {
249+
_lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false)
250+
}
251+
}
252+
253+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
254+
var outputs = model(inputs, cache: cache)
255+
256+
if let lmHead {
257+
outputs = lmHead(outputs)
258+
}
259+
260+
return outputs
261+
}
262+
263+
public func newCache(parameters: GenerateParameters?) -> [KVCache] {
264+
return model.layers.enumerated().map { (i, _) in
265+
let isSWA = configuration.slidingWindowLayers.contains(i)
266+
let convCache = MambaCache()
267+
let kvCache: KVCache =
268+
isSWA ? RotatingKVCache(maxSize: configuration.slidingWindow) : KVCacheSimple()
269+
return CacheList(convCache, kvCache)
270+
}
271+
}
272+
273+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
274+
var weights = weights
275+
let isQuantized = weights["lm_head.scales"] != nil
276+
277+
if !isQuantized, let w = weights["lm_head.weight"] {
278+
var w = w
279+
if w.dtype != .float32 {
280+
w = w.asType(.float32)
281+
}
282+
283+
let norm = sqrt(sum(w * w, axes: [-1], keepDims: true))
284+
w = (w / (norm + 1e-7)).asType(w.dtype)
285+
weights["lm_head.weight"] = w
286+
}
287+
288+
if configuration.tieWordEmbeddings {
289+
weights["lm_head.weight"] = nil
290+
}
291+
292+
return weights
293+
}
294+
}
295+
296+
extension BaichuanM1Model: LoRAModel {
297+
public func loraLinearLayers() -> LoRALinearLayers {
298+
model.layers.map { ($0.attention, ["W_pack"]) }
299+
}
300+
}

0 commit comments

Comments
 (0)