Skip to content

feat: dense arch pack for qk-norm and Gemma family (Qwen3, Gemma1, Gemma3, SmolLM3, OLMo2/3)#558

Merged
inureyes merged 1 commit into
mainfrom
feature/issue-497-dense-arch-pack
Jul 1, 2026
Merged

feat: dense arch pack for qk-norm and Gemma family (Qwen3, Gemma1, Gemma3, SmolLM3, OLMo2/3)#558
inureyes merged 1 commit into
mainfrom
feature/issue-497-dense-arch-pack

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Adds Qwen3, Gemma1, Gemma3, SmolLM3, and OLMo2/3 to the OpenXLA (mlxcel-xla) emitter as per-family deltas on the shared per-layer attention core (#494), the Gemma2 sliding-window schedule (#495), and the reusable validation harness (#496). Part of epic #493.

The monolithic gemma2: bool config switch is decomposed into orthogonal flags, so each dense family is a flag combination rather than a new code path, and Llama / Qwen2 / Gemma2 keep their exact previous flag combinations (every existing emitted graph and its committed golden is byte-for-byte unchanged).

What changed

  • emitter/config.rs: Config::from_json arch detection for the six new model_types, plus the orthogonal fields (embed_scale, norm_one_plus, mlp_geglu, NormStyle, QkNorm, rope_local_base, sliding_pattern, use_rope_layers) and their norm-placement / sliding / NoPE / local-RoPE helpers.
  • emitter/model.rs: the reserved refactor: share the full per-layer attention core across all emitter graph kinds #494 q/k-norm hook now applies per-head (Qwen3 / Gemma3) or flat (OLMo2/3) RMSNorm before RoPE; arch_norm skips the input norm for the OLMo reordered post-norm; seq_mlp / single_mlp are NormStyle-aware; AttnLayout::rope_qk selects the per-layer local RoPE table (Gemma3 dual RoPE); NoPE layers skip RoPE (SmolLM3).
  • emitter/rope.rs: a local-base RoPE table builder for the Gemma3 / OLMo3 sliding layers.
  • iree.rs: weight_names mirrors each family's arg schedule (conditional input_layernorm, q/k norms, feed-forward norms) so the loader lines up with the emitted graph.
  • validation.rs: registers each family as a golden-less structural fixture (the Qwen2.5 precedent), gated in CI via validate_arch.sh.
  • spike/openxla/dense_arch_pack_check.py (new): an IREE-CPU execution check that builds a small synthetic HF model per family, feeds identical random weights to the Rust-emitted prefill graph, and compares logits to the HF fp32 oracle. This proves the family math is correct (not merely self-consistent) without a heavy checkpoint or the xla-iree cargo feature.

Architecture deltas (verified against transformers 5.12.1 source)

Family Delta on the shared core
Qwen3 Qwen2-shaped, no q/k/v bias, per-head q/k RMSNorm (raw) before RoPE
Gemma1 Llama 2-norm structure + embed scale + (1+w) norm + gelu_tanh GeGLU
Gemma3 Gemma2 4-norm + per-head (1+w) q/k norm + dual local/global RoPE + 5:1 sliding, no soft-cap
SmolLM3 Llama + per-layer NoPE mask (no_rope_layers)
OLMo2 reordered post-norm (no input norm) + flat q/k RMSNorm + untied head
OLMo3 OLMo2 + sliding schedule (full-size yarn RoPE deferred)

Test plan

  • cargo test -p mlxcel-xla --lib (69 pass; existing byte-exact Llama goldens and Gemma2 / Qwen2 structural tests unchanged, confirming the decomposition is byte-preserving)
  • cargo clippy -p mlxcel-xla --lib --tests (clean) and cargo fmt
  • Execution check token-exact vs HF fp32 oracle (IREE llvm-cpu), max abs logit diff: Qwen3 0.0, Gemma1 3e-8, Gemma3 (dual RoPE) 4e-8, SmolLM3 (NoPE) 5e-10, OLMo2 2e-9 (all under the sub-0.01 tanh-near-tie bound)
  • validate_arch.sh --structural-only runs the new golden-less family signatures

Deferrals (per the epic's allowance)

  • Real-checkpoint end-to-end CLI (MLXCEL_BACKEND=xla) / serve token-exact runs need a local xla-iree build and are left to the orchestrator's post-merge token-exact gate; the emit is wired into the IREE loader path (weight_names) and proven correct by the synthetic HF oracle above.
  • OLMo3 is structure-complete but its full-size checkpoint uses yarn RoPE (rejected with a clear message, a documented follow-up) and is 32B, so its execution check is deferred. Its post-norm + flat q/k-norm + sliding structure is covered by the OLMo2 execution proof and a plain-RoPE OLMo3 parse test.

Closes #497

Add Qwen3, Gemma1, Gemma3, SmolLM3, and OLMo2/3 to the OpenXLA emitter as per-family deltas on the shared per-layer attention core (#494), the Gemma2 sliding-window schedule (#495), and the reusable validation harness (#496). Part of #493.

The monolithic `gemma2: bool` config switch is decomposed into orthogonal flags so each family is a flag combination rather than a new code path: `embed_scale`, `norm_one_plus`, `mlp_geglu`, a `NormStyle` enum (Plain pre-norm / GemmaFf four-norm / OLMo reordered post-norm), an optional `QkNorm` (per-head for Qwen3/Gemma3, flat for OLMo2/3), a per-layer local RoPE base (Gemma3 dual local/global RoPE), a generalized sliding-window pattern period, and a per-layer NoPE mask (SmolLM3). Llama / Qwen2 / Gemma2 keep their exact previous flag combinations, so every existing emitted graph (and its committed golden) is byte-for-byte unchanged.

Changes:

- emitter/config.rs: `Config::from_json` arch detection + the orthogonal fields and `NormStyle` / `QkNorm` types; norm-placement and sliding/NoPE/local-rope helpers.
- emitter/model.rs: the reserved #494 q/k-norm hook now applies per-head or flat RMSNorm before RoPE; `arch_norm` skips the input norm for OLMo post-norm; `seq_mlp` / `single_mlp` are norm-style aware; `AttnLayout::rope_qk` selects the local RoPE table per layer; NoPE layers skip RoPE.
- emitter/rope.rs: local-base RoPE table builder for the Gemma3 / OLMo3 sliding layers.
- iree.rs: `weight_names` mirrors the per-family arg schedule (conditional input norm, q/k norms, feed-forward norms) so the loader lines up with the emitted graph.
- validation.rs: registers each family as a golden-less structural fixture (the Qwen2.5 precedent), gated in CI via validate_arch.sh.
- spike/openxla/dense_arch_pack_check.py: an IREE-CPU execution check that builds a small synthetic HF model per family and compares the emitter's prefill logits to the HF fp32 oracle, so the family math is proven correct without a heavy checkpoint or the xla-iree feature.

Validation: 69 pure-Rust tests pass (existing byte-exact goldens unchanged). The execution check is token-exact vs HF fp32 for Qwen3 (0.0 max logit diff), Gemma1 (3e-8), Gemma3 including dual RoPE (4e-8), SmolLM3 NoPE (5e-10), and OLMo2 (2e-9). OLMo3 is structure-complete but its full-size checkpoint needs yarn RoPE (rejected with a clear message) and is 32B, so its execution check is deferred to the post-merge real-checkpoint gate.
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions priority:medium Medium priority area:models Model architectures, weights, loading, metadata area:architecture Architecture and code structure changes status:done Completed labels Jul 1, 2026
@inureyes inureyes merged commit cbc0f06 into main Jul 1, 2026
5 checks passed
inureyes added a commit that referenced this pull request Jul 1, 2026
Integrate the dense arch pack (#558: Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) that landed on main while the MoE FFN primitive (#500) was in flight. Both touch the same emitter seams (Config fields, LayerW, take_layer_weights, the MLP helpers, weight_names), so the resolution weaves them: the dense MLP weights (down/gate/up) and input_layernorm are both optional now, take_layer_weights and weight_names branch dense vs MoE while sharing the qk-norm / feed-forward-norm conditionals, and the single-decode / seq FFN dispatch (ffn_single / ffn_seq) picks the MoE block or the dense single_mlp / seq_mlp.

Validated after the merge: cargo test -p mlxcel-xla --lib (80 passed, byte-exact goldens for both the dense fixtures and the qwen2-moe-tiny MoE fixture), cargo clippy -D warnings clean, and spike/openxla/moe_oracle.py token-exact (~1e-9) vs the HF Qwen2-MoE and Mixtral fp32 blocks.
inureyes added a commit that referenced this pull request Jul 1, 2026
Integrate origin/main (PR #558, issue #497: the qk-norm / Gemma dense pack for Qwen3, Gemma1/3, SmolLM3, OLMo2/3) into feature/issue-499-dense-arch-pack (issue #499: the remaining-Llama-family pack for Seed-OSS, MiMo, InternLM3, ExaOne). The two sides restructured the same shared files with disjoint, complementary architecture sets; this merge unifies them so one Config detects and one emitter serves both.

config.rs: keep #497's orthogonal decomposition of the old `gemma2: bool` switch (NormStyle / QkNorm plus embed_scale / norm_one_plus / mlp_geglu / rope_local_base / sliding_pattern / use_rope_layers) and add #499's WeightScheme, the seed_oss / mimo / internlm3 / exaone match arms (config-driven q/k/v bias), the num_layers / layer_norm_epsilon alternate field reads, the default / dynamic rope types served as plain RoPE, and the ERNIE-4.5 / attention_out_bias / mlp_bias / non-SwiGLU rejections. Both flag sets now coexist on one Config, and both arch sets are detected.

Weight naming was the main structural conflict: #497 edited an inline weight_names in iree.rs, while #499 relocated naming into a new weight_names.rs module with a per-scheme name table. Consolidated into the module. iree.rs imports weight_names, and weight_names.rs folds #497's arch-conditional order (a skippable input_layernorm for the OLMo reordered post-norm, the q/k norm weights, and the independent pre / post feed-forward norms) into #499's Llama / ExaOne schemes. One naming layer covers both arch sets, and Llama / Qwen2 / Gemma2 name lists stay byte-identical.

validation.rs, assets, and the spike scripts are additive and disjoint, so both sides are kept. REGISTERED gains the four #499 golden fixtures (Seed-OSS / MiMo / InternLM3 / ExaOne); #497's golden-less STRUCTURAL_FAMILIES (Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) stay alongside them.

The emitter core (model.rs, rope.rs, builder.rs) is byte-identical to origin/main, and Llama / Qwen2 / Gemma2 parse to the same flags, so their emitted graphs are byte-for-byte unchanged (registered_fixtures_are_byte_exact passes). Validation: cargo test -p mlxcel-xla --lib is 84 passed / 0 failed; clippy -D warnings and fmt --check are clean; the spike execution checks are token-exact for both packs (Seed-OSS via synthetic_arch_check.py, and Qwen3 / Gemma1/3 / SmolLM3 / OLMo2 via dense_arch_pack_check.py).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes area:models Model architectures, weights, loading, metadata priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: dense arch pack for qk-norm and Gemma family (Qwen3, Gemma1, Gemma3, SmolLM3, OLMo2/OLMo3)

1 participant