feat: dense arch pack for qk-norm and Gemma family (Qwen3, Gemma1, Gemma3, SmolLM3, OLMo2/3)#558
Merged
Merged
Conversation
Add Qwen3, Gemma1, Gemma3, SmolLM3, and OLMo2/3 to the OpenXLA emitter as per-family deltas on the shared per-layer attention core (#494), the Gemma2 sliding-window schedule (#495), and the reusable validation harness (#496). Part of #493. The monolithic `gemma2: bool` config switch is decomposed into orthogonal flags so each family is a flag combination rather than a new code path: `embed_scale`, `norm_one_plus`, `mlp_geglu`, a `NormStyle` enum (Plain pre-norm / GemmaFf four-norm / OLMo reordered post-norm), an optional `QkNorm` (per-head for Qwen3/Gemma3, flat for OLMo2/3), a per-layer local RoPE base (Gemma3 dual local/global RoPE), a generalized sliding-window pattern period, and a per-layer NoPE mask (SmolLM3). Llama / Qwen2 / Gemma2 keep their exact previous flag combinations, so every existing emitted graph (and its committed golden) is byte-for-byte unchanged. Changes: - emitter/config.rs: `Config::from_json` arch detection + the orthogonal fields and `NormStyle` / `QkNorm` types; norm-placement and sliding/NoPE/local-rope helpers. - emitter/model.rs: the reserved #494 q/k-norm hook now applies per-head or flat RMSNorm before RoPE; `arch_norm` skips the input norm for OLMo post-norm; `seq_mlp` / `single_mlp` are norm-style aware; `AttnLayout::rope_qk` selects the local RoPE table per layer; NoPE layers skip RoPE. - emitter/rope.rs: local-base RoPE table builder for the Gemma3 / OLMo3 sliding layers. - iree.rs: `weight_names` mirrors the per-family arg schedule (conditional input norm, q/k norms, feed-forward norms) so the loader lines up with the emitted graph. - validation.rs: registers each family as a golden-less structural fixture (the Qwen2.5 precedent), gated in CI via validate_arch.sh. - spike/openxla/dense_arch_pack_check.py: an IREE-CPU execution check that builds a small synthetic HF model per family and compares the emitter's prefill logits to the HF fp32 oracle, so the family math is proven correct without a heavy checkpoint or the xla-iree feature. Validation: 69 pure-Rust tests pass (existing byte-exact goldens unchanged). The execution check is token-exact vs HF fp32 for Qwen3 (0.0 max logit diff), Gemma1 (3e-8), Gemma3 including dual RoPE (4e-8), SmolLM3 NoPE (5e-10), and OLMo2 (2e-9). OLMo3 is structure-complete but its full-size checkpoint needs yarn RoPE (rejected with a clear message) and is 32B, so its execution check is deferred to the post-merge real-checkpoint gate.
4 tasks
4 tasks
inureyes
added a commit
that referenced
this pull request
Jul 1, 2026
Integrate the dense arch pack (#558: Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) that landed on main while the MoE FFN primitive (#500) was in flight. Both touch the same emitter seams (Config fields, LayerW, take_layer_weights, the MLP helpers, weight_names), so the resolution weaves them: the dense MLP weights (down/gate/up) and input_layernorm are both optional now, take_layer_weights and weight_names branch dense vs MoE while sharing the qk-norm / feed-forward-norm conditionals, and the single-decode / seq FFN dispatch (ffn_single / ffn_seq) picks the MoE block or the dense single_mlp / seq_mlp. Validated after the merge: cargo test -p mlxcel-xla --lib (80 passed, byte-exact goldens for both the dense fixtures and the qwen2-moe-tiny MoE fixture), cargo clippy -D warnings clean, and spike/openxla/moe_oracle.py token-exact (~1e-9) vs the HF Qwen2-MoE and Mixtral fp32 blocks.
inureyes
added a commit
that referenced
this pull request
Jul 1, 2026
Integrate origin/main (PR #558, issue #497: the qk-norm / Gemma dense pack for Qwen3, Gemma1/3, SmolLM3, OLMo2/3) into feature/issue-499-dense-arch-pack (issue #499: the remaining-Llama-family pack for Seed-OSS, MiMo, InternLM3, ExaOne). The two sides restructured the same shared files with disjoint, complementary architecture sets; this merge unifies them so one Config detects and one emitter serves both. config.rs: keep #497's orthogonal decomposition of the old `gemma2: bool` switch (NormStyle / QkNorm plus embed_scale / norm_one_plus / mlp_geglu / rope_local_base / sliding_pattern / use_rope_layers) and add #499's WeightScheme, the seed_oss / mimo / internlm3 / exaone match arms (config-driven q/k/v bias), the num_layers / layer_norm_epsilon alternate field reads, the default / dynamic rope types served as plain RoPE, and the ERNIE-4.5 / attention_out_bias / mlp_bias / non-SwiGLU rejections. Both flag sets now coexist on one Config, and both arch sets are detected. Weight naming was the main structural conflict: #497 edited an inline weight_names in iree.rs, while #499 relocated naming into a new weight_names.rs module with a per-scheme name table. Consolidated into the module. iree.rs imports weight_names, and weight_names.rs folds #497's arch-conditional order (a skippable input_layernorm for the OLMo reordered post-norm, the q/k norm weights, and the independent pre / post feed-forward norms) into #499's Llama / ExaOne schemes. One naming layer covers both arch sets, and Llama / Qwen2 / Gemma2 name lists stay byte-identical. validation.rs, assets, and the spike scripts are additive and disjoint, so both sides are kept. REGISTERED gains the four #499 golden fixtures (Seed-OSS / MiMo / InternLM3 / ExaOne); #497's golden-less STRUCTURAL_FAMILIES (Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) stay alongside them. The emitter core (model.rs, rope.rs, builder.rs) is byte-identical to origin/main, and Llama / Qwen2 / Gemma2 parse to the same flags, so their emitted graphs are byte-for-byte unchanged (registered_fixtures_are_byte_exact passes). Validation: cargo test -p mlxcel-xla --lib is 84 passed / 0 failed; clippy -D warnings and fmt --check are clean; the spike execution checks are token-exact for both packs (Seed-OSS via synthetic_arch_check.py, and Qwen3 / Gemma1/3 / SmolLM3 / OLMo2 via dense_arch_pack_check.py).
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Qwen3, Gemma1, Gemma3, SmolLM3, and OLMo2/3 to the OpenXLA (
mlxcel-xla) emitter as per-family deltas on the shared per-layer attention core (#494), the Gemma2 sliding-window schedule (#495), and the reusable validation harness (#496). Part of epic #493.The monolithic
gemma2: boolconfig switch is decomposed into orthogonal flags, so each dense family is a flag combination rather than a new code path, and Llama / Qwen2 / Gemma2 keep their exact previous flag combinations (every existing emitted graph and its committed golden is byte-for-byte unchanged).What changed
emitter/config.rs:Config::from_jsonarch detection for the six newmodel_types, plus the orthogonal fields (embed_scale,norm_one_plus,mlp_geglu,NormStyle,QkNorm,rope_local_base,sliding_pattern,use_rope_layers) and their norm-placement / sliding / NoPE / local-RoPE helpers.emitter/model.rs: the reserved refactor: share the full per-layer attention core across all emitter graph kinds #494 q/k-norm hook now applies per-head (Qwen3 / Gemma3) or flat (OLMo2/3) RMSNorm before RoPE;arch_normskips the input norm for the OLMo reordered post-norm;seq_mlp/single_mlpareNormStyle-aware;AttnLayout::rope_qkselects the per-layer local RoPE table (Gemma3 dual RoPE); NoPE layers skip RoPE (SmolLM3).emitter/rope.rs: a local-base RoPE table builder for the Gemma3 / OLMo3 sliding layers.iree.rs:weight_namesmirrors each family's arg schedule (conditionalinput_layernorm, q/k norms, feed-forward norms) so the loader lines up with the emitted graph.validation.rs: registers each family as a golden-less structural fixture (the Qwen2.5 precedent), gated in CI viavalidate_arch.sh.spike/openxla/dense_arch_pack_check.py(new): an IREE-CPU execution check that builds a small synthetic HF model per family, feeds identical random weights to the Rust-emitted prefill graph, and compares logits to the HF fp32 oracle. This proves the family math is correct (not merely self-consistent) without a heavy checkpoint or thexla-ireecargo feature.Architecture deltas (verified against transformers 5.12.1 source)
(1+w)norm + gelu_tanh GeGLU(1+w)q/k norm + dual local/global RoPE + 5:1 sliding, no soft-capno_rope_layers)Test plan
cargo test -p mlxcel-xla --lib(69 pass; existing byte-exact Llama goldens and Gemma2 / Qwen2 structural tests unchanged, confirming the decomposition is byte-preserving)cargo clippy -p mlxcel-xla --lib --tests(clean) andcargo fmt0.0, Gemma13e-8, Gemma3 (dual RoPE)4e-8, SmolLM3 (NoPE)5e-10, OLMo22e-9(all under the sub-0.01 tanh-near-tie bound)validate_arch.sh --structural-onlyruns the new golden-less family signaturesDeferrals (per the epic's allowance)
MLXCEL_BACKEND=xla) / serve token-exact runs need a localxla-ireebuild and are left to the orchestrator's post-merge token-exact gate; the emit is wired into the IREE loader path (weight_names) and proven correct by the synthetic HF oracle above.Closes #497