feat: dense arch pack (Seed-OSS, MiMo, InternLM3, ExaOne) in mlxcel-xla (#499)#560
Merged
Conversation
Add four dense Llama-family architectures to the OpenXLA (mlxcel-xla) emitter as config / naming deltas over the two already-proven forwards (Llama and Qwen2), part of epic #493. Each was verified from its modeling code to use the emitter's half-split RoPE, `cat(freqs, freqs)` table and `head_dim^-0.5` scaling, RMSNorm and SwiGLU, so its emitted graph is a proven graph rather than a new one; the emitter's forward pass (model.rs) is untouched. Families: Seed-OSS (`seed_oss`, q/k/v bias from `attention_bias`, untied, `default` rope served as plain); MiMo (`mimo`, which subclasses Qwen2 verbatim, its multi-token-prediction heads ignored, config `sliding_window` served globally); InternLM3 (`internlm3`, untied, `dynamic` NTK rope served as plain in-context); ExaOne 3.x (`exaone`, llama3 rope, tied, GPT-2-style tensor names via a new `WeightScheme::Exaone`, and the `num_layers` / `layer_norm_epsilon` alternate config fields). Config parsing (`emitter/config.rs`) generalizes the q/k/v bias source (Qwen2 hard-coded, else `attention_bias` / `qkv_bias`), widens the accepted rope types (`default` / in-context `dynamic` -> plain), and reads the alternate field names, while rejecting out-of-scope deltas with clear errors: an attention output bias, an MLP bias, a non-SwiGLU activation, an unsupported rope type, and interleaved (GPT-J) RoPE. Tensor naming moves to a new pure-Rust `weight_names` module so ExaOne's GPT-2-style remap is unit-tested without the `iree` feature (the loader in `iree.rs` now calls it); the `Llama` scheme reproduces the previous names byte-for-byte, so Llama / Qwen2 / Gemma2 loading is unchanged. Validation: each family is registered as a byte-exact structural fixture (small synthetic configs keep the goldens small; the real `config.json` parsing is asserted in `config::tests`), and `dense_pack_families_reuse_proven_graphs` asserts each emits StableHLO byte-for-byte identical to a proven Llama / Qwen2 reference across the single, prefill, and ragged graph kinds. Two Python harnesses under `spike/openxla` prove forward parity through IREE without building `xla-iree`: `synthetic_arch_check.py` matches a tiny random `SeedOssForCausalLM` to the emitter's prefill logits to ~1e-9, and `arch_execution_check.py` drives a real-checkpoint continuation against an HF fp32 oracle (it surfaced ERNIE-4.5's interleaved RoPE, a delta config inspection misses). The heavy token-exact / serve gate runs post-merge via `scripts/xla/validate_arch.sh`. Deferred to follow-ups (documented, out of scope for a dense half-split pack): ERNIE-4.5 (interleaved GPT-J RoPE), Mistral / Ministral (YaRN rope plus a VLM `language_model.` prefix), ExaOne4 (QK-norm plus post-norm placement), InternLM2 (fused `wqkv` needing loader surgery), GLM4-flash and Solar-Open (MoE + MLA, not dense), Baichuan-M1 (conv + differentiated sliding-window heads), Apertus (xIELU activation), and Hunyuan-dense (cross-layer attention).
Member
Author
Merge note: conflicts with #558 (complementary arch sets)This PR was branched from
The merge needs to re-base this PR's additions onto #558's redesigned
Per the workflow, no merge was performed here. |
Integrate origin/main (PR #558, issue #497: the qk-norm / Gemma dense pack for Qwen3, Gemma1/3, SmolLM3, OLMo2/3) into feature/issue-499-dense-arch-pack (issue #499: the remaining-Llama-family pack for Seed-OSS, MiMo, InternLM3, ExaOne). The two sides restructured the same shared files with disjoint, complementary architecture sets; this merge unifies them so one Config detects and one emitter serves both. config.rs: keep #497's orthogonal decomposition of the old `gemma2: bool` switch (NormStyle / QkNorm plus embed_scale / norm_one_plus / mlp_geglu / rope_local_base / sliding_pattern / use_rope_layers) and add #499's WeightScheme, the seed_oss / mimo / internlm3 / exaone match arms (config-driven q/k/v bias), the num_layers / layer_norm_epsilon alternate field reads, the default / dynamic rope types served as plain RoPE, and the ERNIE-4.5 / attention_out_bias / mlp_bias / non-SwiGLU rejections. Both flag sets now coexist on one Config, and both arch sets are detected. Weight naming was the main structural conflict: #497 edited an inline weight_names in iree.rs, while #499 relocated naming into a new weight_names.rs module with a per-scheme name table. Consolidated into the module. iree.rs imports weight_names, and weight_names.rs folds #497's arch-conditional order (a skippable input_layernorm for the OLMo reordered post-norm, the q/k norm weights, and the independent pre / post feed-forward norms) into #499's Llama / ExaOne schemes. One naming layer covers both arch sets, and Llama / Qwen2 / Gemma2 name lists stay byte-identical. validation.rs, assets, and the spike scripts are additive and disjoint, so both sides are kept. REGISTERED gains the four #499 golden fixtures (Seed-OSS / MiMo / InternLM3 / ExaOne); #497's golden-less STRUCTURAL_FAMILIES (Qwen3 / Gemma1/3 / SmolLM3 / OLMo2/3) stay alongside them. The emitter core (model.rs, rope.rs, builder.rs) is byte-identical to origin/main, and Llama / Qwen2 / Gemma2 parse to the same flags, so their emitted graphs are byte-for-byte unchanged (registered_fixtures_are_byte_exact passes). Validation: cargo test -p mlxcel-xla --lib is 84 passed / 0 failed; clippy -D warnings and fmt --check are clean; the spike execution checks are token-exact for both packs (Seed-OSS via synthetic_arch_check.py, and Qwen3 / Gemma1/3 / SmolLM3 / OLMo2 via dense_arch_pack_check.py).
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds four dense Llama-family architectures to the OpenXLA (
mlxcel-xla) emitter as config / naming deltas over the two already-proven forwards (Llama and Qwen2), part of epic #493. The emitter's forward pass (emitter/model.rs) is untouched: each family was verified from its modeling code to use the emitter's half-split RoPE,cat(freqs, freqs)table andhead_dim^-0.5scaling, RMSNorm and SwiGLU, so its emitted graph is a proven graph rather than a new one.Families added
seed_oss): q/k/v projection bias (fromattention_bias), untied embeddings,rope_type = "default"served as plain RoPE. The proven Qwen2 bias forward with standard names.mimo):MiMoForCausalLMsubclassesQwen2ForCausalLMand reusesQwen2Attention/Qwen2MLP/Qwen2RMSNormverbatim (q/k/v bias, untied, plain RoPE); its multi-token-prediction heads are not loaded and its configsliding_windowis served globally (None), as for Qwen2.internlm3): standard names, untied,rope_type = "dynamic"served as plain RoPE (dynamic NTK is identity within the original context; long-context rescale is a follow-up).exaone): llama3 RoPE, tied, the GPT-2-styletransformer.h.{i}...tensor names via a newWeightScheme::Exaone, and thenum_layers/layer_norm_epsilonalternate config fields. Verified against the checkpoint'smodeling_exaone.py(gated MLPc_proj(act(c_fc_0(x)) * c_fc_1(x)), soc_fc_0is the gate andc_fc_1the up projection).What changed
emitter/config.rs: generalizes the q/k/v bias source (Qwen2 hard-coded, elseattention_bias/qkv_bias), widens the accepted rope types (default/ in-contextdynamic-> plain), reads the alternate field names, and adds aWeightScheme. Out-of-scope deltas are rejected with clear errors: an attention output bias, an MLP bias, a non-SwiGLU activation, an unsupported rope type (e.g.yarn), and interleaved (GPT-J) RoPE (e.g. ERNIE-4.5).weight_names.rs(new): the emitter's weight arg order moves to a pure-Rust module so ExaOne's GPT-2-style remap is unit-tested without theireefeature;iree.rsnow calls it. TheLlamascheme reproduces the previous names byte-for-byte, so Llama / Qwen2 / Gemma2 loading is unchanged.validation.rs: registers each family as a byte-exact structural fixture (small synthetic configs keep the goldens small) and addsdense_pack_families_reuse_proven_graphs, which asserts each family emits StableHLO byte-for-byte identical to a proven Llama / Qwen2 reference across the single, prefill, and ragged graph kinds.assets/{seed_oss,mimo,internlm3,exaone}/: fixtureconfig.json+ frozendecode.mlir/prefill.mlir+ a README documenting each delta and its validation.spike/openxla/synthetic_arch_check.py(new): builds a tiny randomSeedOssForCausalLMand matches the emitter's prefill logits through IREE to HF;spike/openxla/arch_execution_check.py(new): drives a real-checkpoint continuation against an HF fp32 oracle. Neither buildsxla-iree.Validation
cargo test -p mlxcel-xla --lib: 72 passed (byte-exact fixtures, the reuse-proven-graphs equivalence, config-parse and weight-name unit tests, and the existing Llama/Qwen2/Gemma2 gates all green; the llama-3.2-1b goldens are unchanged).synthetic_arch_check.pymatches Seed-OSS to HF eager tomax|logit diff| = 3.7e-9. Thearch_execution_check.pyharness ran end to end on the smallest checkpoint and correctly rejected ERNIE-4.5 (interleaved RoPE diverged from HF by ~8, vs ~1e-9 for a supported family), which is why it is deferred.cargo clippy -p mlxcel-xla --lib --tests -- -D warnings: clean;cargo fmtapplied.scripts/xla/validate_arch.shon the real checkpoints.Deferred (documented follow-ups)
Out of scope for a dense half-split pack, each needing a distinct emitter subsystem: ERNIE-4.5 (interleaved GPT-J RoPE), Mistral / Ministral (YaRN rope + VLM
language_model.prefix), ExaOne4 (QK-norm + post-norm placement), InternLM2 (fusedwqkvloader surgery), GLM4-flash and Solar-Open (MoE + MLA, not dense), Baichuan-M1 (conv + differentiated sliding-window heads), Apertus (xIELU activation), and Hunyuan-dense (cross-layer attention). Theconfig.jsonparser rejects each with a clear, specific error.Test plan
cargo test -p mlxcel-xla --lib(72 passed)cargo clippy -p mlxcel-xla --lib --tests -- -D warnings(clean)spike/openxla/synthetic_arch_check.py(Seed-OSS forward parity ~1e-9)scripts/xla/validate_arch.sh --model <ckpt>token-exact / serve gate per family (needs the native IREE build)Closes #499