Skip to content

Commit d339a82

Browse files
authored
add bailing moe (#395)
1 parent 46ec70a commit d339a82

File tree

2 files changed

+372
-0
lines changed

2 files changed

+372
-0
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
6060
"lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init),
6161
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
6262
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
63+
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
6364
]
6465
}
6566
}
@@ -314,6 +315,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
314315
defaultPrompt: "Why is the sky blue?"
315316
)
316317

318+
static public let ling_mini_2_2bit = ModelConfiguration(
319+
id: "mlx-community/Ling-mini-2.0-2bit-DWQ",
320+
defaultPrompt: "Why is the sky blue?"
321+
)
322+
317323
private static func all() -> [ModelConfiguration] {
318324
[
319325
codeLlama13b4bit,
@@ -359,6 +365,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
359365
lille_130m_bf16,
360366
olmoe_1b_7b_0125_instruct_4bit,
361367
olmo_2_1124_7B_Instruct_4bit,
368+
ling_mini_2_2bit,
362369
]
363370
}
364371

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
//
2+
// BailingMoe.swift
3+
// LLM
4+
//
5+
// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/bailing_moe.py
6+
// This architecture is used by the Ling-family models (e.g., Ling Mini).
7+
//
8+
9+
import Foundation
10+
import MLX
11+
import MLXLMCommon
12+
import MLXNN
13+
14+
public struct BailingMoeConfiguration: Codable, Sendable {
15+
var modelType: String
16+
var hiddenSize: Int
17+
var intermediateSize: Int
18+
var maxPositionEmbeddings: Int?
19+
var moeIntermediateSize: Int
20+
var numExperts: Int
21+
var numSharedExperts: Int
22+
var normTopkProb: Bool
23+
var attentionHeads: Int
24+
var numExpertsPerToken: Int
25+
var hiddenLayers: Int
26+
var kvHeads: Int
27+
var rmsNormEps: Float
28+
var ropeTheta: Float
29+
var vocabularySize: Int
30+
var firstKDenseReplace: Int
31+
32+
// Optional features
33+
var ropeScaling: [String: StringOrNumber]? = nil
34+
var useBias: Bool = false
35+
var useQKVBias: Bool = false
36+
var useQKNorm: Bool = false
37+
var tieWordEmbeddings: Bool = false
38+
var partialRotaryFactor: Float = 1.0
39+
var moeRouterEnableExpertBias: Bool = false
40+
var routedScalingFactor: Float = 1.0
41+
var scoreFunction: String = "softmax"
42+
var nGroup: Int = 1
43+
var topkGroup: Int = 4
44+
var moeSharedExpertIntermediateSize: Int? = nil
45+
46+
enum CodingKeys: String, CodingKey {
47+
case modelType = "model_type"
48+
case hiddenSize = "hidden_size"
49+
case intermediateSize = "intermediate_size"
50+
case maxPositionEmbeddings = "max_position_embeddings"
51+
case moeIntermediateSize = "moe_intermediate_size"
52+
case numExperts = "num_experts"
53+
case numSharedExperts = "num_shared_experts"
54+
case normTopkProb = "norm_topk_prob"
55+
case attentionHeads = "num_attention_heads"
56+
case numExpertsPerToken = "num_experts_per_tok"
57+
case hiddenLayers = "num_hidden_layers"
58+
case kvHeads = "num_key_value_heads"
59+
case rmsNormEps = "rms_norm_eps"
60+
case ropeTheta = "rope_theta"
61+
case vocabularySize = "vocab_size"
62+
case firstKDenseReplace = "first_k_dense_replace"
63+
case ropeScaling = "rope_scaling"
64+
case useBias = "use_bias"
65+
case useQKVBias = "use_qkv_bias"
66+
case useQKNorm = "use_qk_norm"
67+
case tieWordEmbeddings = "tie_word_embeddings"
68+
case partialRotaryFactor = "partial_rotary_factor"
69+
case moeRouterEnableExpertBias = "moe_router_enable_expert_bias"
70+
case routedScalingFactor = "routed_scaling_factor"
71+
case scoreFunction = "score_function"
72+
case nGroup = "n_group"
73+
case topkGroup = "topk_group"
74+
case moeSharedExpertIntermediateSize = "moe_shared_expert_intermediate_size"
75+
}
76+
}
77+
78+
private class Attention: Module {
79+
let args: BailingMoeConfiguration
80+
let heads: Int
81+
let kvHeads: Int
82+
let headDim: Int
83+
let ropeDim: Int
84+
let scale: Float
85+
86+
@ModuleInfo(key: "query_key_value") var qkv: Linear
87+
@ModuleInfo(key: "dense") var wo: Linear
88+
89+
@ModuleInfo(key: "query_layernorm") var qNorm: RMSNorm?
90+
@ModuleInfo(key: "key_layernorm") var kNorm: RMSNorm?
91+
92+
let rope: RoPE
93+
94+
init(_ args: BailingMoeConfiguration) {
95+
self.args = args
96+
self.heads = args.attentionHeads
97+
self.kvHeads = args.kvHeads
98+
self.headDim = args.hiddenSize / heads
99+
self.ropeDim = Int(Float(headDim) * args.partialRotaryFactor)
100+
self.scale = pow(Float(headDim), -0.5)
101+
102+
_qkv.wrappedValue = Linear(
103+
args.hiddenSize,
104+
(heads + 2 * kvHeads) * headDim,
105+
bias: args.useQKVBias
106+
)
107+
_wo.wrappedValue = Linear(heads * headDim, args.hiddenSize, bias: args.useBias)
108+
109+
if args.useQKNorm {
110+
_qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
111+
_kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
112+
} else {
113+
_qNorm.wrappedValue = nil
114+
_kNorm.wrappedValue = nil
115+
}
116+
117+
self.rope = RoPE(
118+
dimensions: ropeDim, traditional: false, base: args.ropeTheta,
119+
scale: 1.0)
120+
}
121+
122+
func callAsFunction(
123+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
124+
) -> MLXArray {
125+
let (B, L) = (x.dim(0), x.dim(1))
126+
127+
let qSize = heads * headDim
128+
let kSize = kvHeads * headDim
129+
let qkvOut = qkv(x)
130+
let splits = split(qkvOut, indices: [qSize, qSize + kSize], axis: -1)
131+
var queries = splits[0]
132+
var keys = splits[1]
133+
var values = splits[2]
134+
135+
// reshape to (B, L, H, Hd), apply optional per-head norms, then transpose to (B, H, L, Hd)
136+
queries = queries.reshaped(B, L, heads, -1)
137+
keys = keys.reshaped(B, L, kvHeads, -1)
138+
139+
if let qNorm { queries = qNorm(queries) }
140+
if let kNorm { keys = kNorm(keys) }
141+
142+
queries = queries.transposed(0, 2, 1, 3)
143+
keys = keys.transposed(0, 2, 1, 3)
144+
values = values.reshaped(B, L, kvHeads, -1).transposed(0, 2, 1, 3)
145+
146+
if let cache {
147+
queries = rope(queries, offset: cache.offset)
148+
keys = rope(keys, offset: cache.offset)
149+
} else {
150+
queries = rope(queries)
151+
keys = rope(keys)
152+
}
153+
154+
let output = attentionWithCacheUpdate(
155+
queries: queries,
156+
keys: keys,
157+
values: values,
158+
cache: cache,
159+
scale: scale,
160+
mask: mask
161+
)
162+
.transposed(0, 2, 1, 3)
163+
.reshaped(B, L, -1)
164+
165+
return wo(output)
166+
}
167+
}
168+
169+
private class BailingMoeMLP: Module, UnaryLayer {
170+
@ModuleInfo(key: "gate_proj") var gate: Linear
171+
@ModuleInfo(key: "down_proj") var down: Linear
172+
@ModuleInfo(key: "up_proj") var up: Linear
173+
174+
init(_ args: BailingMoeConfiguration, hiddenDims: Int? = nil) {
175+
let inter = hiddenDims ?? args.intermediateSize
176+
_gate.wrappedValue = Linear(args.hiddenSize, inter, bias: args.useBias)
177+
_down.wrappedValue = Linear(inter, args.hiddenSize, bias: args.useBias)
178+
_up.wrappedValue = Linear(args.hiddenSize, inter, bias: args.useBias)
179+
}
180+
181+
func callAsFunction(_ x: MLXArray) -> MLXArray { down(silu(gate(x)) * up(x)) }
182+
}
183+
184+
private class BailingMoeGate: Module, UnaryLayer {
185+
let topK: Int
186+
let nGroup: Int
187+
let topkGroup: Int
188+
let numExperts: Int
189+
let routedScalingFactor: Float
190+
let normTopkProb: Bool
191+
let scoreFunction: String
192+
193+
@ModuleInfo(key: "gate_proj") var gate: Linear
194+
@ModuleInfo(key: "expert_bias") var expertBias: MLXArray
195+
196+
init(_ args: BailingMoeConfiguration) {
197+
self.topK = args.numExpertsPerToken
198+
self.nGroup = args.nGroup
199+
self.topkGroup = args.topkGroup
200+
self.routedScalingFactor = args.routedScalingFactor
201+
self.normTopkProb = args.normTopkProb
202+
self.scoreFunction = args.scoreFunction
203+
self.numExperts = args.numExperts
204+
205+
_gate.wrappedValue = Linear(args.hiddenSize, args.numExperts, bias: false)
206+
_expertBias.wrappedValue = zeros([args.numExperts])
207+
}
208+
209+
func callAsFunction(_ x: MLXArray) -> MLXArray {
210+
// This returns a packed result not directly used; callers use groupSelect to get inds and scores.
211+
gate(x)
212+
}
213+
214+
func groupSelect(_ x: MLXArray) -> (inds: MLXArray, scores: MLXArray) {
215+
let (bsz, seqLen, h) = (x.dim(0), x.dim(1), x.dim(2))
216+
217+
let logits = gate(x)
218+
var scores = sigmoid(logits.asType(.float32))
219+
let scoresForChoice = scores + expertBias
220+
let groupScores = scoresForChoice.reshaped(bsz, seqLen, self.nGroup, -1)
221+
222+
let topKGroup = top(groupScores, k: 2, axis: -1).sum(axis: -1, keepDims: true)
223+
var k = nGroup - topkGroup
224+
var groupIdx = argPartition(topKGroup, kth: k - 1, axis: -2)[.ellipsis, ..<k, 0...]
225+
scores = putAlong(groupScores, groupIdx, values: MLXArray(0.0), axis: -2)
226+
scores = flattened(scores, start: -2, end: -1)
227+
228+
k = topK
229+
let inds = argPartition(-scores, kth: k - 1, axis: -1)[.ellipsis, ..<k]
230+
scores = takeAlong(scores, inds, axis: -1)
231+
if topK ?? 1 > 1, normTopkProb {
232+
let denominator = scores.sum(axis: -1, keepDims: true) + 1e-20
233+
scores = scores / denominator
234+
}
235+
scores = scores * routedScalingFactor
236+
return (inds, scores.asType(logits.dtype))
237+
}
238+
}
239+
240+
private class BailingMoeSparseMoeBlock: Module, UnaryLayer {
241+
let args: BailingMoeConfiguration
242+
@ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU
243+
@ModuleInfo(key: "gate") var gate: BailingMoeGate
244+
@ModuleInfo(key: "shared_experts") var sharedExperts: BailingMoeMLP?
245+
246+
init(_ args: BailingMoeConfiguration) {
247+
self.args = args
248+
_switchMLP.wrappedValue = SwitchGLU(
249+
inputDims: args.hiddenSize, hiddenDims: args.moeIntermediateSize,
250+
numExperts: args.numExperts,
251+
bias: args.useBias
252+
)
253+
_gate.wrappedValue = BailingMoeGate(args)
254+
255+
if args.numSharedExperts > 0 {
256+
let sharedDim =
257+
(args.moeSharedExpertIntermediateSize ?? args.moeIntermediateSize)
258+
* args.numSharedExperts
259+
_sharedExperts.wrappedValue = BailingMoeMLP(args, hiddenDims: sharedDim)
260+
} else {
261+
_sharedExperts.wrappedValue = nil
262+
}
263+
}
264+
265+
func callAsFunction(_ x: MLXArray) -> MLXArray {
266+
let (inds, weights) = gate.groupSelect(x)
267+
var out = switchMLP(x, inds)
268+
out = (out * weights[.ellipsis, .newAxis]).sum(axis: -2)
269+
if let shared = sharedExperts {
270+
out = out + shared(x)
271+
}
272+
return out
273+
}
274+
}
275+
276+
private class TransformerBlock: Module {
277+
let args: BailingMoeConfiguration
278+
let layerIdx: Int
279+
280+
@ModuleInfo(key: "attention") var attention: Attention
281+
@ModuleInfo(key: "mlp") var mlp: Module & UnaryLayer
282+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
283+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
284+
285+
init(_ args: BailingMoeConfiguration, layerIdx: Int) {
286+
self.args = args
287+
self.layerIdx = layerIdx
288+
289+
_attention.wrappedValue = Attention(args)
290+
_inputLayerNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
291+
_postAttentionLayerNorm.wrappedValue = RMSNorm(
292+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
293+
294+
if args.numExperts > 0 && layerIdx >= args.firstKDenseReplace {
295+
_mlp.wrappedValue = BailingMoeSparseMoeBlock(args)
296+
} else {
297+
_mlp.wrappedValue = BailingMoeMLP(args)
298+
}
299+
}
300+
301+
func callAsFunction(
302+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
303+
) -> MLXArray {
304+
let r = attention(inputLayerNorm(x), mask: mask, cache: cache)
305+
let h = x + r
306+
let r2 = mlp(postAttentionLayerNorm(h))
307+
return h + r2
308+
}
309+
}
310+
311+
private class BailingMoeModelInner: Module {
312+
@ModuleInfo(key: "word_embeddings") var embedTokens: Embedding
313+
let layers: [TransformerBlock]
314+
let norm: RMSNorm
315+
316+
init(_ args: BailingMoeConfiguration) {
317+
precondition(args.vocabularySize > 0)
318+
_embedTokens.wrappedValue = Embedding(
319+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
320+
self.layers = (0 ..< args.hiddenLayers).map { TransformerBlock(args, layerIdx: $0) }
321+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
322+
}
323+
324+
func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
325+
var h = embedTokens(inputs)
326+
let mask = createAttentionMask(h: h, cache: cache)
327+
for (i, layer) in layers.enumerated() {
328+
h = layer(h, mask: mask, cache: cache?[i])
329+
}
330+
return norm(h)
331+
}
332+
}
333+
334+
public class BailingMoeModel: Module, LLMModel, KVCacheDimensionProvider {
335+
public let vocabularySize: Int
336+
public let kvHeads: [Int]
337+
fileprivate let model: BailingMoeModelInner
338+
let configuration: BailingMoeConfiguration
339+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
340+
341+
public init(_ args: BailingMoeConfiguration) {
342+
self.configuration = args
343+
self.vocabularySize = args.vocabularySize
344+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
345+
self.model = BailingMoeModelInner(args)
346+
if !args.tieWordEmbeddings {
347+
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
348+
}
349+
}
350+
351+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
352+
let out = model(inputs, cache: cache)
353+
if let lmHead {
354+
return lmHead(out)
355+
} else {
356+
return model.embedTokens.asLinear(out)
357+
}
358+
}
359+
}
360+
361+
extension BailingMoeModel: LoRAModel {
362+
public func loraLinearLayers() -> LoRALinearLayers {
363+
model.layers.map { ($0.attention, ["query_key_value"]) }
364+
}
365+
}

0 commit comments

Comments
 (0)