Skip to content

Commit eb10e75

Browse files
authored
Port of Ernie4 5 (#348)
* Port of Ernie4 5
1 parent 8dfd9ee commit eb10e75

File tree

3 files changed

+244
-1
lines changed

3 files changed

+244
-1
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class LLMEvaluator {
232232
let timeTool = Tool<EmptyInput, TimeOutput>(
233233
name: "get_time",
234234
description: "Get the current time",
235-
parameters: [],
235+
parameters: []
236236
) { _ in
237237
TimeOutput(time: Date.now.formatted())
238238
}

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
4949
"glm4": create(GLM4Configuration.self, GLM4Model.init),
5050
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
5151
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
52+
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
5253
]
5354
}
5455

@@ -231,6 +232,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
231232
defaultPrompt: "Why is the sky blue?"
232233
)
233234

235+
static public let ernie_45_0_3BPT_bf16_ft = ModelConfiguration(
236+
id: "mlx-community/ERNIE-4.5-0.3B-PT-bf16-ft",
237+
defaultPrompt: "Why is the sky blue?"
238+
)
239+
234240
private static func all() -> [ModelConfiguration] {
235241
[
236242
codeLlama13b4bit,
@@ -263,6 +269,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
263269
glm4_9b_4bit,
264270
acereason_7b_4bit,
265271
bitnet_b1_58_2b_4t_4bit,
272+
ernie_45_0_3BPT_bf16_ft,
266273
]
267274
}
268275

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
//
2+
// Ernie4_5.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Sachin Desai on 7/3/25.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXLMCommon
11+
import MLXNN
12+
13+
// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ernie4_5.py
14+
15+
public struct Ernie45Configuration: Codable {
16+
var hiddenSize: Int
17+
var intermediateSize: Int
18+
var maxPositionEmbeddings: Int
19+
var numAttentionHeads: Int
20+
var numKeyValueHeads: Int
21+
var headDim: Int?
22+
var numHiddenLayers: Int
23+
var rmsNormEps: Float
24+
var vocabularySize: Int
25+
var ropeTheta: Float
26+
var useBias: Bool
27+
var tieWordEmbeddings: Bool
28+
29+
enum CodingKeys: String, CodingKey {
30+
case hiddenSize = "hidden_size"
31+
case intermediateSize = "intermediate_size"
32+
case maxPositionEmbeddings = "max_position_embeddings"
33+
case numAttentionHeads = "num_attention_heads"
34+
case numKeyValueHeads = "num_key_value_heads"
35+
case headDim = "head_dim"
36+
case numHiddenLayers = "num_hidden_layers"
37+
case rmsNormEps = "rms_norm_eps"
38+
case vocabularySize = "vocab_size"
39+
case ropeTheta = "rope_theta"
40+
case useBias = "use_bias"
41+
case tieWordEmbeddings = "tie_word_embeddings"
42+
}
43+
44+
public init(from decoder: Decoder) throws {
45+
let container: KeyedDecodingContainer<Ernie45Configuration.CodingKeys> =
46+
try decoder.container(keyedBy: Ernie45Configuration.CodingKeys.self)
47+
48+
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
49+
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
50+
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
51+
self.numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads)
52+
self.numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads)
53+
self.headDim = try container.decode(Int.self, forKey: .headDim)
54+
self.numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers)
55+
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
56+
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
57+
self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta)
58+
self.useBias = try container.decode(Bool.self, forKey: .useBias)
59+
self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings)
60+
}
61+
}
62+
63+
private class Attention: Module {
64+
let nHeads: Int
65+
let nKVHeads: Int
66+
let headDim: Int
67+
let scale: Float
68+
69+
@ModuleInfo(key: "q_proj") var qProj: Linear
70+
@ModuleInfo(key: "k_proj") var kProj: Linear
71+
@ModuleInfo(key: "v_proj") var vProj: Linear
72+
@ModuleInfo(key: "o_proj") var oProj: Linear
73+
74+
let rope: RoPE
75+
76+
public init(_ args: Ernie45Configuration) {
77+
let dim = args.hiddenSize
78+
self.nHeads = args.numAttentionHeads
79+
self.nKVHeads = args.numKeyValueHeads
80+
self.headDim = args.headDim ?? (dim / args.numAttentionHeads)
81+
self.scale = pow(Float(headDim), -0.5)
82+
83+
self._qProj.wrappedValue = Linear(dim, nHeads * headDim, bias: args.useBias)
84+
self._kProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias)
85+
self._vProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.useBias)
86+
self._oProj.wrappedValue = Linear(nHeads * headDim, dim, bias: args.useBias)
87+
88+
self.rope = RoPE(
89+
dimensions: headDim,
90+
traditional: true,
91+
base: args.ropeTheta
92+
)
93+
}
94+
95+
public func callAsFunction(
96+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
97+
) -> MLXArray {
98+
let (B, L) = (x.dim(0), x.dim(1))
99+
100+
var queries = qProj(x)
101+
var keys = kProj(x)
102+
var values = vProj(x)
103+
104+
queries = queries.reshaped(B, L, nHeads, -1).transposed(0, 2, 1, 3)
105+
keys = keys.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
106+
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
107+
108+
if let cache {
109+
queries = rope(queries, offset: cache.offset)
110+
keys = rope(keys, offset: cache.offset)
111+
} else {
112+
queries = rope(queries)
113+
keys = rope(keys)
114+
}
115+
116+
let output = attentionWithCacheUpdate(
117+
queries: queries,
118+
keys: keys,
119+
values: values,
120+
cache: cache,
121+
scale: scale,
122+
mask: mask
123+
)
124+
.transposed(0, 2, 1, 3)
125+
.reshaped(B, L, -1)
126+
127+
return oProj(output)
128+
}
129+
}
130+
131+
private class MLP: Module, UnaryLayer {
132+
@ModuleInfo(key: "gate_proj") var gateProj: Linear
133+
@ModuleInfo(key: "down_proj") var downProj: Linear
134+
@ModuleInfo(key: "up_proj") var upProj: Linear
135+
136+
public init(dim: Int, hiddenDim: Int, useBias: Bool = false) {
137+
self._gateProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias)
138+
self._downProj.wrappedValue = Linear(hiddenDim, dim, bias: useBias)
139+
self._upProj.wrappedValue = Linear(dim, hiddenDim, bias: useBias)
140+
}
141+
142+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
143+
downProj(silu(gateProj(x)) * upProj(x))
144+
}
145+
}
146+
147+
private class DecoderLayer: Module {
148+
@ModuleInfo(key: "self_attn") var attention: Attention
149+
let mlp: MLP
150+
151+
@ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm
152+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm
153+
154+
public init(_ args: Ernie45Configuration) {
155+
self._attention.wrappedValue = Attention(args)
156+
self.mlp = MLP(
157+
dim: args.hiddenSize, hiddenDim: args.intermediateSize, useBias: args.useBias)
158+
self._inputLayernorm.wrappedValue = RMSNorm(
159+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
160+
self._postAttentionLayernorm.wrappedValue = RMSNorm(
161+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
162+
}
163+
164+
public func callAsFunction(
165+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
166+
) -> MLXArray {
167+
var r = attention(inputLayernorm(x), mask: mask, cache: cache)
168+
let h = x + r
169+
r = mlp(postAttentionLayernorm(h))
170+
return h + r
171+
}
172+
}
173+
174+
private class Ernie45ModelInner: Module {
175+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
176+
let layers: [DecoderLayer]
177+
let norm: RMSNorm
178+
179+
public init(_ args: Ernie45Configuration) {
180+
self._embedTokens.wrappedValue = Embedding(
181+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize
182+
)
183+
self.layers = (0 ..< args.numHiddenLayers).map { _ in
184+
DecoderLayer(args)
185+
}
186+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
187+
}
188+
189+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
190+
var h = embedTokens(inputs)
191+
192+
let mask = createAttentionMask(h: h, cache: cache)
193+
194+
for (i, layer) in layers.enumerated() {
195+
h = layer(h, mask: mask, cache: cache?[i])
196+
}
197+
198+
return norm(h)
199+
}
200+
}
201+
202+
public class Ernie45Model: Module, LLMModel, KVCacheDimensionProvider {
203+
public let vocabularySize: Int
204+
public let kvHeads: [Int]
205+
206+
private let model: Ernie45ModelInner
207+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
208+
209+
public init(_ args: Ernie45Configuration) {
210+
self.vocabularySize = args.vocabularySize
211+
self.kvHeads = Array(repeating: args.numKeyValueHeads, count: args.numHiddenLayers)
212+
self.model = Ernie45ModelInner(args)
213+
214+
if !args.tieWordEmbeddings {
215+
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
216+
}
217+
}
218+
219+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
220+
let out = model(inputs, cache: cache)
221+
222+
if let lmHead {
223+
return lmHead(out)
224+
} else {
225+
return model.embedTokens.asLinear(out)
226+
}
227+
}
228+
}
229+
230+
// MARK: - LoRA
231+
232+
extension Ernie45Model: LoRAModel {
233+
public func loraLinearLayers() -> LoRALinearLayers {
234+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
235+
}
236+
}

0 commit comments

Comments
 (0)