Skip to content

feat: dense arch pack for parallel-block and norm-variant families (#498)#561

Merged
inureyes merged 2 commits into
mainfrom
feature/issue-498-dense-arch-pack
Jul 1, 2026
Merged

feat: dense arch pack for parallel-block and norm-variant families (#498)#561
inureyes merged 2 commits into
mainfrom
feature/issue-498-dense-arch-pack

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Adds seven dense LLM families to the OpenXLA (mlxcel-xla) StableHLO emitter as config-driven deltas on the shared per-layer core (#494): Cohere and Cohere2, Phi3, StableLM, StarCoder2, Granite, and MiniCPM (v1). Each is a set of Config switches, so it is authored once and reaches the single-token decode, ragged (continuous-batching) decode, and prefill graphs together. Part of #493.

What changed

  • emitter/config.rs: Config::from_json now detects each family and carries the new switches: mean-subtract LayerNorm (+ optional affine bias), a parallel attention+MLP block (single input_layernorm, one residual), interleaved (GPT-J) and partial RoPE, a per-layer RoPE gate (Cohere2 NoPE on full-attention layers), a dense (non-gated) GELU MLP, q/k/v/o and MLP biases, and the scalar multipliers (Granite attention_multiplier/embedding_multiplier/residual_multiplier/logits_scaling; MiniCPM scale_emb/scale_depth/dim_model_base; Cohere logit_scale). MiniCPM ships as model_type = "llama" with the scale_* fields, so it is detected by their presence.
  • emitter/model.rs, emitter/rope.rs: the per-layer body is factored into one shared emit_transformer_layer (attention parts + MLP body + residual, sequential or parallel) used by all three graph kinds; the norm generalizes to LayerNorm-with-bias, the RoPE to half-split / interleaved / partial, and the residual and final-logit scaling are added. Every change is guarded so the Llama family emits byte-for-byte identically.
  • weights.rs: the loader weight order and the Phi3 fused-projection split now live in weight_specs / slice_rows (pure logic, unit-tested without the IREE runtime). iree.rs's loader consumes them, loading each fused qkv_proj / gate_up_proj once and row-slicing it into the emitter's separate args.
  • validation.rs + assets/<arch>/: each family is registered as a byte-exact structural ArchFixture (small synthetic config + frozen prefill / decode goldens), gated by registered_fixtures_are_byte_exact, with per-delta invariant tests in emitter/mod.rs.
  • spike/openxla/dense_arch_check.py: an execution-check harness that emits each family's prefill graph, compiles/runs it with IREE, and compares last-token logits to an HF fp32 oracle on a synthetic model.

Validation

  • The six in-transformers families (Cohere, Cohere2, Phi3, StableLM, StarCoder2, Granite) are proven token-exact (last-token argmax match, max logit diff <= 9.3e-9) against an HF fp32 oracle via dense_arch_check.py; the goldens were frozen from that validated emitter.
  • The Llama-3.2-1B byte-exact goldens are unchanged and the Qwen2 / Gemma2 invariants hold, so the refactor is non-regressive for the existing archs.
  • MiniCPM (v1) execution correctness is covered transitively (Granite proves the scalar machinery; Phi3 / StableLM prove untied heads); its real-checkpoint gate is deferred (not in transformers, needs a hand oracle).
  • MiniCPM3 (MLA attention + LongRoPE) is a genuinely different attention and is rejected with a clear follow-up message rather than mis-emitted; it is deferred per the epic.
  • The real-checkpoint token-exact and serve (xla_batch_bench) gates run through the IREE feature build (scripts/xla/validate_arch.sh) post-merge, since the iree cargo feature is not built in-agent.

Test plan

  • cargo test -p mlxcel-xla --lib (73 passed): byte-exact fixtures for all 8 registered archs, weight-spec / Phi3-split unit tests, and per-family config + invariant tests
  • cargo clippy -p mlxcel-xla --lib --tests -- -D warnings clean; cargo fmt applied
  • spike/openxla/.venv/bin/python spike/openxla/dense_arch_check.py all -> all six families PASS token-exact vs HF fp32

Closes #498

Add seven dense LLM families to the OpenXLA (mlxcel-xla) StableHLO emitter as config-driven deltas on the shared per-layer core (issue #494): Cohere and Cohere2, Phi3, StableLM, StarCoder2, Granite, and MiniCPM (v1). Each is a set of `Config` switches, so it is authored once and reaches the single-token decode, ragged (continuous-batching) decode, and prefill graphs together. Part of #493.

Config detection (emitter/config.rs) grows the per-family switches: a mean-subtract LayerNorm with an optional affine bias (Cohere/Cohere2, StableLM, StarCoder2) alongside the existing RMSNorm and Gemma2 `(1+w)` norm; a parallel attention+MLP block with a single shared `input_layernorm` and one residual (Cohere/Cohere2); interleaved (GPT-J) and partial RoPE, plus a per-layer RoPE gate for Cohere2's position-free (NoPE) full-attention layers; a dense (non-gated) GELU MLP and q/k/v/o and MLP biases (StarCoder2); and the scalar multipliers (Granite `attention_multiplier` / `embedding_multiplier` / `residual_multiplier` / `logits_scaling`, MiniCPM `scale_emb` / `scale_depth` / `dim_model_base`, Cohere `logit_scale`). MiniCPM ships as `model_type = "llama"` but keeps the `scale_*` fields, so it is detected by their presence. MiniCPM3 (MLA attention plus LongRoPE) is rejected with a clear follow-up message rather than mis-emitted.

The emitter (emitter/model.rs, emitter/rope.rs) refactors the per-layer body into one shared `emit_transformer_layer` (attention parts + MLP body + residual, sequential or parallel) used by all three graph kinds, generalizes the norm to LayerNorm-with-bias, the RoPE to half-split / interleaved / partial, and adds the residual and final-logit scaling. Every change is guarded so the Llama family emits byte-for-byte identically (the committed Llama-3.2-1B goldens are unchanged, and the Qwen2 / Gemma2 invariants hold).

The loader weight order and the Phi3 fused-projection split live in weights.rs (`weight_specs`, unit-tested without the IREE runtime) and are consumed by the feature-gated iree.rs loader, which loads each fused `qkv_proj` / `gate_up_proj` once and row-slices it into the emitter's separate args.

Validation: each family is registered as a byte-exact structural `ArchFixture` in crate::validation (small synthetic config plus frozen prefill / decode goldens), gated by `registered_fixtures_are_byte_exact`, with per-delta invariant tests. Correctness of the six in-transformers families (Cohere, Cohere2, Phi3, StableLM, StarCoder2, Granite) is proven token-exact (last-token argmax, max logit diff <= 9.3e-9) against an HF fp32 oracle on a synthetic model via the new spike/openxla/dense_arch_check.py; the goldens were frozen from that validated emitter. The real-checkpoint token-exact and serve (xla_batch_bench) gates run through the IREE feature build (scripts/xla/validate_arch.sh) post-merge.
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions priority:medium Medium priority area:architecture Architecture and code structure changes status:review Under review status:done Completed and removed status:review Under review labels Jul 1, 2026
Brings the issue #498 branch (Cohere/Cohere2/Phi3/StableLM/StarCoder2/Granite/MiniCPM) up to date with origin/main, which added the #497 pack (Qwen3/Gemma1/Gemma3/SmolLM3/OLMo2/OLMo3, decomposed Config flags on the shared #494 attention core) and the #499 pack (Seed-OSS/MiMo/InternLM3/ExaOne, a WeightScheme field and a consolidated weight_names module). All three packs now coexist and emit correctly.

Config: one Config carries the union of all flags. main's decomposition of gemma2:bool into orthogonal flags (NormStyle/QkNorm/embed_scale/norm_one_plus/mlp_geglu/rope_local_base/sliding_pattern/use_rope_layers/weight_scheme) is kept, and #498's flags (layernorm/norm_bias/parallel_block/attn_o_bias/mlp_bias/dense_mlp/rope_interleaved/rotary_dim/rope_on_sliding_only/attention_multiplier/embedding_multiplier/residual_multiplier/logit_mul/logit_div/fused_qkv/fused_gate_up) ride alongside. Every pack's arch detection coexists and every rejection is kept (ERNIE-4.5 interleaved RoPE, MiniCPM3 MLA, yarn, out-of-scope biases and activations), with the o_proj/MLP-bias and non-silu guards scoped so StarCoder2/Granite are not wrongly rejected. sliding_pattern is unified to main's usize (Cohere2 sets its sliding_window_pattern).

model.rs (the hardest reconcile): #498's LayerNorm, parallel block, interleaved+partial RoPE, dense MLP, and residual+logit scaling are folded into main's shared emit_attention + AttnLayout core. emit_attention now returns (hn, attn_out) so emit_transformer_layer owns the residual (sequential with the Granite/MiniCPM residual multiplier, or Cohere's parallel x + attn(ln) + mlp(ln)); the merged emit_mlp_body/mlp_pre_norm drive the activation off mlp_geglu and the feed-forward norms off has_pre_ff_norm/has_post_ff_norm; the RoPE gate folds Cohere2's rope-on-sliding-only into layer_uses_rope alongside SmolLM3's NoPE mask; main's dual-RoPE (pick_rope) and qk-norm (apply_qk_norm) are preserved. Llama/Qwen2/Gemma2 and the #497/#499 arch goldens stay byte-identical.

Weight handling is one coherent layer: weights::weight_specs is the single ordering authority (keeping the Phi3 fused qkv_proj/gate_up_proj row-slice load path) and sources its per-arch names from weight_names::scheme_names, while weight_names::weight_names delegates to it; iree.rs loads via weight_specs (Whole/Rows), preserving main's shard and affine-dequant paths. rope.rs unions #497's local-base tables with #498's interleaved and partial (rotary_width) layout. validation.rs registers all 12 arch fixtures on one dense_fixture! macro parameterized by the goldens' sample mode.

Validation: cargo test -p mlxcel-xla --lib passes 97/0 (all 12 byte-exact fixtures plus the Llama/Qwen2/Gemma2 goldens), clippy and fmt clean. Execution parity against HF fp32 confirmed token-exact across all three packs (cohere/starcoder2/granite, gemma3/qwen3/olmo2, seed_oss).
@inureyes inureyes merged commit 8582731 into main Jul 1, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: dense arch pack for parallel-block and norm-variant families (Cohere/Cohere2, Phi3, StableLM, StarCoder2, Granite, MiniCPM/MiniCPM3)

1 participant