Skip to content

feat: add the MoE FFN graph primitive (router + top-k dispatch)#559

Merged
inureyes merged 3 commits into
mainfrom
feature/issue-500-moe-ffn-primitive
Jul 1, 2026
Merged

feat: add the MoE FFN graph primitive (router + top-k dispatch)#559
inureyes merged 3 commits into
mainfrom
feature/issue-500-moe-ffn-primitive

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Add a shared Mixture-of-Experts FFN graph primitive to the OpenXLA (xla-iree) emitter (issue #500), the foundation the MoE families (#501) build on. A MoE layer replaces the dense SwiGLU MLP with a router that selects the top-k of N experts and combines their SwiGLU outputs, plus an optional shared-expert branch. The routing math is proven token-exact against a genuine HF fp32 MoE block, and the emitted graphs are guarded byte-for-byte in CI.

What changed

The MoE FFN graph primitive (emitter/moe.rs, new)

  • Router linear, then softmax over ALL experts, then top-k selection (iterative argmax, reproducing torch.topk's lower-index tie-break), then renormalize to sum to one when norm_topk_prob, then scale by routed_scaling_factor, then a weighted combine of the selected experts' SwiGLU outputs.
  • Optional shared-expert branch: a parallel SwiGLU added to the routed output, gated by sigmoid(x @ Wg^T) for Qwen2-MoE, ungated for DeepSeek.
  • Masked-dense dispatch over the mlx-lm stacked switch_mlp weights (batched over the expert axis), numerically identical to gathering only the selected experts (0 * y == 0, x + 0 == x in IEEE-754); a sparse gather is a throughput follow-up that does not change the result.
  • Emitted once for [N, H] activations, so single-token decode (reshaped to [1, H]), prefill, ragged decode, and batched decode all route from one authoring site (mirroring the shared attention core, refactor: share the full per-layer attention core across all emitter graph kinds #494). A dense (non-MoE) config emits no MoE op, so every shipped dense checkpoint stays byte-for-byte unchanged.

Plumbing

  • emitter/config.rs: Config::from_json parses the MoE fields (n_experts, n_experts_per_tok, moe_intermediate_size, norm_topk_prob, routed_scaling_factor, shared-expert config, first_k_dense) for Mixtral and Qwen2-MoE, whose attention this emitter already reproduces. Qwen3-MoE (per-head q/k norm) and DeepSeek-V2 (multi-head latent attention) are deferred to feat: MoE family coverage on the shared FFN primitive (Mixtral, Qwen2/Qwen3-MoE, OLMoE, DeepSeek-V2/V3, gpt-oss, and more) #501 with a clear message rather than mis-emitted.
  • emitter/model.rs: take_moe_weights takes the router + stacked experts + optional shared expert in the canonical per-layer MoE order; the FFN sites dispatch MoE vs dense; emit_moe_probe isolates the MoE block for the execution check.
  • iree.rs: weight_names emits the mlx-lm switch_mlp expert names (per family prefix) in the emitted arg order; load_weights dequantizes the rank-3 stacked expert weights.
  • weights.rs: dequantize_affine_stacked dequantizes each expert slab of a [experts, out, in_packed] MLX affine weight.

Validation

  • Structural (in-CI gate): a tiny qwen2_moe fixture (crate::validation::QWEN2_MOE_TINY, assets/qwen2-moe-tiny/) freezes the prefill / decode / ragged serve graphs and guards them byte-for-byte.
  • Execution: spike/openxla/moe_oracle.py compiles the emitted MoE block probe with IREE (llvm-cpu) and compares it to a genuine HF fp32 MoE block. The emitter matches Qwen2-MoE (shared expert + gate + norm_topk_prob) and Mixtral (no shared expert, always renormalized) to ~1e-9 (fp32 reduction-order noise).

Deferred (documented, allowed by the epic)

Full-model token-exact on a real MoE checkpoint is deferred: the smallest on disk (deepseek-v2-lite) uses multi-head latent attention, and the others (Qwen3-MoE q/k norm; Mixtral / Phi-3.5-MoE at 22-25 GB) either need attention work or are too large to run in-agent. The routing / dispatch math is instead proven at the block level against the HF fp32 oracle. The full-checkpoint execution gate (CLI MLXCEL_BACKEND=xla generate, serve reference-exact via xla_batch_bench, IREE-CUDA on GB10) and the deferred attentions land with the families in #501.

Test plan

  • cargo test -p mlxcel-xla --lib (71 passed; byte-exact MoE fixture, structural MoE tests, stacked-dequant tests)
  • cargo clippy -p mlxcel-xla --lib --tests -- -D warnings (clean)
  • cargo fmt -p mlxcel-xla
  • spike/openxla/.venv/bin/python spike/openxla/moe_oracle.py (RESULT: PASS; Qwen2-MoE + Mixtral blocks token-exact ~1e-9 vs HF)

Closes #500

Add a shared Mixture-of-Experts FFN graph block to the OpenXLA (xla-iree) emitter, the foundation the MoE families (issue #501) build on. A MoE layer replaces the dense SwiGLU MLP with a router that selects the top-k of N experts: a softmax over all experts forms the routing probabilities, the top-k are picked (iterative argmax, reproducing torch.topk's lower-index tie-break), renormalized to sum to one when norm_topk_prob and scaled by routed_scaling_factor, then used to combine the selected experts' SwiGLU outputs, plus an optional shared-expert branch (Qwen2-MoE gates it by sigmoid(x @ Wg^T); DeepSeek adds it ungated). The dispatch is masked-dense over the mlx-lm stacked switch_mlp weights (batched over the expert axis), which is numerically identical to gathering only the selected experts (0 * y == 0 and x + 0 == x in IEEE-754); a sparse gather is a throughput follow-up that does not change the result.

The new emitter/moe.rs owns the routing math, emitted once for [N, H] activations so single-token decode (reshaped to [1, H]), prefill, ragged decode, and batched decode all route from one authoring site, mirroring the shared attention core (#494). A dense (non-MoE) config emits no MoE op, so every shipped dense checkpoint stays byte-for-byte unchanged.

Plumbing:

- emitter/config.rs: Config::from_json parses the MoE fields (n_experts, n_experts_per_tok, moe_intermediate_size, norm_topk_prob, routed_scaling_factor, shared-expert config, first_k_dense) for Mixtral and Qwen2-MoE, whose attention this emitter already reproduces; Qwen3-MoE (per-head q/k norm) and DeepSeek-V2 (multi-head latent attention) are deferred to #501 with a clear message rather than mis-emitted.
- emitter/model.rs: the per-layer arg schema takes the router + stacked experts + optional shared expert (take_moe_weights) in the canonical MoE order; ffn_seq and the single-decode / batched inline paths dispatch MoE vs dense; a standalone emit_moe_probe isolates the MoE block for the execution check.
- iree.rs: weight_names emits the mlx-lm switch_mlp expert names (per family prefix) in the emitted arg order, and load_weights dequantizes the rank-3 stacked expert weights.
- weights.rs: dequantize_affine_stacked dequantizes each expert slab of a [experts, out, in_packed] MLX affine weight.

Validation:

- Structural: a tiny qwen2_moe fixture (crate::validation QWEN2_MOE_TINY) freezes the prefill / decode / ragged serve graphs and guards them byte-for-byte in CI.
- Execution: spike/openxla/moe_oracle.py compiles the emitted MoE block probe with IREE and compares it to a genuine HF fp32 MoE block; the emitter matches Qwen2-MoE (shared expert + gate + norm_topk_prob) and Mixtral (no shared, always renormalized) to ~1e-9 (fp32 reduction-order noise), proving the routing / dispatch math.

Full-model token-exact on a real MoE checkpoint is deferred: the smallest on disk (deepseek-v2-lite) uses multi-head latent attention, and the others (Qwen3-MoE q/k norm; Mixtral / Phi-3.5-MoE at 22-25 GB) either need attention work or are too large to run in-agent, so the routing math is proven at the block level against the HF oracle instead. Those attentions, and the full-checkpoint execution gate, land with the families in #501.

Refs #500
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions priority:high High priority area:architecture Architecture and code structure changes status:done Completed labels Jul 1, 2026
inureyes added 2 commits July 1, 2026 11:15
Integrate the dense arch pack (#558: Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) that landed on main while the MoE FFN primitive (#500) was in flight. Both touch the same emitter seams (Config fields, LayerW, take_layer_weights, the MLP helpers, weight_names), so the resolution weaves them: the dense MLP weights (down/gate/up) and input_layernorm are both optional now, take_layer_weights and weight_names branch dense vs MoE while sharing the qk-norm / feed-forward-norm conditionals, and the single-decode / seq FFN dispatch (ffn_single / ffn_seq) picks the MoE block or the dense single_mlp / seq_mlp.

Validated after the merge: cargo test -p mlxcel-xla --lib (80 passed, byte-exact goldens for both the dense fixtures and the qwen2-moe-tiny MoE fixture), cargo clippy -D warnings clean, and spike/openxla/moe_oracle.py token-exact (~1e-9) vs the HF Qwen2-MoE and Mixtral fp32 blocks.
Integrate the #497/#498/#499 dense arch packs from main into the #500 MoE FFN branch and resolve the semantic conflict in the emitter.

main restructured emitter/model.rs so emit_attention returns (hn, attn_out) and emit_transformer_layer owns the residual (sequential or parallel) and emits the dense FFN via emit_mlp_body. The MoE FFN is wired in as the alternative FFN body selected by is_moe_layer: a new emit_ffn_body dispatches to moe::moe_block (over the pre-normed hidden, no residual) on a MoE layer and to emit_mlp_body on a dense layer, so emit_transformer_layer keeps main's residual ownership and every dense arch emits byte-for-byte as main does. The single-token decode path reshapes [H] to [1, H] for the seq-only MoE primitive; the superseded uniform-B batched decode gets the same MoE-vs-dense split inline. The now-unused moe_ffn_seq/moe_ffn_single wrappers are dropped (moe_block is the one primitive).

Weight handling: the MoE expert bank (router, stacked switch_mlp gate/up/down, optional shared expert plus its sigmoid gate) is appended per MoE layer in weights::weight_specs (the single ordering authority main introduced), mirroring take_moe_weights in the emitter; the branch's inline iree.rs weight_names is dropped in favor of that path. The loader gains a rank-3 stacked-expert dequant (dequantize_affine_stacked) alongside the rank-2 path, keeping the Phi3 fused-split dense path intact.

config.rs unions the MoE fields (MoeConfig / SharedExpertConfig / moe / is_moe_layer, Mixtral + Qwen2-MoE detection) onto main's already-unioned dense-pack config. validation.rs registers the qwen2-moe-tiny fixture on the byte-exact gate next to the 12 dense fixtures; its decode golden is re-frozen for the norm-at-[H] single-token path (a numerically identical single-row RMSNorm).

Validation: cargo test -p mlxcel-xla --lib is 108 passed / 0 failed (all dense goldens byte-identical, MoE fixture green); clippy and fmt clean; moe_oracle.py is block-exact vs HF Qwen2-MoE and Mixtral (~1e-9); dense_arch_check.py is token-exact for the #498 pack.
@inureyes inureyes merged commit 12059e7 into main Jul 1, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:high High priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: add the MoE FFN graph primitive (router + top-k expert dispatch)

1 participant