feat: add the MoE FFN graph primitive (router + top-k dispatch)#559
Merged
Conversation
Add a shared Mixture-of-Experts FFN graph block to the OpenXLA (xla-iree) emitter, the foundation the MoE families (issue #501) build on. A MoE layer replaces the dense SwiGLU MLP with a router that selects the top-k of N experts: a softmax over all experts forms the routing probabilities, the top-k are picked (iterative argmax, reproducing torch.topk's lower-index tie-break), renormalized to sum to one when norm_topk_prob and scaled by routed_scaling_factor, then used to combine the selected experts' SwiGLU outputs, plus an optional shared-expert branch (Qwen2-MoE gates it by sigmoid(x @ Wg^T); DeepSeek adds it ungated). The dispatch is masked-dense over the mlx-lm stacked switch_mlp weights (batched over the expert axis), which is numerically identical to gathering only the selected experts (0 * y == 0 and x + 0 == x in IEEE-754); a sparse gather is a throughput follow-up that does not change the result. The new emitter/moe.rs owns the routing math, emitted once for [N, H] activations so single-token decode (reshaped to [1, H]), prefill, ragged decode, and batched decode all route from one authoring site, mirroring the shared attention core (#494). A dense (non-MoE) config emits no MoE op, so every shipped dense checkpoint stays byte-for-byte unchanged. Plumbing: - emitter/config.rs: Config::from_json parses the MoE fields (n_experts, n_experts_per_tok, moe_intermediate_size, norm_topk_prob, routed_scaling_factor, shared-expert config, first_k_dense) for Mixtral and Qwen2-MoE, whose attention this emitter already reproduces; Qwen3-MoE (per-head q/k norm) and DeepSeek-V2 (multi-head latent attention) are deferred to #501 with a clear message rather than mis-emitted. - emitter/model.rs: the per-layer arg schema takes the router + stacked experts + optional shared expert (take_moe_weights) in the canonical MoE order; ffn_seq and the single-decode / batched inline paths dispatch MoE vs dense; a standalone emit_moe_probe isolates the MoE block for the execution check. - iree.rs: weight_names emits the mlx-lm switch_mlp expert names (per family prefix) in the emitted arg order, and load_weights dequantizes the rank-3 stacked expert weights. - weights.rs: dequantize_affine_stacked dequantizes each expert slab of a [experts, out, in_packed] MLX affine weight. Validation: - Structural: a tiny qwen2_moe fixture (crate::validation QWEN2_MOE_TINY) freezes the prefill / decode / ragged serve graphs and guards them byte-for-byte in CI. - Execution: spike/openxla/moe_oracle.py compiles the emitted MoE block probe with IREE and compares it to a genuine HF fp32 MoE block; the emitter matches Qwen2-MoE (shared expert + gate + norm_topk_prob) and Mixtral (no shared, always renormalized) to ~1e-9 (fp32 reduction-order noise), proving the routing / dispatch math. Full-model token-exact on a real MoE checkpoint is deferred: the smallest on disk (deepseek-v2-lite) uses multi-head latent attention, and the others (Qwen3-MoE q/k norm; Mixtral / Phi-3.5-MoE at 22-25 GB) either need attention work or are too large to run in-agent, so the routing math is proven at the block level against the HF oracle instead. Those attentions, and the full-checkpoint execution gate, land with the families in #501. Refs #500
Integrate the dense arch pack (#558: Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) that landed on main while the MoE FFN primitive (#500) was in flight. Both touch the same emitter seams (Config fields, LayerW, take_layer_weights, the MLP helpers, weight_names), so the resolution weaves them: the dense MLP weights (down/gate/up) and input_layernorm are both optional now, take_layer_weights and weight_names branch dense vs MoE while sharing the qk-norm / feed-forward-norm conditionals, and the single-decode / seq FFN dispatch (ffn_single / ffn_seq) picks the MoE block or the dense single_mlp / seq_mlp. Validated after the merge: cargo test -p mlxcel-xla --lib (80 passed, byte-exact goldens for both the dense fixtures and the qwen2-moe-tiny MoE fixture), cargo clippy -D warnings clean, and spike/openxla/moe_oracle.py token-exact (~1e-9) vs the HF Qwen2-MoE and Mixtral fp32 blocks.
Integrate the #497/#498/#499 dense arch packs from main into the #500 MoE FFN branch and resolve the semantic conflict in the emitter. main restructured emitter/model.rs so emit_attention returns (hn, attn_out) and emit_transformer_layer owns the residual (sequential or parallel) and emits the dense FFN via emit_mlp_body. The MoE FFN is wired in as the alternative FFN body selected by is_moe_layer: a new emit_ffn_body dispatches to moe::moe_block (over the pre-normed hidden, no residual) on a MoE layer and to emit_mlp_body on a dense layer, so emit_transformer_layer keeps main's residual ownership and every dense arch emits byte-for-byte as main does. The single-token decode path reshapes [H] to [1, H] for the seq-only MoE primitive; the superseded uniform-B batched decode gets the same MoE-vs-dense split inline. The now-unused moe_ffn_seq/moe_ffn_single wrappers are dropped (moe_block is the one primitive). Weight handling: the MoE expert bank (router, stacked switch_mlp gate/up/down, optional shared expert plus its sigmoid gate) is appended per MoE layer in weights::weight_specs (the single ordering authority main introduced), mirroring take_moe_weights in the emitter; the branch's inline iree.rs weight_names is dropped in favor of that path. The loader gains a rank-3 stacked-expert dequant (dequantize_affine_stacked) alongside the rank-2 path, keeping the Phi3 fused-split dense path intact. config.rs unions the MoE fields (MoeConfig / SharedExpertConfig / moe / is_moe_layer, Mixtral + Qwen2-MoE detection) onto main's already-unioned dense-pack config. validation.rs registers the qwen2-moe-tiny fixture on the byte-exact gate next to the 12 dense fixtures; its decode golden is re-frozen for the norm-at-[H] single-token path (a numerically identical single-row RMSNorm). Validation: cargo test -p mlxcel-xla --lib is 108 passed / 0 failed (all dense goldens byte-identical, MoE fixture green); clippy and fmt clean; moe_oracle.py is block-exact vs HF Qwen2-MoE and Mixtral (~1e-9); dense_arch_check.py is token-exact for the #498 pack.
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add a shared Mixture-of-Experts FFN graph primitive to the OpenXLA (xla-iree) emitter (issue #500), the foundation the MoE families (#501) build on. A MoE layer replaces the dense SwiGLU MLP with a router that selects the top-k of N experts and combines their SwiGLU outputs, plus an optional shared-expert branch. The routing math is proven token-exact against a genuine HF fp32 MoE block, and the emitted graphs are guarded byte-for-byte in CI.
What changed
The MoE FFN graph primitive (
emitter/moe.rs, new)norm_topk_prob, then scale byrouted_scaling_factor, then a weighted combine of the selected experts' SwiGLU outputs.sigmoid(x @ Wg^T)for Qwen2-MoE, ungated for DeepSeek.switch_mlpweights (batched over the expert axis), numerically identical to gathering only the selected experts (0 * y == 0,x + 0 == xin IEEE-754); a sparse gather is a throughput follow-up that does not change the result.[N, H]activations, so single-token decode (reshaped to[1, H]), prefill, ragged decode, and batched decode all route from one authoring site (mirroring the shared attention core, refactor: share the full per-layer attention core across all emitter graph kinds #494). A dense (non-MoE) config emits no MoE op, so every shipped dense checkpoint stays byte-for-byte unchanged.Plumbing
emitter/config.rs:Config::from_jsonparses the MoE fields (n_experts,n_experts_per_tok,moe_intermediate_size,norm_topk_prob,routed_scaling_factor, shared-expert config,first_k_dense) for Mixtral and Qwen2-MoE, whose attention this emitter already reproduces. Qwen3-MoE (per-head q/k norm) and DeepSeek-V2 (multi-head latent attention) are deferred to feat: MoE family coverage on the shared FFN primitive (Mixtral, Qwen2/Qwen3-MoE, OLMoE, DeepSeek-V2/V3, gpt-oss, and more) #501 with a clear message rather than mis-emitted.emitter/model.rs:take_moe_weightstakes the router + stacked experts + optional shared expert in the canonical per-layer MoE order; the FFN sites dispatch MoE vs dense;emit_moe_probeisolates the MoE block for the execution check.iree.rs:weight_namesemits the mlx-lmswitch_mlpexpert names (per family prefix) in the emitted arg order;load_weightsdequantizes the rank-3 stacked expert weights.weights.rs:dequantize_affine_stackeddequantizes each expert slab of a[experts, out, in_packed]MLX affine weight.Validation
qwen2_moefixture (crate::validation::QWEN2_MOE_TINY,assets/qwen2-moe-tiny/) freezes the prefill / decode / ragged serve graphs and guards them byte-for-byte.spike/openxla/moe_oracle.pycompiles the emitted MoE block probe with IREE (llvm-cpu) and compares it to a genuine HF fp32 MoE block. The emitter matches Qwen2-MoE (shared expert + gate +norm_topk_prob) and Mixtral (no shared expert, always renormalized) to ~1e-9 (fp32 reduction-order noise).Deferred (documented, allowed by the epic)
Full-model token-exact on a real MoE checkpoint is deferred: the smallest on disk (deepseek-v2-lite) uses multi-head latent attention, and the others (Qwen3-MoE q/k norm; Mixtral / Phi-3.5-MoE at 22-25 GB) either need attention work or are too large to run in-agent. The routing / dispatch math is instead proven at the block level against the HF fp32 oracle. The full-checkpoint execution gate (CLI
MLXCEL_BACKEND=xlagenerate, serve reference-exact viaxla_batch_bench, IREE-CUDA on GB10) and the deferred attentions land with the families in #501.Test plan
cargo test -p mlxcel-xla --lib(71 passed; byte-exact MoE fixture, structural MoE tests, stacked-dequant tests)cargo clippy -p mlxcel-xla --lib --tests -- -D warnings(clean)cargo fmt -p mlxcel-xlaspike/openxla/.venv/bin/python spike/openxla/moe_oracle.py(RESULT: PASS; Qwen2-MoE + Mixtral blocks token-exact ~1e-9 vs HF)Closes #500