|
| 1 | +// |
| 2 | +// BaichuanM1.swift |
| 3 | +// mlx-swift-examples |
| 4 | +// |
| 5 | +// Created by John Mai on 2025/6/16. |
| 6 | +// |
| 7 | + |
| 8 | +import Foundation |
| 9 | +import MLX |
| 10 | +import MLXFast |
| 11 | +import MLXLMCommon |
| 12 | +import MLXNN |
| 13 | +import MLXRandom |
| 14 | + |
| 15 | +public struct BaichuanM1Configuration: Codable, Sendable { |
| 16 | + var vocabularySize: Int |
| 17 | + var hiddenSize: Int |
| 18 | + var intermediateSize: Int |
| 19 | + var hiddenLayers: Int |
| 20 | + var attentionHeads: Int |
| 21 | + var kvHeads: Int |
| 22 | + var ropeTheta: Float |
| 23 | + var slidingWindow: Int |
| 24 | + var slidingWindowLayers: [Int] |
| 25 | + var convWindow: Int |
| 26 | + var rmsNormEps: Float |
| 27 | + var swaAttentionHeads: Int? |
| 28 | + var swaKvHeads: Int? |
| 29 | + var tieWordEmbeddings: Bool = false |
| 30 | + |
| 31 | + enum CodingKeys: String, CodingKey { |
| 32 | + case vocabularySize = "vocab_size" |
| 33 | + case hiddenSize = "hidden_size" |
| 34 | + case intermediateSize = "intermediate_size" |
| 35 | + case hiddenLayers = "num_hidden_layers" |
| 36 | + case attentionHeads = "num_attention_heads" |
| 37 | + case kvHeads = "num_key_value_heads" |
| 38 | + case ropeTheta = "rope_theta" |
| 39 | + case slidingWindow = "sliding_window" |
| 40 | + case slidingWindowLayers = "sliding_window_layers" |
| 41 | + case convWindow = "conv_window" |
| 42 | + case rmsNormEps = "rms_norm_eps" |
| 43 | + case swaAttentionHeads = "num_swa_attention_heads" |
| 44 | + case swaKvHeads = "num_swa_key_value_heads" |
| 45 | + case tieWordEmbeddings = "tie_word_embeddings" |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +private class Attention: Module { |
| 50 | + let config: BaichuanM1Configuration |
| 51 | + let layerIdx: Int |
| 52 | + let isSWA: Bool |
| 53 | + let numHeads: Int |
| 54 | + let numKVHeads: Int |
| 55 | + let hiddenSize: Int |
| 56 | + let headDim: Int |
| 57 | + let scale: Float |
| 58 | + |
| 59 | + @ModuleInfo(key: "W_pack") var wPack: Linear |
| 60 | + @ModuleInfo(key: "o_proj") var oProj: Linear |
| 61 | + let rope: RoPE |
| 62 | + |
| 63 | + @ParameterInfo(key: "conv_k") var convK: MLXArray |
| 64 | + @ParameterInfo(key: "conv_v") var convV: MLXArray |
| 65 | + |
| 66 | + init(_ config: BaichuanM1Configuration, layerIdx: Int) { |
| 67 | + self.config = config |
| 68 | + self.layerIdx = layerIdx |
| 69 | + |
| 70 | + self.isSWA = config.slidingWindowLayers.contains(layerIdx) |
| 71 | + self.numHeads = |
| 72 | + isSWA && config.swaAttentionHeads != nil |
| 73 | + ? config.swaAttentionHeads! : config.attentionHeads |
| 74 | + self.numKVHeads = isSWA && config.swaKvHeads != nil ? config.swaKvHeads! : config.kvHeads |
| 75 | + |
| 76 | + self.hiddenSize = config.hiddenSize |
| 77 | + self.headDim = hiddenSize / numHeads |
| 78 | + self.scale = pow(Float(headDim), -0.5) |
| 79 | + |
| 80 | + _wPack.wrappedValue = Linear( |
| 81 | + config.hiddenSize, config.hiddenSize + 2 * numKVHeads * headDim, bias: false) |
| 82 | + _oProj.wrappedValue = Linear(numHeads * headDim, config.hiddenSize, bias: false) |
| 83 | + |
| 84 | + self.rope = RoPE(dimensions: headDim, traditional: false, base: config.ropeTheta) |
| 85 | + |
| 86 | + _convK.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow]) |
| 87 | + _convV.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow]) |
| 88 | + } |
| 89 | + |
| 90 | + func customConvolution(_ u: MLXArray, _ weights: MLXArray, state: MLXArray? = nil) -> MLXArray { |
| 91 | + let (B, H, L, D) = (u.dim(0), u.dim(1), u.dim(2), u.dim(3)) |
| 92 | + let reshapedWeights = weights.reshaped(1, H, config.convWindow, 1, 1) |
| 93 | + let w0 = reshapedWeights[0..., 0..., 0] |
| 94 | + let w1 = reshapedWeights[0..., 0..., 1] |
| 95 | + |
| 96 | + let state = state ?? MLXArray.zeros([B, H, 1, D], dtype: u.dtype) |
| 97 | + |
| 98 | + let uPrev: MLXArray = |
| 99 | + L > 1 ? concatenated([state, u[0..., 0..., ..<(L - 1), 0...]], axis: 2) : state |
| 100 | + |
| 101 | + return uPrev * w0 + u * w1 |
| 102 | + } |
| 103 | + |
| 104 | + func callAsFunction( |
| 105 | + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? |
| 106 | + ) -> MLXArray { |
| 107 | + let (B, L, D) = (x.dim(0), x.dim(1), x.dim(2)) |
| 108 | + |
| 109 | + let proj = wPack(x) |
| 110 | + let qkv = split(proj, indices: [D, D + self.numKVHeads * self.headDim], axis: -1) |
| 111 | + |
| 112 | + var queries = qkv[0].reshaped(B, L, numHeads, headDim).transposed(0, 2, 1, 3) |
| 113 | + var keys = qkv[1].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3) |
| 114 | + var values = qkv[2].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3) |
| 115 | + |
| 116 | + var offset = 0 |
| 117 | + var lastK: MLXArray? = nil |
| 118 | + var lastV: MLXArray? = nil |
| 119 | + |
| 120 | + if let cacheList = cache as? CacheList { |
| 121 | + offset = cacheList[1].offset |
| 122 | + if let mambaCache = cacheList[0] as? MambaCache { |
| 123 | + lastK = mambaCache[0] |
| 124 | + lastV = mambaCache[1] |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + let kInit = keys |
| 129 | + let vInit = values |
| 130 | + |
| 131 | + keys = customConvolution(keys, convK, state: lastK) |
| 132 | + values = customConvolution(values, convV, state: lastV) |
| 133 | + |
| 134 | + queries = rope(queries, offset: offset) |
| 135 | + keys = rope(keys, offset: offset) |
| 136 | + |
| 137 | + if let cache = cache as? CacheList { |
| 138 | + let kvCache = cache[1] |
| 139 | + let (cachedKeys, cachedValues) = kvCache.update(keys: keys, values: values) |
| 140 | + keys = cachedKeys |
| 141 | + values = cachedValues |
| 142 | + |
| 143 | + if L > 0 { |
| 144 | + let convCache = cache[0] as! MambaCache |
| 145 | + convCache[0] = kInit[0..., 0..., (L - 1)..., 0...] |
| 146 | + convCache[1] = vInit[0..., 0..., (L - 1)..., 0...] |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + let out = MLXFast.scaledDotProductAttention( |
| 151 | + queries: queries, keys: keys, values: values, scale: scale, mask: mask |
| 152 | + ) |
| 153 | + .transposed(0, 2, 1, 3) |
| 154 | + .reshaped(B, L, -1) |
| 155 | + |
| 156 | + return oProj(out) |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +private class MLP: Module, UnaryLayer { |
| 161 | + @ModuleInfo(key: "gate_proj") var gateProj: Linear |
| 162 | + @ModuleInfo(key: "up_proj") var upProj: Linear |
| 163 | + @ModuleInfo(key: "down_proj") var downProj: Linear |
| 164 | + |
| 165 | + init(_ config: BaichuanM1Configuration) { |
| 166 | + _gateProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false) |
| 167 | + _upProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false) |
| 168 | + _downProj.wrappedValue = Linear(config.intermediateSize, config.hiddenSize, bias: false) |
| 169 | + } |
| 170 | + |
| 171 | + func callAsFunction(_ x: MLXArray) -> MLXArray { |
| 172 | + return downProj(silu(gateProj(x)) * upProj(x)) |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +private class DecoderLayer: Module { |
| 177 | + @ModuleInfo(key: "self_attn") var attention: Attention |
| 178 | + let mlp: MLP |
| 179 | + @ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm |
| 180 | + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm |
| 181 | + |
| 182 | + init(_ config: BaichuanM1Configuration, layerIdx: Int) { |
| 183 | + _attention.wrappedValue = Attention(config, layerIdx: layerIdx) |
| 184 | + self.mlp = MLP(config) |
| 185 | + _inputLayernorm.wrappedValue = RMSNorm( |
| 186 | + dimensions: config.hiddenSize, eps: config.rmsNormEps) |
| 187 | + _postAttentionLayernorm.wrappedValue = RMSNorm( |
| 188 | + dimensions: config.hiddenSize, eps: config.rmsNormEps) |
| 189 | + } |
| 190 | + |
| 191 | + func callAsFunction( |
| 192 | + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? |
| 193 | + ) -> MLXArray { |
| 194 | + var r = attention(inputLayernorm(x), mask: mask, cache: cache) |
| 195 | + let x = x + r |
| 196 | + r = mlp(postAttentionLayernorm(x)) |
| 197 | + return x + r |
| 198 | + } |
| 199 | +} |
| 200 | + |
| 201 | +private class BaichuanM1ModelInner: Module { |
| 202 | + let args: BaichuanM1Configuration |
| 203 | + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding |
| 204 | + |
| 205 | + fileprivate let layers: [DecoderLayer] |
| 206 | + let norm: RMSNorm |
| 207 | + |
| 208 | + init(_ config: BaichuanM1Configuration) { |
| 209 | + self.args = config |
| 210 | + _embedTokens.wrappedValue = Embedding( |
| 211 | + embeddingCount: config.vocabularySize, dimensions: config.hiddenSize) |
| 212 | + self.layers = (0 ..< config.hiddenLayers).map { DecoderLayer(config, layerIdx: $0) } |
| 213 | + norm = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps) |
| 214 | + } |
| 215 | + |
| 216 | + func callAsFunction( |
| 217 | + _ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, |
| 218 | + cache: [KVCache]? |
| 219 | + ) -> MLXArray { |
| 220 | + var x = embedTokens(inputs) |
| 221 | + |
| 222 | + let mask = mask ?? createAttentionMask(h: x, cache: cache) |
| 223 | + |
| 224 | + for (i, layer) in layers.enumerated() { |
| 225 | + x = layer(x, mask: mask, cache: cache?[i]) |
| 226 | + } |
| 227 | + |
| 228 | + return norm(x) |
| 229 | + } |
| 230 | +} |
| 231 | + |
| 232 | +public class BaichuanM1Model: Module, LLMModel, KVCacheDimensionProvider { |
| 233 | + |
| 234 | + public let vocabularySize: Int |
| 235 | + public let kvHeads: [Int] |
| 236 | + |
| 237 | + private let model: BaichuanM1ModelInner |
| 238 | + let configuration: BaichuanM1Configuration |
| 239 | + |
| 240 | + @ModuleInfo(key: "lm_head") var lmHead: Linear? |
| 241 | + |
| 242 | + public init(_ config: BaichuanM1Configuration) { |
| 243 | + self.configuration = config |
| 244 | + self.vocabularySize = config.vocabularySize |
| 245 | + self.kvHeads = Array(repeating: config.kvHeads, count: config.hiddenLayers) |
| 246 | + self.model = BaichuanM1ModelInner(config) |
| 247 | + |
| 248 | + if !config.tieWordEmbeddings { |
| 249 | + _lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false) |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { |
| 254 | + var outputs = model(inputs, cache: cache) |
| 255 | + |
| 256 | + if let lmHead { |
| 257 | + outputs = lmHead(outputs) |
| 258 | + } |
| 259 | + |
| 260 | + return outputs |
| 261 | + } |
| 262 | + |
| 263 | + public func newCache(parameters: GenerateParameters?) -> [KVCache] { |
| 264 | + return model.layers.enumerated().map { (i, _) in |
| 265 | + let isSWA = configuration.slidingWindowLayers.contains(i) |
| 266 | + let convCache = MambaCache() |
| 267 | + let kvCache: KVCache = |
| 268 | + isSWA ? RotatingKVCache(maxSize: configuration.slidingWindow) : KVCacheSimple() |
| 269 | + return CacheList(convCache, kvCache) |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { |
| 274 | + var weights = weights |
| 275 | + let isQuantized = weights["lm_head.scales"] != nil |
| 276 | + |
| 277 | + if !isQuantized, let w = weights["lm_head.weight"] { |
| 278 | + var w = w |
| 279 | + if w.dtype != .float32 { |
| 280 | + w = w.asType(.float32) |
| 281 | + } |
| 282 | + |
| 283 | + let norm = sqrt(sum(w * w, axes: [-1], keepDims: true)) |
| 284 | + w = (w / (norm + 1e-7)).asType(w.dtype) |
| 285 | + weights["lm_head.weight"] = w |
| 286 | + } |
| 287 | + |
| 288 | + if configuration.tieWordEmbeddings { |
| 289 | + weights["lm_head.weight"] = nil |
| 290 | + } |
| 291 | + |
| 292 | + return weights |
| 293 | + } |
| 294 | +} |
| 295 | + |
| 296 | +extension BaichuanM1Model: LoRAModel { |
| 297 | + public func loraLinearLayers() -> LoRALinearLayers { |
| 298 | + model.layers.map { ($0.attention, ["W_pack"]) } |
| 299 | + } |
| 300 | +} |
0 commit comments