feat: dense arch pack for parallel-block and norm-variant families (#498)#561
Merged
Conversation
Add seven dense LLM families to the OpenXLA (mlxcel-xla) StableHLO emitter as config-driven deltas on the shared per-layer core (issue #494): Cohere and Cohere2, Phi3, StableLM, StarCoder2, Granite, and MiniCPM (v1). Each is a set of `Config` switches, so it is authored once and reaches the single-token decode, ragged (continuous-batching) decode, and prefill graphs together. Part of #493. Config detection (emitter/config.rs) grows the per-family switches: a mean-subtract LayerNorm with an optional affine bias (Cohere/Cohere2, StableLM, StarCoder2) alongside the existing RMSNorm and Gemma2 `(1+w)` norm; a parallel attention+MLP block with a single shared `input_layernorm` and one residual (Cohere/Cohere2); interleaved (GPT-J) and partial RoPE, plus a per-layer RoPE gate for Cohere2's position-free (NoPE) full-attention layers; a dense (non-gated) GELU MLP and q/k/v/o and MLP biases (StarCoder2); and the scalar multipliers (Granite `attention_multiplier` / `embedding_multiplier` / `residual_multiplier` / `logits_scaling`, MiniCPM `scale_emb` / `scale_depth` / `dim_model_base`, Cohere `logit_scale`). MiniCPM ships as `model_type = "llama"` but keeps the `scale_*` fields, so it is detected by their presence. MiniCPM3 (MLA attention plus LongRoPE) is rejected with a clear follow-up message rather than mis-emitted. The emitter (emitter/model.rs, emitter/rope.rs) refactors the per-layer body into one shared `emit_transformer_layer` (attention parts + MLP body + residual, sequential or parallel) used by all three graph kinds, generalizes the norm to LayerNorm-with-bias, the RoPE to half-split / interleaved / partial, and adds the residual and final-logit scaling. Every change is guarded so the Llama family emits byte-for-byte identically (the committed Llama-3.2-1B goldens are unchanged, and the Qwen2 / Gemma2 invariants hold). The loader weight order and the Phi3 fused-projection split live in weights.rs (`weight_specs`, unit-tested without the IREE runtime) and are consumed by the feature-gated iree.rs loader, which loads each fused `qkv_proj` / `gate_up_proj` once and row-slices it into the emitter's separate args. Validation: each family is registered as a byte-exact structural `ArchFixture` in crate::validation (small synthetic config plus frozen prefill / decode goldens), gated by `registered_fixtures_are_byte_exact`, with per-delta invariant tests. Correctness of the six in-transformers families (Cohere, Cohere2, Phi3, StableLM, StarCoder2, Granite) is proven token-exact (last-token argmax, max logit diff <= 9.3e-9) against an HF fp32 oracle on a synthetic model via the new spike/openxla/dense_arch_check.py; the goldens were frozen from that validated emitter. The real-checkpoint token-exact and serve (xla_batch_bench) gates run through the IREE feature build (scripts/xla/validate_arch.sh) post-merge.
Brings the issue #498 branch (Cohere/Cohere2/Phi3/StableLM/StarCoder2/Granite/MiniCPM) up to date with origin/main, which added the #497 pack (Qwen3/Gemma1/Gemma3/SmolLM3/OLMo2/OLMo3, decomposed Config flags on the shared #494 attention core) and the #499 pack (Seed-OSS/MiMo/InternLM3/ExaOne, a WeightScheme field and a consolidated weight_names module). All three packs now coexist and emit correctly. Config: one Config carries the union of all flags. main's decomposition of gemma2:bool into orthogonal flags (NormStyle/QkNorm/embed_scale/norm_one_plus/mlp_geglu/rope_local_base/sliding_pattern/use_rope_layers/weight_scheme) is kept, and #498's flags (layernorm/norm_bias/parallel_block/attn_o_bias/mlp_bias/dense_mlp/rope_interleaved/rotary_dim/rope_on_sliding_only/attention_multiplier/embedding_multiplier/residual_multiplier/logit_mul/logit_div/fused_qkv/fused_gate_up) ride alongside. Every pack's arch detection coexists and every rejection is kept (ERNIE-4.5 interleaved RoPE, MiniCPM3 MLA, yarn, out-of-scope biases and activations), with the o_proj/MLP-bias and non-silu guards scoped so StarCoder2/Granite are not wrongly rejected. sliding_pattern is unified to main's usize (Cohere2 sets its sliding_window_pattern). model.rs (the hardest reconcile): #498's LayerNorm, parallel block, interleaved+partial RoPE, dense MLP, and residual+logit scaling are folded into main's shared emit_attention + AttnLayout core. emit_attention now returns (hn, attn_out) so emit_transformer_layer owns the residual (sequential with the Granite/MiniCPM residual multiplier, or Cohere's parallel x + attn(ln) + mlp(ln)); the merged emit_mlp_body/mlp_pre_norm drive the activation off mlp_geglu and the feed-forward norms off has_pre_ff_norm/has_post_ff_norm; the RoPE gate folds Cohere2's rope-on-sliding-only into layer_uses_rope alongside SmolLM3's NoPE mask; main's dual-RoPE (pick_rope) and qk-norm (apply_qk_norm) are preserved. Llama/Qwen2/Gemma2 and the #497/#499 arch goldens stay byte-identical. Weight handling is one coherent layer: weights::weight_specs is the single ordering authority (keeping the Phi3 fused qkv_proj/gate_up_proj row-slice load path) and sources its per-arch names from weight_names::scheme_names, while weight_names::weight_names delegates to it; iree.rs loads via weight_specs (Whole/Rows), preserving main's shard and affine-dequant paths. rope.rs unions #497's local-base tables with #498's interleaved and partial (rotary_width) layout. validation.rs registers all 12 arch fixtures on one dense_fixture! macro parameterized by the goldens' sample mode. Validation: cargo test -p mlxcel-xla --lib passes 97/0 (all 12 byte-exact fixtures plus the Llama/Qwen2/Gemma2 goldens), clippy and fmt clean. Execution parity against HF fp32 confirmed token-exact across all three packs (cohere/starcoder2/granite, gemma3/qwen3/olmo2, seed_oss).
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds seven dense LLM families to the OpenXLA (
mlxcel-xla) StableHLO emitter as config-driven deltas on the shared per-layer core (#494): Cohere and Cohere2, Phi3, StableLM, StarCoder2, Granite, and MiniCPM (v1). Each is a set ofConfigswitches, so it is authored once and reaches the single-token decode, ragged (continuous-batching) decode, and prefill graphs together. Part of #493.What changed
emitter/config.rs:Config::from_jsonnow detects each family and carries the new switches: mean-subtract LayerNorm (+ optional affine bias), a parallel attention+MLP block (singleinput_layernorm, one residual), interleaved (GPT-J) and partial RoPE, a per-layer RoPE gate (Cohere2 NoPE on full-attention layers), a dense (non-gated) GELU MLP, q/k/v/o and MLP biases, and the scalar multipliers (Graniteattention_multiplier/embedding_multiplier/residual_multiplier/logits_scaling; MiniCPMscale_emb/scale_depth/dim_model_base; Coherelogit_scale). MiniCPM ships asmodel_type = "llama"with thescale_*fields, so it is detected by their presence.emitter/model.rs,emitter/rope.rs: the per-layer body is factored into one sharedemit_transformer_layer(attention parts + MLP body + residual, sequential or parallel) used by all three graph kinds; the norm generalizes to LayerNorm-with-bias, the RoPE to half-split / interleaved / partial, and the residual and final-logit scaling are added. Every change is guarded so the Llama family emits byte-for-byte identically.weights.rs: the loader weight order and the Phi3 fused-projection split now live inweight_specs/slice_rows(pure logic, unit-tested without the IREE runtime).iree.rs's loader consumes them, loading each fusedqkv_proj/gate_up_projonce and row-slicing it into the emitter's separate args.validation.rs+assets/<arch>/: each family is registered as a byte-exact structuralArchFixture(small synthetic config + frozen prefill / decode goldens), gated byregistered_fixtures_are_byte_exact, with per-delta invariant tests inemitter/mod.rs.spike/openxla/dense_arch_check.py: an execution-check harness that emits each family's prefill graph, compiles/runs it with IREE, and compares last-token logits to an HF fp32 oracle on a synthetic model.Validation
dense_arch_check.py; the goldens were frozen from that validated emitter.transformers, needs a hand oracle).xla_batch_bench) gates run through the IREE feature build (scripts/xla/validate_arch.sh) post-merge, since theireecargo feature is not built in-agent.Test plan
cargo test -p mlxcel-xla --lib(73 passed): byte-exact fixtures for all 8 registered archs, weight-spec / Phi3-split unit tests, and per-family config + invariant testscargo clippy -p mlxcel-xla --lib --tests -- -D warningsclean;cargo fmtappliedspike/openxla/.venv/bin/python spike/openxla/dense_arch_check.py all-> all six families PASS token-exact vs HF fp32Closes #498