From 1ae67b556393b46e87899aef25d9d940947c1c28 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Wed, 1 Jul 2026 10:47:13 +0900 Subject: [PATCH] feat: dense arch pack for qk-norm and Gemma family in mlxcel-xla Add Qwen3, Gemma1, Gemma3, SmolLM3, and OLMo2/3 to the OpenXLA emitter as per-family deltas on the shared per-layer attention core (#494), the Gemma2 sliding-window schedule (#495), and the reusable validation harness (#496). Part of #493. The monolithic `gemma2: bool` config switch is decomposed into orthogonal flags so each family is a flag combination rather than a new code path: `embed_scale`, `norm_one_plus`, `mlp_geglu`, a `NormStyle` enum (Plain pre-norm / GemmaFf four-norm / OLMo reordered post-norm), an optional `QkNorm` (per-head for Qwen3/Gemma3, flat for OLMo2/3), a per-layer local RoPE base (Gemma3 dual local/global RoPE), a generalized sliding-window pattern period, and a per-layer NoPE mask (SmolLM3). Llama / Qwen2 / Gemma2 keep their exact previous flag combinations, so every existing emitted graph (and its committed golden) is byte-for-byte unchanged. Changes: - emitter/config.rs: `Config::from_json` arch detection + the orthogonal fields and `NormStyle` / `QkNorm` types; norm-placement and sliding/NoPE/local-rope helpers. - emitter/model.rs: the reserved #494 q/k-norm hook now applies per-head or flat RMSNorm before RoPE; `arch_norm` skips the input norm for OLMo post-norm; `seq_mlp` / `single_mlp` are norm-style aware; `AttnLayout::rope_qk` selects the local RoPE table per layer; NoPE layers skip RoPE. - emitter/rope.rs: local-base RoPE table builder for the Gemma3 / OLMo3 sliding layers. - iree.rs: `weight_names` mirrors the per-family arg schedule (conditional input norm, q/k norms, feed-forward norms) so the loader lines up with the emitted graph. - validation.rs: registers each family as a golden-less structural fixture (the Qwen2.5 precedent), gated in CI via validate_arch.sh. - spike/openxla/dense_arch_pack_check.py: an IREE-CPU execution check that builds a small synthetic HF model per family and compares the emitter's prefill logits to the HF fp32 oracle, so the family math is proven correct without a heavy checkpoint or the xla-iree feature. Validation: 69 pure-Rust tests pass (existing byte-exact goldens unchanged). The execution check is token-exact vs HF fp32 for Qwen3 (0.0 max logit diff), Gemma1 (3e-8), Gemma3 including dual RoPE (4e-8), SmolLM3 NoPE (5e-10), and OLMo2 (2e-9). OLMo3 is structure-complete but its full-size checkpoint needs yarn RoPE (rejected with a clear message) and is 32B, so its execution check is deferred to the post-merge real-checkpoint gate. --- scripts/xla/validate_arch.sh | 8 +- spike/openxla/dense_arch_pack_check.py | 344 ++++++++++++++++ src/lib/mlxcel-xla/src/emitter/config.rs | 437 ++++++++++++++------ src/lib/mlxcel-xla/src/emitter/mod.rs | 470 ++++++++++++++++++++- src/lib/mlxcel-xla/src/emitter/model.rs | 494 +++++++++++++++++------ src/lib/mlxcel-xla/src/emitter/rope.rs | 42 +- src/lib/mlxcel-xla/src/iree.rs | 56 +-- src/lib/mlxcel-xla/src/validation.rs | 143 +++++++ 8 files changed, 1695 insertions(+), 299 deletions(-) create mode 100644 spike/openxla/dense_arch_pack_check.py diff --git a/scripts/xla/validate_arch.sh b/scripts/xla/validate_arch.sh index d8ad951b..38e7595e 100755 --- a/scripts/xla/validate_arch.sh +++ b/scripts/xla/validate_arch.sh @@ -77,9 +77,11 @@ done # --- structural pre-gate (fast, pure Rust, no GPU / IREE) --- if [ "$SKIP_STRUCTURAL" -eq 0 ]; then - echo "== [structural] byte-exact emitter regression (cargo test) ==" - cargo test -p mlxcel-xla --lib \ - validation::tests::registered_fixtures_are_byte_exact -- --nocapture + echo "== [structural] byte-exact + per-family signature emitter gate (cargo test) ==" + # The whole validation tests module: the byte-exact goldens + # (registered_fixtures_are_byte_exact) plus the golden-less dense-family + # signatures (structural_families_emit_expected_signature, issue #497). + cargo test -p mlxcel-xla --lib validation::tests -- --nocapture echo "[structural] PASS" fi if [ "$STRUCTURAL_ONLY" -eq 1 ]; then diff --git a/spike/openxla/dense_arch_pack_check.py b/spike/openxla/dense_arch_pack_check.py new file mode 100644 index 00000000..b185d76d --- /dev/null +++ b/spike/openxla/dense_arch_pack_check.py @@ -0,0 +1,344 @@ +"""Execution check for issue #497: the dense arch pack (Qwen3, Gemma1, Gemma3, +SmolLM3, OLMo2). + +Proves the mlxcel-xla Rust emitter reproduces each family's attention/MLP math by +comparing its prefill last-token logits to an independent HF fp32 oracle, on a +SMALL SYNTHETIC model per family (random weights, tiny dims). The same weights are +fed to both sides, so the only variable is the emitted graph: a logit match proves +the family delta (per-head q/k norm for Qwen3 / Gemma3, flat q/k norm for OLMo2, the +Gemma embed-scale / (1+w) / GeGLU split, the OLMo reordered post-norm, SmolLM3 NoPE, +and Gemma3's dual local/global RoPE) is correct. This needs no real checkpoint and +no `xla-iree` cargo feature: it emits via the scoped pure-Rust +`dump_prefill_graph_for_execution_check` test and runs the graph with IREE's Python +llvm-cpu backend. Short prompt / CPU, streams progress (watchdog-safe). + +Method (mirrors spike/openxla/gemma2_sliding_window_check.py): + 1. Build one small HF model of the family (random weights; the 1-D norm weights, + which HF inits to 0/1, are randomized so the norms and q/k norms are exercised). + 2. Emit the matching emitter `config.json`'s prefill graph via the pure-Rust dump + test, compile with IREE (llvm-cpu), run on the frozen weights in the emitter's + arg order. + 3. Run HF (eager, fp32) on the same tokens; compare last-token logits (argmax + + max abs diff). + +Run (from the repo, with the spike venv's python): + spike/openxla/.venv/bin/python spike/openxla/dense_arch_pack_check.py + ... --family qwen3 # one family only + +Exit 0 = every requested family PASS. +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +import tempfile + +import numpy as np +import torch + +WORKTREE = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +PREFILL_LP = 256 # emitter MAX_SEQ / prefill bucket +PROMPT_LEN = 12 +TOL = 2e-2 # last-token logit tolerance (loose for the Gemma tanh GeGLU near-ties) + +# Small dims shared by the HF config and the emitter config.json. head_dim*n_q == +# hidden (square o_proj); n_kv < n_q exercises GQA; the flat OLMo2 q/k norm is then +# [n_q*head_dim] = 16 and [n_kv*head_dim] = 8. +HIDDEN = 16 +N_Q = 4 +N_KV = 2 +HEAD_DIM = 4 +INTER = 32 +N_LAYERS = 4 +VOCAB = 64 +EPS = 1e-6 +ROPE_THETA = 10000.0 + + +def arg_names(*, has_input_norm, qkv_bias, qk_norm, has_pre_ff, has_post_ff, untied): + """The emitter's weight arg order (mirrors `weight_names` in iree.rs exactly).""" + names = ["model.embed_tokens.weight", "model.norm.weight"] + if untied: + names.append("lm_head.weight") + for i in range(N_LAYERS): + p = f"model.layers.{i}." + names.append(p + "mlp.down_proj.weight") + names.append(p + "mlp.gate_proj.weight") + if has_input_norm: + names.append(p + "input_layernorm.weight") + for suf in ( + "post_attention_layernorm.weight", + "mlp.up_proj.weight", + "self_attn.k_proj.weight", + "self_attn.o_proj.weight", + "self_attn.q_proj.weight", + "self_attn.v_proj.weight", + ): + names.append(p + suf) + if qkv_bias: + for suf in ("k_proj.bias", "q_proj.bias", "v_proj.bias"): + names.append(p + "self_attn." + suf) + if qk_norm: + names.append(p + "self_attn.q_norm.weight") + names.append(p + "self_attn.k_norm.weight") + if has_pre_ff: + names.append(p + "pre_feedforward_layernorm.weight") + if has_post_ff: + names.append(p + "post_feedforward_layernorm.weight") + return names + + +def base_dims(): + return dict( + hidden_size=HIDDEN, + num_attention_heads=N_Q, + num_key_value_heads=N_KV, + head_dim=HEAD_DIM, + intermediate_size=INTER, + num_hidden_layers=N_LAYERS, + vocab_size=VOCAB, + rms_norm_eps=EPS, + rope_theta=ROPE_THETA, + max_position_embeddings=512, + attention_bias=False, + ) + + +def base_emitter_cfg(model_type): + return dict( + model_type=model_type, + hidden_size=HIDDEN, + num_attention_heads=N_Q, + num_key_value_heads=N_KV, + head_dim=HEAD_DIM, + intermediate_size=INTER, + num_hidden_layers=N_LAYERS, + vocab_size=VOCAB, + rms_norm_eps=EPS, + rope_theta=ROPE_THETA, + attention_bias=False, + tie_word_embeddings=True, + ) + + +def spec_qwen3(): + from transformers import Qwen3Config, Qwen3ForCausalLM + + cfg = Qwen3Config(**base_dims(), tie_word_embeddings=True) + return dict( + hf=(Qwen3Config, Qwen3ForCausalLM, cfg), + emitter_cfg=base_emitter_cfg("qwen3"), + arg_flags=dict( + has_input_norm=True, qkv_bias=False, qk_norm=True, + has_pre_ff=False, has_post_ff=False, untied=False, + ), + ) + + +def spec_gemma1(): + from transformers import GemmaConfig, GemmaForCausalLM + + cfg = GemmaConfig( + **base_dims(), tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", hidden_activation="gelu_pytorch_tanh", + ) + ec = base_emitter_cfg("gemma") + ec["hidden_activation"] = "gelu_pytorch_tanh" + return dict( + hf=(GemmaConfig, GemmaForCausalLM, cfg), + emitter_cfg=ec, + arg_flags=dict( + has_input_norm=True, qkv_bias=False, qk_norm=False, + has_pre_ff=False, has_post_ff=False, untied=False, + ), + ) + + +def spec_gemma3(): + from transformers import Gemma3TextConfig, Gemma3ForCausalLM + + # Distinct local RoPE base (100) vs global (ROPE_THETA=10000) so the dual-RoPE is + # exercised; an inert sliding window (>= prompt) keeps the mask a no-op while the + # sliding layers still rotate on the local base (as HF does). + cfg = Gemma3TextConfig( + **base_dims(), tie_word_embeddings=True, + query_pre_attn_scalar=HEAD_DIM, rope_local_base_freq=100.0, + sliding_window=64, sliding_window_pattern=3, + attn_logit_softcapping=None, final_logit_softcapping=None, + hidden_activation="gelu_pytorch_tanh", + ) + ec = base_emitter_cfg("gemma3_text") + ec.update( + rope_local_base_freq=100.0, sliding_window=64, sliding_window_pattern=3, + query_pre_attn_scalar=HEAD_DIM, hidden_activation="gelu_pytorch_tanh", + ) + return dict( + hf=(Gemma3TextConfig, Gemma3ForCausalLM, cfg), + emitter_cfg=ec, + arg_flags=dict( + has_input_norm=True, qkv_bias=False, qk_norm=True, + has_pre_ff=True, has_post_ff=True, untied=False, + ), + ) + + +def spec_smollm3(): + from transformers import SmolLM3Config, SmolLM3ForCausalLM + + no_rope = [1, 1, 1, 0] # layer 3 NoPE + # SmolLM3Config's real default pad/bos/eos ids exceed the synthetic vocab; pin + # them inside it so the padding-idx embedding is valid. + cfg = SmolLM3Config( + **base_dims(), tie_word_embeddings=True, no_rope_layers=no_rope, + pad_token_id=0, bos_token_id=1, eos_token_id=2, + ) + ec = base_emitter_cfg("smollm3") + ec["no_rope_layers"] = no_rope + return dict( + hf=(SmolLM3Config, SmolLM3ForCausalLM, cfg), + emitter_cfg=ec, + arg_flags=dict( + has_input_norm=True, qkv_bias=False, qk_norm=False, + has_pre_ff=False, has_post_ff=False, untied=False, + ), + ) + + +def spec_olmo2(): + from transformers import Olmo2Config, Olmo2ForCausalLM + + cfg = Olmo2Config(**base_dims(), tie_word_embeddings=False) + ec = base_emitter_cfg("olmo2") + ec["tie_word_embeddings"] = False + return dict( + hf=(Olmo2Config, Olmo2ForCausalLM, cfg), + emitter_cfg=ec, + arg_flags=dict( + has_input_norm=False, qkv_bias=False, qk_norm=True, + has_pre_ff=False, has_post_ff=True, untied=True, + ), + ) + + +SPECS = { + "qwen3": spec_qwen3, + "gemma1": spec_gemma1, + "gemma3": spec_gemma3, + "smollm3": spec_smollm3, + "olmo2": spec_olmo2, +} + + +def build_reference_weights(model_cls, cfg): + torch.manual_seed(0) + model = model_cls(cfg).eval().float() + with torch.no_grad(): + for _, param in model.named_parameters(): + if param.dim() == 1: # RMSNorm weights (HF inits them to 0 or 1) + param.copy_(torch.randn_like(param) * 0.1) + return model + + +def emitter_logits(emitter_cfg, weights_np, tokens, positions, real_len, tag): + workdir = tempfile.mkdtemp(prefix=f"dense_{tag}_") + cfg_path = os.path.join(workdir, "config.json") + mlir_path = os.path.join(workdir, "prefill.mlir") + vmfb_path = os.path.join(workdir, "prefill.vmfb") + with open(cfg_path, "w") as fh: + json.dump(emitter_cfg, fh) + + print(f"[emit] {tag}: cargo dump prefill graph ...", flush=True) + subprocess.run( + [ + "cargo", "test", "-p", "mlxcel-xla", "--lib", + "emitter::tests::dump_prefill_graph_for_execution_check", + "--", "--ignored", "--nocapture", + ], + cwd=WORKTREE, + env={**os.environ, "MLXCEL_DUMP_CONFIG": cfg_path, "MLXCEL_DUMP_OUT": mlir_path}, + check=True, + stdout=subprocess.DEVNULL, + ) + + from iree.compiler.tools import compile_file + from iree.runtime import load_vm_flatbuffer_file + + print(f"[compile] {tag}: iree-compile (llvm-cpu) ...", flush=True) + compile_file( + mlir_path, output_file=vmfb_path, + input_type="stablehlo", target_backends=["llvm-cpu"], + ) + print(f"[run] {tag}: IREE prefill ...", flush=True) + mod = load_vm_flatbuffer_file(vmfb_path, driver="local-task") + out = mod.main(*(list(weights_np) + [tokens, positions, real_len])) + logits = out[0].to_host() if hasattr(out[0], "to_host") else np.asarray(out[0]) + return np.asarray(logits, dtype=np.float32) + + +def hf_logits(model, prompt): + model.config._attn_implementation = "eager" + with torch.no_grad(): + out = model(input_ids=torch.tensor(prompt[None, :], dtype=torch.long)) + return out.logits[0, PROMPT_LEN - 1].numpy().astype(np.float32) + + +def run_family(name): + print(f"\n===== {name} =====", flush=True) + spec = SPECS[name]() + _, model_cls, cfg = spec["hf"] + model = build_reference_weights(model_cls, cfg) + state = model.state_dict() + + names = arg_names(**spec["arg_flags"]) + missing = [n for n in names if n not in state] + if missing: + print(f"[{name}] MISSING weights in HF state_dict: {missing[:4]} ...", flush=True) + return False + weights_np = [np.ascontiguousarray(state[n].numpy(), dtype=np.float32) for n in names] + + rng = np.random.default_rng(1) + prompt = rng.integers(0, VOCAB, size=PROMPT_LEN).astype(np.int32) + tokens = np.zeros(PREFILL_LP, dtype=np.int32) + tokens[:PROMPT_LEN] = prompt + positions = np.arange(PREFILL_LP, dtype=np.int32) + real_len = np.asarray(PROMPT_LEN, dtype=np.int32) + + li = emitter_logits(spec["emitter_cfg"], weights_np, tokens, positions, real_len, name) + lh = hf_logits(model, prompt) + diff = float(np.max(np.abs(li - lh))) + ai, ah = int(li.argmax()), int(lh.argmax()) + ok = ai == ah and diff < TOL + print( + f"[{name}] argmax iree={ai} hf={ah} max|logit diff|={diff:.3e} " + f"-> {'PASS' if ok else 'FAIL'}", + flush=True, + ) + return ok + + +def main(): + ap = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + ap.add_argument("--family", choices=list(SPECS), help="one family (default: all)") + args = ap.parse_args() + families = [args.family] if args.family else list(SPECS) + results = {} + for fam in families: + try: + results[fam] = run_family(fam) + except Exception as e: # noqa: BLE001 - report and continue to the next family + print(f"[{fam}] ERROR: {type(e).__name__}: {e}", flush=True) + results[fam] = False + print("\n===== summary =====", flush=True) + for fam, ok in results.items(): + print(f" {fam:9s}: {'PASS' if ok else 'FAIL'}", flush=True) + all_ok = all(results.values()) + print("RESULT:", "PASS" if all_ok else "FAIL", flush=True) + return 0 if all_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/lib/mlxcel-xla/src/emitter/config.rs b/src/lib/mlxcel-xla/src/emitter/config.rs index 86262aee..58c6f15f 100644 --- a/src/lib/mlxcel-xla/src/emitter/config.rs +++ b/src/lib/mlxcel-xla/src/emitter/config.rs @@ -1,20 +1,25 @@ -//! Emitter config for the Llama-family architectures the OpenXLA backend serves. +//! Emitter config for the dense architectures the OpenXLA backend serves. //! The hard-coded [`Config::llama_3_2_1b`] matches spike/openxla/model_jax.py; -//! [`Config::from_json`] reads the same shape from a checkpoint's `config.json` -//! (issue #449 M3 Stage 2d). Stage A covered the Llama architecture (llama3 RoPE, -//! no attention bias); Stage B adds Qwen2 (plain RoPE + QKV bias), so the config -//! carries the architecture switches the emitter branches on: the RoPE kind, -//! whether q/k/v projections have a bias, and whether the LM head is tied to the -//! token embedding (tied) or a separate `lm_head.weight` (untied, e.g. -//! Llama-3.1-8B and the larger Qwen2.5 checkpoints). +//! [`Config::from_json`] reads the same shape from a checkpoint's `config.json`. +//! +//! The config carries orthogonal architecture switches the emitter branches on, +//! so a new dense family is a combination of flags rather than a new code path: +//! the RoPE kind (llama3 vs plain, plus an optional per-layer local base for +//! Gemma3), whether q/k/v projections carry a bias (Qwen2), the LM-head tie, MLX +//! quantization, the Gemma embedding scale / `(1+w)` RMSNorm / GeGLU MLP, the +//! per-layer norm placement ([`NormStyle`]), an optional q/k normalization +//! ([`QkNorm`]: per-head for Qwen3 / Gemma3, flat for OLMo2 / OLMo3), the +//! sliding-window schedule (window + pattern period), and a per-layer NoPE mask +//! (SmolLM3). Llama / Qwen2 / Gemma2 keep their exact previous flag combinations, +//! so their emitted graphs are byte-for-byte unchanged. /// How the RoPE inverse-frequency table is computed. Both kinds share the /// `outer(pos, inv_freq)` table build (see [`rope`](super::rope)); they differ /// only in `inv_freq`. #[derive(Clone, Debug, PartialEq)] pub enum RopeScaling { - /// Plain RoPE: `inv_freq[i] = 1 / theta^(2i/head_dim)` (Qwen2, and plain-RoPE - /// Llama without a `rope_scaling` block). + /// Plain RoPE: `inv_freq[i] = 1 / theta^(2i/head_dim)` (Qwen2, Qwen3, Gemma, + /// SmolLM3, OLMo2, and plain-RoPE Llama without a `rope_scaling` block). Plain, /// Llama3 RoPE scaling, byte-for-byte with HF `_compute_llama3_parameters`. Llama3 { @@ -25,6 +30,40 @@ pub enum RopeScaling { }, } +/// Per-layer norm placement. The three dense patterns differ in where the +/// RMSNorms sit relative to the attention / MLP sublayers. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NormStyle { + /// Llama / Qwen2 / Qwen3 / Gemma1 / SmolLM3: pre-norm. `input_layernorm` + /// normalizes the residual before attention; `post_attention_layernorm` + /// normalizes it before the MLP. Two norms per layer, both on the input side. + Plain, + /// Gemma2 / Gemma3: pre-norm wrapped by post-norms. `input_layernorm` before + /// attention, then `post_attention_layernorm` on the attention output before + /// the residual; `pre_feedforward_layernorm` before the MLP, then + /// `post_feedforward_layernorm` on the MLP output before the residual. Four + /// norms per layer. + GemmaFf, + /// OLMo2 / OLMo3: reordered (post) norm. No `input_layernorm`; attention and + /// the MLP consume the raw residual, and `post_attention_layernorm` / + /// `post_feedforward_layernorm` normalize each sublayer's OUTPUT before its + /// residual add. Two norms per layer, both on the output side. + OlmoPost, +} + +/// Optional q/k normalization applied to the projected query / key before RoPE. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct QkNorm { + /// `true` normalizes each head independently over `head_dim` (Qwen3, Gemma3; + /// weight shape `[head_dim]`). `false` normalizes the whole flattened + /// projection over `n_q*head_dim` / `n_kv*head_dim` (OLMo2, OLMo3; weight + /// shapes `[n_q*head_dim]` / `[n_kv*head_dim]`). + pub per_head: bool, + /// `true` uses Gemma's `(1 + weight)` RMSNorm (Gemma3); `false` the raw weight + /// (Qwen3, OLMo2, OLMo3). + pub one_plus: bool, +} + /// MLX affine weight quantization (`config.json` `quantization`). The linear / /// embedding `*.weight` tensors are stored packed as `U32` with companion /// `*.scales` / `*.biases`; the loader dequantizes them to f32 as @@ -46,44 +85,59 @@ pub struct Config { pub eps: f32, pub rope_theta: f64, pub vocab: usize, - /// RoPE inverse-frequency scheme (Stage B: `Plain` for Qwen2). + /// RoPE inverse-frequency scheme for the global (full-attention) layers. pub rope: RopeScaling, - /// q/k/v projections carry a bias (Qwen2). `o_proj` never does, and the MLP - /// projections never do, so this single switch covers the architecture delta. + /// q/k/v projections carry a bias (Qwen2 only; the HF `Qwen2Attention` hard- + /// codes `bias=True`). `o_proj` and the MLP projections never do. Qwen3 drops + /// the bias. pub qkv_bias: bool, /// The LM head shares the token-embedding matrix (HF `tie_word_embeddings`). - /// `true` (Llama-3.2-1B, Qwen2.5-0.5B) reuses `params['embed']` for the final - /// projection; `false` adds a separate `params['lm_head']` weight the tail - /// projects through instead (Llama-3.1-8B, larger Qwen2.5 sizes). + /// `true` reuses `params['embed']` for the final projection; `false` adds a + /// separate `params['lm_head']` weight (Llama-3.1-8B, larger Qwen2.5, OLMo2/3). pub tie_word_embeddings: bool, /// MLX affine weight quantization, if the checkpoint is quantized (`None` for - /// an unquantized bf16/f16/f32 checkpoint). The graph itself is unchanged (it - /// runs in f32); the loader dequantizes the packed weights at load. + /// an unquantized bf16/f16/f32 checkpoint). The graph runs in f32; the loader + /// dequantizes the packed weights at load. pub quantization: Option, - /// Gemma2 architecture switch. When true the emitter scales the input - /// embeddings by `sqrt(hidden)`, uses `(1 + weight)` RMSNorm, a GeGLU - /// (`gelu_tanh`) MLP, a post-norm on each sublayer (four norms per layer), and - /// attention / final logit soft-capping; `o_proj` is non-square - /// (`n_q*head_dim != hidden`). Llama / Qwen2 keep their existing path. - pub gemma2: bool, + /// Scale the input embeddings by `sqrt(hidden)` (the Gemma family). + pub embed_scale: bool, + /// Use Gemma's `(1 + weight)` RMSNorm on the layer / final norms (the Gemma + /// family). The q/k norm has its own `one_plus` flag in [`QkNorm`]. + pub norm_one_plus: bool, + /// GeGLU (`gelu_pytorch_tanh`) MLP activation instead of SwiGLU (silu) (the + /// Gemma family). + pub mlp_geglu: bool, + /// Per-layer RMSNorm placement (see [`NormStyle`]). + pub norm_style: NormStyle, + /// Optional q/k normalization before RoPE (Qwen3 / Gemma3 per-head, OLMo2 / + /// OLMo3 flat). `None` for Llama / Qwen2 / Gemma1/2 / SmolLM3. + pub qk_norm: Option, + /// Gemma3 (and OLMo3) local RoPE base for the sliding (local) layers: those + /// layers build their RoPE table from this base while the global layers use + /// `rope_theta`. `None` means every layer shares the single `rope` table. + pub rope_local_base: Option, /// Gemma2 query pre-attention scale base: the attention score scale is - /// `query_pre_attn_scalar^-0.5` (Gemma2; can differ from `head_dim`). `None` - /// uses `head_dim^-0.5` (Llama / Qwen2). + /// `query_pre_attn_scalar^-0.5`. `None` uses `head_dim^-0.5` (Llama / Qwen2 / + /// Qwen3 / Gemma1 / SmolLM3 / OLMo2/3). pub query_pre_attn_scalar: Option, /// Gemma2 attention logit soft-cap: `softcap * tanh(scores / softcap)` on the - /// pre-mask scores. `None` for Llama / Qwen2. + /// pre-mask scores. `None` for the other families (Gemma3's is null). pub attn_logit_softcap: Option, - /// Gemma2 final logit soft-cap on the LM-head logits. `None` for Llama / Qwen2. + /// Gemma2 final logit soft-cap on the LM-head logits. `None` otherwise. pub final_logit_softcap: Option, - /// Gemma2 sliding-window attention (issue #495): `Some(window)` makes the - /// local (even) layers attend only to the last `window` keys, while the - /// global (odd) layers keep full-context attention. Read from `config.json`'s - /// `sliding_window` (HF Gemma2 default 4096) for a gemma2 checkpoint. `None` - /// for Llama / Qwen2, whose every layer is global; the emitter then emits no - /// window ops, so those graphs are byte-identical. (Qwen2's own - /// `sliding_window` field is deliberately ignored: the emitter serves Qwen2 - /// with `use_sliding_window = false` semantics.) + /// Sliding-window attention size: `Some(window)` makes the local layers attend + /// only to the last `window` keys, while the global layers keep full context. + /// `None` means every layer is global. The local/global schedule is set by + /// [`Config::is_sliding_layer`] via [`sliding_pattern`](Self::sliding_pattern). pub sliding_window: Option, + /// Sliding-window schedule period: layer `li` is global iff `(li+1) % + /// sliding_pattern == 0`, otherwise local. Gemma2 uses 2 (even layers local); + /// Gemma3 uses `sliding_window_pattern` (6, i.e. 5 local : 1 global); OLMo3 + /// uses 4. Only meaningful when `sliding_window` is `Some`. + pub sliding_pattern: usize, + /// Per-layer NoPE mask (SmolLM3): `use_rope_layers[li] == false` skips RoPE on + /// that layer (`no_rope_layers`). `None` applies RoPE on every layer. + pub use_rope_layers: Option>, } impl Config { @@ -108,55 +162,35 @@ impl Config { qkv_bias: false, tie_word_embeddings: true, quantization: None, - gemma2: false, + embed_scale: false, + norm_one_plus: false, + mlp_geglu: false, + norm_style: NormStyle::Plain, + qk_norm: None, + rope_local_base: None, query_pre_attn_scalar: None, attn_logit_softcap: None, final_logit_softcap: None, sliding_window: None, + sliding_pattern: 2, + use_rope_layers: None, } } /// Build a [`Config`] from a model's `config.json` text. /// - /// Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied or - /// untied embeddings). Llama uses llama3 RoPE scaling and no attention bias; - /// Qwen2 uses plain RoPE and a q/k/v projection bias; either may tie its LM - /// head to the token embedding or carry a separate `lm_head.weight`. Configs - /// the emitter cannot yet reproduce are rejected with a clear error rather than + /// Scope: the dense architectures Llama, Qwen2, Qwen3, Gemma1, Gemma2, Gemma3, + /// SmolLM3, and OLMo2/3 (RMSNorm variants, SwiGLU / GeGLU MLP, GQA/MHA, tied or + /// untied embeddings, optional q/k norm, sliding windows, NoPE). Configs the + /// emitter cannot yet reproduce are rejected with a clear error rather than /// silently mis-emitted: an unsupported `model_type`, a `llama` checkpoint with - /// `attention_bias`, or a `rope_scaling` whose `rope_type` is not `llama3`. + /// `attention_bias`, or a `rope_scaling` whose `rope_type` is not `llama3` + /// (e.g. yarn, which OLMo3 uses at full size). pub fn from_json_str(s: &str) -> Result { let v: serde_json::Value = serde_json::from_str(s).map_err(|e| format!("parse config.json: {e}"))?; let model_type = v.get("model_type").and_then(serde_json::Value::as_str); - let gemma2 = model_type == Some("gemma2"); - // Qwen2 always has a q/k/v projection bias (the HF `Qwen2Attention` hard- - // codes `bias=True`), and it is not a `config.json` field, so it is keyed - // off the architecture rather than read. - let qkv_bias = match model_type { - Some("llama") => { - // A `llama` checkpoint with attention bias would need the same bias - // emit Qwen2 uses, but that pairing is untested here, so reject it - // rather than emit an unvalidated graph. - if v.get("attention_bias").and_then(serde_json::Value::as_bool) == Some(true) { - return Err( - "the OpenXLA emitter does not support a `llama` checkpoint with \ - attention_bias = true (only Qwen2 carries a q/k/v bias here)" - .to_string(), - ); - } - false - } - Some("qwen2") => true, - Some("gemma2") => false, - other => { - return Err(format!( - "the OpenXLA emitter supports the Llama, Qwen2, and Gemma2 architectures; \ - config.json model_type = {other:?} (other Gemma variants are a follow-up)" - )); - } - }; // Tied (share `embed` for the head) vs untied (separate `lm_head.weight`). // HF `PretrainedConfig` defaults this to `true`, so an absent field means @@ -195,18 +229,162 @@ impl Config { .and_then(serde_json::Value::as_f64) .ok_or_else(|| format!("config.json missing number `{k}`")) }; + let of64 = |k: &str| -> Option { v.get(k).and_then(serde_json::Value::as_f64) }; + let ou = |k: &str| -> Option { + v.get(k) + .and_then(serde_json::Value::as_u64) + .map(|x| x as usize) + }; let hidden = u("hidden_size")?; let n_q = u("num_attention_heads")?; // head_dim is explicit in recent configs; otherwise it is hidden / heads. - let head_dim = v - .get("head_dim") - .and_then(serde_json::Value::as_u64) - .map(|x| x as usize) - .unwrap_or(hidden / n_q.max(1)); + let head_dim = ou("head_dim").unwrap_or(hidden / n_q.max(1)); + let n_layers = u("num_hidden_layers")?; - // rope_scaling is optional: absent -> plain RoPE (Qwen2.5, plain Llama); - // present -> only the llama3 scheme is supported (Stage A). + // Architecture-family flags, defaulted to the Llama baseline and overridden + // per model_type below. + let mut qkv_bias = false; + let mut embed_scale = false; + let mut norm_one_plus = false; + let mut mlp_geglu = false; + let mut norm_style = NormStyle::Plain; + let mut qk_norm: Option = None; + let mut rope_local_base: Option = None; + let mut query_pre_attn_scalar: Option = None; + let mut attn_logit_softcap: Option = None; + let mut final_logit_softcap: Option = None; + let mut sliding_window: Option = None; + let mut sliding_pattern = 2usize; + let mut use_rope_layers: Option> = None; + + // Read a Gemma family's soft-caps + query scale + sliding window (shared by + // gemma2 / gemma3). Gemma3's soft-caps are null, so this yields `None` there. + let read_gemma_common = + |qpa: &mut Option, asc: &mut Option, fsc: &mut Option| { + *qpa = Some(of64("query_pre_attn_scalar").unwrap_or(head_dim as f64)); + *asc = of64("attn_logit_softcapping").map(|x| x as f32); + *fsc = of64("final_logit_softcapping").map(|x| x as f32); + }; + + match model_type { + Some("llama") => { + // A `llama` checkpoint with attention bias would need the Qwen2 bias + // emit, untested here; reject rather than emit an unvalidated graph. + if v.get("attention_bias").and_then(serde_json::Value::as_bool) == Some(true) { + return Err( + "the OpenXLA emitter does not support a `llama` checkpoint with \ + attention_bias = true (only Qwen2 carries a q/k/v bias here)" + .to_string(), + ); + } + } + Some("qwen2") => { + qkv_bias = true; + } + Some("qwen3") => { + // Qwen3 drops the Qwen2 bias and adds a per-head q/k RMSNorm (raw + // weight, over head_dim) before RoPE. + qk_norm = Some(QkNorm { + per_head: true, + one_plus: false, + }); + } + Some("gemma") => { + // Gemma1: Llama-shaped norm placement, but embedding scale, `(1+w)` + // RMSNorm, and a GeGLU MLP. No soft-caps, sliding, or q/k norm. + embed_scale = true; + norm_one_plus = true; + mlp_geglu = true; + } + Some("gemma2") => { + embed_scale = true; + norm_one_plus = true; + mlp_geglu = true; + norm_style = NormStyle::GemmaFf; + sliding_pattern = 2; + sliding_window = Some(ou("sliding_window").unwrap_or(4096)); + read_gemma_common( + &mut query_pre_attn_scalar, + &mut attn_logit_softcap, + &mut final_logit_softcap, + ); + } + Some("gemma3") | Some("gemma3_text") => { + embed_scale = true; + norm_one_plus = true; + mlp_geglu = true; + norm_style = NormStyle::GemmaFf; + // Gemma3: per-head `(1+w)` q/k norm, a 5:1 local:global schedule + // (`sliding_window_pattern` = 6), and a local RoPE base for the + // sliding layers (`rope_local_base_freq`) distinct from `rope_theta`. + qk_norm = Some(QkNorm { + per_head: true, + one_plus: true, + }); + sliding_pattern = ou("sliding_window_pattern").unwrap_or(6).max(1); + sliding_window = Some(ou("sliding_window").unwrap_or(4096)); + rope_local_base = Some(of64("rope_local_base_freq").unwrap_or(10000.0)); + read_gemma_common( + &mut query_pre_attn_scalar, + &mut attn_logit_softcap, + &mut final_logit_softcap, + ); + } + Some("smollm3") => { + // SmolLM3: Llama-shaped, with a per-layer NoPE mask. HF stores + // `no_rope_layers[li]` as 1 = use RoPE, 0 = NoPE, so it maps directly + // to `use_rope_layers`. + if let Some(arr) = v + .get("no_rope_layers") + .and_then(serde_json::Value::as_array) + { + let flags: Vec = arr + .iter() + .map(|x| x.as_i64().map(|n| n != 0).unwrap_or(true)) + .collect(); + if flags.iter().any(|&b| !b) { + use_rope_layers = Some(flags); + } + } + } + Some("olmo2") => { + // OLMo2: reordered (post) norm and a FLAT q/k RMSNorm over the whole + // projection (raw weight). No input_layernorm. + norm_style = NormStyle::OlmoPost; + qk_norm = Some(QkNorm { + per_head: false, + one_plus: false, + }); + } + Some("olmo3") => { + // OLMo3: OLMo2 plus a sliding-window schedule. The full-size + // checkpoint additionally uses yarn RoPE scaling, which the rope + // block below rejects (a documented follow-up); a plain-RoPE OLMo3 + // config exercises the norm/qk/sliding structure. + norm_style = NormStyle::OlmoPost; + qk_norm = Some(QkNorm { + per_head: false, + one_plus: false, + }); + if let Some(w) = ou("sliding_window") { + sliding_window = Some(w); + // OLMo3 marks every `sliding_window_pattern`-th layer global; + // the layer_types list (3 sliding : 1 full) implies a period 4. + sliding_pattern = ou("sliding_window_pattern").unwrap_or(4).max(1); + } + rope_local_base = of64("rope_local_base_freq"); + } + other => { + return Err(format!( + "the OpenXLA emitter supports the Llama, Qwen2, Qwen3, Gemma1/2/3, \ + SmolLM3, and OLMo2/3 architectures; config.json model_type = {other:?}" + )); + } + } + + // rope_scaling is optional: absent -> plain RoPE; present -> only the llama3 + // scheme is supported (yarn, e.g. OLMo3 at full size, is a follow-up). let rope = match v.get("rope_scaling") { None | Some(serde_json::Value::Null) => RopeScaling::Plain, Some(scaling) => { @@ -244,46 +422,10 @@ impl Config { } }; - // Gemma2 logit soft-caps and the query pre-attention scale base (read only - // for a gemma2 checkpoint; the scale defaults to `head_dim` if absent). - let (query_pre_attn_scalar, attn_logit_softcap, final_logit_softcap) = if gemma2 { - ( - Some( - v.get("query_pre_attn_scalar") - .and_then(serde_json::Value::as_f64) - .unwrap_or(head_dim as f64), - ), - v.get("attn_logit_softcapping") - .and_then(serde_json::Value::as_f64) - .map(|x| x as f32), - v.get("final_logit_softcapping") - .and_then(serde_json::Value::as_f64) - .map(|x| x as f32), - ) - } else { - (None, None, None) - }; - - // Gemma2 sliding-window size (issue #495). Read only for a gemma2 - // checkpoint; an absent field falls back to the HF Gemma2 default of 4096. - // Non-gemma2 architectures get `None` (global attention on every layer), - // even if their config carries a `sliding_window` (e.g. Qwen2.5, which the - // emitter serves without sliding-window attention). - let sliding_window = if gemma2 { - Some( - v.get("sliding_window") - .and_then(serde_json::Value::as_u64) - .map(|x| x as usize) - .unwrap_or(4096), - ) - } else { - None - }; - Ok(Config { hidden, inter: u("intermediate_size")?, - n_layers: u("num_hidden_layers")?, + n_layers, n_q, n_kv: u("num_key_value_heads")?, head_dim, @@ -294,11 +436,18 @@ impl Config { qkv_bias, tie_word_embeddings, quantization, - gemma2, + embed_scale, + norm_one_plus, + mlp_geglu, + norm_style, + qk_norm, + rope_local_base, query_pre_attn_scalar, attn_logit_softcap, final_logit_softcap, sliding_window, + sliding_pattern, + use_rope_layers, }) } @@ -313,9 +462,9 @@ impl Config { self.n_q / self.n_kv } - /// Attention score scale. Llama / Qwen2 use `head_dim^-0.5`; Gemma2 uses + /// Attention score scale. Most families use `head_dim^-0.5`; Gemma2/3 use /// `query_pre_attn_scalar^-0.5` (computed in f64 to match HF, since it can - /// differ from `head_dim`). The Llama / Qwen2 branch is unchanged. + /// differ from `head_dim`). pub fn scale(&self) -> f32 { match self.query_pre_attn_scalar { Some(q) => q.powf(-0.5) as f32, @@ -323,19 +472,57 @@ impl Config { } } - /// Gemma2 input-embedding normalizer `sqrt(hidden)` (computed in f64 then + /// Gemma input-embedding normalizer `sqrt(hidden)` (computed in f64 then /// narrowed, matching HF's `hidden_size**0.5` cast to the activation dtype). pub fn embed_normalizer(&self) -> f32 { (self.hidden as f64).sqrt() as f32 } - /// Whether attention layer `li` uses sliding-window (local) attention (issue - /// #495). Gemma2 alternates local and global attention starting local, so the - /// even layers (0, 2, 4, …) are local and the odd layers are global, matching - /// HF `Gemma2DecoderLayer` (`is_sliding = not bool(layer_idx % 2)`). Only a - /// config with a sliding window (Gemma2) has local layers; Llama / Qwen2 - /// return `false` for every layer, so their emitted graphs are unchanged. + /// Whether attention layer `li` uses sliding-window (local) attention. A + /// windowed config marks layer `li` global iff `(li+1) % sliding_pattern == 0`, + /// otherwise local (Gemma2 period 2 = even local; Gemma3 period 6 = 5 local : 1 + /// global; OLMo3 period 4). A non-windowed config has no local layer, so its + /// emitted graphs are unchanged. pub fn is_sliding_layer(&self, li: usize) -> bool { - self.sliding_window.is_some() && li.is_multiple_of(2) + self.sliding_window.is_some() && !(li + 1).is_multiple_of(self.sliding_pattern.max(1)) + } + + /// Whether attention layer `li` applies RoPE. Every layer does unless the + /// config carries a NoPE mask (SmolLM3) that clears it. + pub fn layer_uses_rope(&self, li: usize) -> bool { + self.use_rope_layers + .as_ref() + .and_then(|v| v.get(li).copied()) + .unwrap_or(true) + } + + /// The local RoPE table base for the sliding (local) layers, when the config + /// has a distinct one (Gemma3 / OLMo3). `None` means every layer shares the + /// single global `rope` table. + pub fn local_rope_layer(&self, li: usize) -> bool { + self.rope_local_base.is_some() && self.is_sliding_layer(li) + } + + /// The layer has an `input_layernorm` applied to the residual before attention + /// (all styles except OLMo2/3's reordered post-norm). + pub fn has_input_norm(&self) -> bool { + self.norm_style != NormStyle::OlmoPost + } + + /// The layer normalizes the attention OUTPUT before the residual add + /// (`post_attention_layernorm` in Gemma2/3 and OLMo2/3). + pub fn has_post_attn_norm(&self) -> bool { + matches!(self.norm_style, NormStyle::GemmaFf | NormStyle::OlmoPost) + } + + /// The layer has a `pre_feedforward_layernorm` before the MLP (Gemma2/3). + pub fn has_pre_ff_norm(&self) -> bool { + self.norm_style == NormStyle::GemmaFf + } + + /// The layer normalizes the MLP OUTPUT before the residual add + /// (`post_feedforward_layernorm` in Gemma2/3 and OLMo2/3). + pub fn has_post_ff_norm(&self) -> bool { + matches!(self.norm_style, NormStyle::GemmaFf | NormStyle::OlmoPost) } } diff --git a/src/lib/mlxcel-xla/src/emitter/mod.rs b/src/lib/mlxcel-xla/src/emitter/mod.rs index 635d3b7f..698a577b 100644 --- a/src/lib/mlxcel-xla/src/emitter/mod.rs +++ b/src/lib/mlxcel-xla/src/emitter/mod.rs @@ -19,15 +19,19 @@ //! `config.json` at load, instead of being pinned to the bundled Llama-3.2-1B //! `.mlir` assets. //! -//! Scope: the Llama, Qwen2, and Gemma2 architectures. The `Config` is -//! parameterized by dimensions, so any checkpoint of a supported architecture (any -//! size) emits correctly. The architecture switches the emitter branches on are -//! the RoPE kind (llama3 scaling for Llama, plain for Qwen2 / Gemma2), whether the -//! q/k/v projections carry a bias (Qwen2), whether the LM head is tied or a -//! separate `lm_head.weight` (untied, e.g. Llama-3.1-8B), and the `gemma2` switch -//! (embedding scale, `(1+w)` RMSNorm, GeGLU, a four-norm layer, attention / final -//! logit soft-cap, non-square `o_proj`). Gemma2 is single-sequence only so far; -//! the batched / ragged serve graphs are a follow-up. +//! Scope: the dense families Llama, Qwen2, Qwen3, Gemma1, Gemma2, Gemma3, SmolLM3, +//! and OLMo2/3. The `Config` is parameterized by dimensions AND by orthogonal +//! architecture flags, so any checkpoint of a supported family (any size) emits +//! correctly and a new family is a flag combination rather than a new code path. +//! The switches the emitter branches on are the RoPE kind (llama3 vs plain, plus an +//! optional per-layer local base for Gemma3), the q/k/v bias (Qwen2), the LM-head +//! tie, MLX quantization, the Gemma embedding scale / `(1+w)` RMSNorm / GeGLU MLP, +//! the per-layer norm placement ([`NormStyle`](config::NormStyle): pre-norm, +//! Gemma four-norm, or OLMo reordered post-norm), an optional q/k normalization +//! ([`QkNorm`](config::QkNorm): per-head for Qwen3 / Gemma3, flat for OLMo2/3), the +//! sliding-window schedule, Gemma2/3 soft-caps, and a per-layer NoPE mask +//! (SmolLM3). The single-sequence, ragged-decode, and prefill graphs share one +//! per-layer attention core (issue #494), so every delta reaches all three. //! //! Pure Rust (no IREE), so it compiles and is unit-tested without the `iree` //! feature; only the IREE engine consumes it. The bundled `.mlir` assets remain @@ -55,7 +59,7 @@ pub(crate) use model::{ #[cfg(test)] mod tests { - use super::config::RopeScaling; + use super::config::{NormStyle, QkNorm, RopeScaling}; use super::*; const CONFIG_JSON: &str = include_str!("../../assets/llama-3.2-1b/config.json"); @@ -91,11 +95,18 @@ mod tests { qkv_bias, tie_word_embeddings: true, quantization: None, - gemma2: false, + embed_scale: false, + norm_one_plus: false, + mlp_geglu: false, + norm_style: NormStyle::Plain, + qk_norm: None, + rope_local_base: None, query_pre_attn_scalar: None, attn_logit_softcap: None, final_logit_softcap: None, sliding_window: None, + sliding_pattern: 2, + use_rope_layers: None, } } @@ -122,11 +133,18 @@ mod tests { qkv_bias: false, tie_word_embeddings: true, quantization: None, - gemma2: true, + embed_scale: true, + norm_one_plus: true, + mlp_geglu: true, + norm_style: NormStyle::GemmaFf, + qk_norm: None, + rope_local_base: None, query_pre_attn_scalar: Some(6.0), attn_logit_softcap: Some(50.0), final_logit_softcap: Some(30.0), sliding_window, + sliding_pattern: 2, + use_rope_layers: None, } } @@ -171,21 +189,22 @@ mod tests { assert_eq!(c.eps, 1e-6); } - /// A non-Llama/Qwen2/Gemma2 architecture and an unsupported `rope_scaling` are - /// each rejected with a clear message rather than mis-emitted. (Untied - /// embeddings and Gemma2 are no longer rejected; see - /// `from_json_accepts_untied_embeddings` / `from_json_parses_gemma2`.) + /// An unsupported `model_type` and an unsupported `rope_scaling` are each + /// rejected with a clear message rather than mis-emitted. (Untied embeddings and + /// the Gemma / Qwen3 / SmolLM3 / OLMo2/3 families are now accepted; see their + /// own parse tests.) #[test] fn from_json_rejects_unsupported_configs() { - let gemma3 = r#"{"model_type":"gemma3","tie_word_embeddings":true,"hidden_size":8, + let mamba = r#"{"model_type":"mamba","tie_word_embeddings":true,"hidden_size":8, "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#; assert!( - Config::from_json_str(gemma3) + Config::from_json_str(mamba) .unwrap_err() .contains("model_type") ); + // yarn RoPE (e.g. OLMo3 at full size) is not supported yet. let yarn = r#"{"model_type":"qwen2","tie_word_embeddings":true,"hidden_size":8, "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10, @@ -209,7 +228,10 @@ mod tests { "query_pre_attn_scalar":224,"attn_logit_softcapping":50.0, "final_logit_softcapping":30.0,"hidden_activation":"gelu_pytorch_tanh"}"#; let c = Config::from_json_str(g).expect("gemma2 parses"); - assert!(c.gemma2); + assert_eq!(c.norm_style, NormStyle::GemmaFf); + assert!(c.embed_scale && c.norm_one_plus && c.mlp_geglu); + assert!(c.qk_norm.is_none(), "Gemma2 has no q/k norm"); + assert!(c.rope_local_base.is_none(), "Gemma2 has a single RoPE base"); assert_eq!(c.rope, RopeScaling::Plain); assert_eq!(c.query_pre_attn_scalar, Some(224.0)); assert_eq!(c.attn_logit_softcap, Some(50.0)); @@ -666,4 +688,414 @@ mod tests { // The f32 default carries no f16 at all (byte-exact path preserved). assert!(!emit_decode_with(&c, true, Precision::F32).contains("f16")); } + + // =================================================================== + // issue #497: dense arch pack (Qwen3, Gemma1/3, SmolLM3, OLMo2/3) + // =================================================================== + + /// A small Plain (Llama-shaped) config to derive the new families from. Tiny + /// dims keep the emitted text small while exercising every shared-core path; + /// `head_dim` (4) deliberately differs from `hidden / n_q`, and `n_q*head_dim` + /// (12) from `hidden` (8), so the flat q-norm and non-square o_proj widths are + /// genuinely distinct (as in real checkpoints). + fn dense_base() -> Config { + Config { + hidden: 8, + inter: 16, + n_layers: 4, + n_q: 3, + n_kv: 1, + head_dim: 4, + eps: 1e-6, + rope_theta: 1e4, + vocab: 12, + rope: RopeScaling::Plain, + qkv_bias: false, + tie_word_embeddings: true, + quantization: None, + embed_scale: false, + norm_one_plus: false, + mlp_geglu: false, + norm_style: NormStyle::Plain, + qk_norm: None, + rope_local_base: None, + query_pre_attn_scalar: None, + attn_logit_softcap: None, + final_logit_softcap: None, + sliding_window: None, + sliding_pattern: 2, + use_rope_layers: None, + } + } + + fn qwen3_like() -> Config { + Config { + qk_norm: Some(QkNorm { + per_head: true, + one_plus: false, + }), + ..dense_base() + } + } + + fn gemma1_like() -> Config { + Config { + embed_scale: true, + norm_one_plus: true, + mlp_geglu: true, + ..dense_base() + } + } + + fn gemma3_like() -> Config { + Config { + embed_scale: true, + norm_one_plus: true, + mlp_geglu: true, + norm_style: NormStyle::GemmaFf, + qk_norm: Some(QkNorm { + per_head: true, + one_plus: true, + }), + rope_local_base: Some(1e3), + query_pre_attn_scalar: Some(4.0), + sliding_window: Some(2), + sliding_pattern: 3, + ..dense_base() + } + } + + fn olmo2_like() -> Config { + Config { + norm_style: NormStyle::OlmoPost, + qk_norm: Some(QkNorm { + per_head: false, + one_plus: false, + }), + tie_word_embeddings: false, + ..dense_base() + } + } + + /// The three shared-core graph kinds (single decode, ragged decode, prefill), + /// as `(name, emitter)` pairs, so a family delta is checked reaching all of them. + fn shared_core_kinds(c: &Config) -> [(&'static str, String); 3] { + [ + ("decode", emit_decode(c, false)), + ("ragged", emit_decode_ragged(c, 4, false)), + ("prefill", emit_prefill(c, false)), + ] + } + + /// Qwen3's per-head q/k RMSNorm reaches every shared-core graph kind: turning it + /// on adds exactly the two `[head_dim]` norm weights per layer and the two extra + /// `rsqrt`s (one for q, one for k) that normalize each head before RoPE, and + /// nothing else. A with/without diff over an otherwise-identical config isolates + /// exactly the q/k-norm surface. + #[test] + fn qwen3_per_head_qk_norm_reaches_every_shared_core_kind() { + let with = qwen3_like(); + let without = dense_base(); + let nl = with.n_layers; + for ((_, g_with), (name, g_without)) in shared_core_kinds(&with) + .iter() + .zip(shared_core_kinds(&without)) + { + assert_eq!( + arg_count(g_with) - arg_count(&g_without), + 2 * nl, + "{name}: q_norm + k_norm per layer" + ); + assert_eq!(occurs(g_with, "['q_norm']"), nl, "{name}: one q_norm/layer"); + assert_eq!(occurs(g_with, "['k_norm']"), nl, "{name}: one k_norm/layer"); + assert_eq!( + occurs(g_with, "stablehlo.rsqrt") - occurs(&g_without, "stablehlo.rsqrt"), + 2 * nl, + "{name}: one rsqrt each for the q and k head-norm per layer" + ); + // Per-head: the norm weight is [head_dim] (4), not the flat n_q*head_dim. + assert!( + g_with.contains("tensor<4xf32> loc(\"params['layers'][0]['q_norm']\")"), + "{name}: qwen3 q_norm is per-head [head_dim]" + ); + } + } + + /// OLMo2 is the reordered post-norm structure with a FLAT q/k norm and an untied + /// head: no `input_layernorm`, the q/k norm sized over the whole projection + /// (`n_q*head_dim` = 12, `n_kv*head_dim` = 4), a `post_feedforward_layernorm` but + /// no `pre_feedforward_layernorm`, and a separate `lm_head`. Asserted on every + /// shared-core kind so the post-norm reaches all of them. + #[test] + fn olmo2_flat_qk_norm_and_post_norm_structure() { + let c = olmo2_like(); + let nl = c.n_layers; + for (name, g) in shared_core_kinds(&c) { + assert_eq!( + occurs(&g, "['in_ln']"), + 0, + "{name}: OLMo2 has no input norm" + ); + assert_eq!(occurs(&g, "['q_norm']"), nl, "{name}: one q_norm/layer"); + assert_eq!(occurs(&g, "['k_norm']"), nl, "{name}: one k_norm/layer"); + assert_eq!( + occurs(&g, "['post_ff_ln']"), + nl, + "{name}: post-feedforward norm/layer" + ); + assert_eq!(occurs(&g, "['pre_ff_ln']"), 0, "{name}: no pre-ff norm"); + assert_eq!(occurs(&g, "params['lm_head']"), 1, "{name}: untied head"); + // Flat: q_norm is [n_q*head_dim] (12), NOT [head_dim]. + assert!( + g.contains("tensor<12xf32> loc(\"params['layers'][0]['q_norm']\")"), + "{name}: olmo2 q_norm is flat [n_q*head_dim]" + ); + } + } + + /// Gemma1 has the Gemma activation/scale surface (GeGLU `tanh`, `(1+w)` norm, + /// embedding scale) but the Llama TWO-norm layer (an `input_layernorm`, no + /// pre/post feed-forward norms) and no q/k norm, distinguishing it from Gemma2/3. + /// The embedding scale is isolated by a with/without diff. + #[test] + fn gemma1_is_two_norm_geglu_with_embed_scale() { + let g1 = gemma1_like(); + let nl = g1.n_layers; + for (name, g) in shared_core_kinds(&g1) { + assert!(g.contains("stablehlo.tanh"), "{name}: GeGLU emits tanh"); + assert_eq!( + occurs(&g, "['in_ln']"), + nl, + "{name}: Gemma1 keeps input norm" + ); + assert_eq!( + occurs(&g, "['pre_ff_ln']"), + 0, + "{name}: no pre-ff norm (2-norm)" + ); + assert_eq!( + occurs(&g, "['post_ff_ln']"), + 0, + "{name}: no post-ff norm (2-norm)" + ); + assert_eq!( + occurs(&g, "['q_norm']"), + 0, + "{name}: Gemma1 has no q/k norm" + ); + } + // The embedding scale is one const + broadcast + multiply in the head. + let no_scale = Config { + embed_scale: false, + ..gemma1_like() + }; + let d_with = emit_decode(&g1, false); + let d_without = emit_decode(&no_scale, false); + assert_eq!( + occurs(&d_with, "stablehlo.multiply") - occurs(&d_without, "stablehlo.multiply"), + 1, + "embed scale adds exactly one head multiply" + ); + } + + /// SmolLM3's NoPE mask skips RoPE on the marked layers: the rotate-half + /// `concatenate` (two per rope'd layer, q and k) drops by exactly two per NoPE + /// layer, and nothing else changes. A with/without diff over the NoPE mask + /// isolates it on every shared-core kind. + #[test] + fn smollm3_nope_skips_rope_on_marked_layers() { + // Layer 3 is NoPE (the SmolLM3 every-fourth-layer pattern at n_layers = 4). + let with_nope = Config { + use_rope_layers: Some(vec![true, true, true, false]), + ..dense_base() + }; + assert!(!with_nope.layer_uses_rope(3), "layer 3 is NoPE"); + assert!(with_nope.layer_uses_rope(0), "layer 0 keeps RoPE"); + let all_rope = dense_base(); + for ((_, g_nope), (name, g_all)) in shared_core_kinds(&with_nope) + .iter() + .zip(shared_core_kinds(&all_rope)) + { + assert_eq!( + occurs(&g_all, "stablehlo.concatenate") - occurs(g_nope, "stablehlo.concatenate"), + 2, + "{name}: the one NoPE layer drops the q and k rotate-half concatenates" + ); + } + } + + /// Gemma3 pairs the Gemma2 four-norm layer and a per-head `(1+w)` q/k norm with a + /// DUAL RoPE base: the sliding layers rotate on a local-base table distinct from + /// the global one, so the graph carries two extra `[MAX_SEQ, head_dim]` constant + /// tables (cos_local, sin_local) versus a single-RoPE twin. + #[test] + fn gemma3_dual_rope_and_qk_norm_reach_the_shared_core() { + let g3 = gemma3_like(); + let nl = g3.n_layers; + for (name, g) in shared_core_kinds(&g3) { + assert_eq!(occurs(&g, "['q_norm']"), nl, "{name}: per-head q norm"); + assert_eq!( + occurs(&g, "['pre_ff_ln']"), + nl, + "{name}: Gemma 4-norm (pre)" + ); + assert_eq!( + occurs(&g, "['post_ff_ln']"), + nl, + "{name}: Gemma 4-norm (post)" + ); + assert!(g.contains("stablehlo.tanh"), "{name}: GeGLU tanh"); + } + // Dual RoPE: two extra dense hex-blob constant tables (the local cos/sin). + let single = Config { + rope_local_base: None, + ..gemma3_like() + }; + for ((_, g_dual), (name, g_single)) in shared_core_kinds(&g3) + .iter() + .zip(shared_core_kinds(&single)) + { + assert_eq!( + occurs(g_dual, "stablehlo.constant dense<\"0x") + - occurs(&g_single, "stablehlo.constant dense<\"0x"), + 2, + "{name}: dual-RoPE adds the local cos + sin tables" + ); + } + } + + /// The local (sliding) layers rotate on the local RoPE table and the global + /// layers on the global table (Gemma3 dual-RoPE). With `sliding_pattern = 3` and + /// `n_layers = 4`, layers 0/1/3 are sliding and layer 2 is global, so + /// `local_rope_layer` selects the local table on exactly the sliding layers. + #[test] + fn gemma3_local_rope_selected_on_sliding_layers() { + let g3 = gemma3_like(); // sliding_pattern = 3 + assert!(g3.local_rope_layer(0), "layer 0 sliding -> local rope"); + assert!(g3.local_rope_layer(1), "layer 1 sliding -> local rope"); + assert!( + !g3.local_rope_layer(2), + "layer 2 global (3rd) -> global rope" + ); + assert!(g3.local_rope_layer(3), "layer 3 sliding -> local rope"); + } + + /// Each new family's real `config.json` shape parses to the expected flags. + #[test] + fn from_json_parses_new_dense_families() { + // Qwen3: per-head q/k norm, no bias, explicit head_dim, plain RoPE. + let qwen3 = r#"{"model_type":"qwen3","hidden_size":1024,"num_attention_heads":16, + "num_key_value_heads":8,"head_dim":128,"intermediate_size":3072, + "num_hidden_layers":28,"rms_norm_eps":1e-6,"rope_theta":1000000,"vocab_size":151936, + "attention_bias":false,"tie_word_embeddings":true}"#; + let c = Config::from_json_str(qwen3).expect("qwen3 parses"); + assert_eq!( + c.qk_norm, + Some(QkNorm { + per_head: true, + one_plus: false + }) + ); + assert!(!c.qkv_bias, "Qwen3 drops the Qwen2 bias"); + assert_eq!(c.head_dim, 128, "explicit head_dim != hidden/heads"); + assert_eq!(c.norm_style, NormStyle::Plain); + + // Gemma1: Plain norm, embed scale + (1+w) + GeGLU, no q/k norm, no sliding. + let gemma = r#"{"model_type":"gemma","hidden_size":2048,"num_attention_heads":8, + "num_key_value_heads":1,"head_dim":256,"intermediate_size":16384, + "num_hidden_layers":18,"rms_norm_eps":1e-6,"rope_theta":10000.0,"vocab_size":256000, + "hidden_activation":"gelu_pytorch_tanh"}"#; + let c = Config::from_json_str(gemma).expect("gemma parses"); + assert_eq!( + c.norm_style, + NormStyle::Plain, + "Gemma1 is Llama-shaped 2-norm" + ); + assert!(c.embed_scale && c.norm_one_plus && c.mlp_geglu); + assert!(c.qk_norm.is_none() && c.sliding_window.is_none()); + + // Gemma3: GemmaFf 4-norm, per-head (1+w) q/k norm, dual RoPE, 5:1 sliding. + let gemma3 = r#"{"model_type":"gemma3_text","hidden_size":1152,"num_attention_heads":4, + "num_key_value_heads":1,"head_dim":256,"intermediate_size":6912, + "num_hidden_layers":26,"rms_norm_eps":1e-6,"rope_theta":1000000, + "rope_local_base_freq":10000,"sliding_window":512,"sliding_window_pattern":6, + "query_pre_attn_scalar":256,"attn_logit_softcapping":null, + "final_logit_softcapping":null,"vocab_size":262144, + "hidden_activation":"gelu_pytorch_tanh"}"#; + let c = Config::from_json_str(gemma3).expect("gemma3 parses"); + assert_eq!(c.norm_style, NormStyle::GemmaFf); + assert_eq!( + c.qk_norm, + Some(QkNorm { + per_head: true, + one_plus: true + }) + ); + assert_eq!(c.rope_local_base, Some(10000.0), "distinct local RoPE base"); + assert_eq!(c.sliding_window, Some(512)); + assert_eq!(c.sliding_pattern, 6, "5 local : 1 global"); + assert!(c.attn_logit_softcap.is_none(), "Gemma3 drops the soft-caps"); + + // SmolLM3: NoPE mask (every 4th layer), no q/k norm. + let smollm3 = r#"{"model_type":"smollm3","hidden_size":2048,"num_attention_heads":16, + "num_key_value_heads":4,"intermediate_size":11008,"num_hidden_layers":8, + "rms_norm_eps":1e-6,"rope_theta":5000000.0,"vocab_size":128256, + "no_rope_layers":[1,1,1,0,1,1,1,0]}"#; + let c = Config::from_json_str(smollm3).expect("smollm3 parses"); + assert!(c.qk_norm.is_none()); + let rope = c.use_rope_layers.as_ref().expect("NoPE mask present"); + assert_eq!(rope, &[true, true, true, false, true, true, true, false]); + assert!(!c.layer_uses_rope(3) && c.layer_uses_rope(0)); + + // OLMo2: reordered post-norm, flat q/k norm, untied. + let olmo2 = r#"{"model_type":"olmo2","hidden_size":4096,"num_attention_heads":32, + "num_key_value_heads":32,"intermediate_size":11008,"num_hidden_layers":32, + "rms_norm_eps":1e-6,"rope_theta":500000,"vocab_size":100352, + "tie_word_embeddings":false}"#; + let c = Config::from_json_str(olmo2).expect("olmo2 parses"); + assert_eq!(c.norm_style, NormStyle::OlmoPost); + assert_eq!( + c.qk_norm, + Some(QkNorm { + per_head: false, + one_plus: false + }) + ); + assert!(!c.tie_word_embeddings); + + // A plain-RoPE OLMo3 (structure only; the full checkpoint's yarn RoPE is a + // documented follow-up rejected by the rope guard). + let olmo3 = r#"{"model_type":"olmo3","hidden_size":5120,"num_attention_heads":40, + "num_key_value_heads":8,"intermediate_size":27648,"num_hidden_layers":64, + "rms_norm_eps":1e-6,"rope_theta":500000,"vocab_size":100278, + "sliding_window":4096,"sliding_window_pattern":4,"tie_word_embeddings":false}"#; + let c = Config::from_json_str(olmo3).expect("plain-rope olmo3 parses"); + assert_eq!(c.norm_style, NormStyle::OlmoPost); + assert_eq!( + c.qk_norm, + Some(QkNorm { + per_head: false, + one_plus: false + }) + ); + assert_eq!(c.sliding_window, Some(4096)); + assert_eq!(c.sliding_pattern, 4, "3 sliding : 1 global"); + } + + /// OLMo3 at full size uses yarn RoPE, which is rejected with a clear message + /// (the documented follow-up), rather than silently mis-emitted. + #[test] + fn from_json_rejects_yarn_olmo3() { + let olmo3_yarn = r#"{"model_type":"olmo3","hidden_size":5120,"num_attention_heads":40, + "num_key_value_heads":8,"intermediate_size":27648,"num_hidden_layers":64, + "rms_norm_eps":1e-6,"rope_theta":500000,"vocab_size":100278,"sliding_window":4096, + "tie_word_embeddings":false,"rope_scaling":{"rope_type":"yarn","factor":8.0, + "original_max_position_embeddings":8192}}"#; + assert!( + Config::from_json_str(olmo3_yarn) + .unwrap_err() + .contains("rope_type"), + "yarn RoPE is rejected" + ); + } } diff --git a/src/lib/mlxcel-xla/src/emitter/model.rs b/src/lib/mlxcel-xla/src/emitter/model.rs index 1d970589..638e7cee 100644 --- a/src/lib/mlxcel-xla/src/emitter/model.rs +++ b/src/lib/mlxcel-xla/src/emitter/model.rs @@ -15,7 +15,7 @@ //! is byte-identical to before. use super::builder::{Builder, Precision, Ty, Val, precision_from_env}; -use super::config::Config; +use super::config::{Config, NormStyle}; use super::rope; const MAX_SEQ: usize = 256; @@ -34,7 +34,9 @@ const PREFILL_LP: usize = MAX_SEQ; struct LayerW { down: Val, gate: Val, - in_ln: Val, + /// `input_layernorm` (`None` for OLMo2/3, whose reordered post-norm has no + /// input norm; the attention projects the raw residual instead). + in_ln: Option, post_ln: Val, up: Val, wk: Val, @@ -44,9 +46,15 @@ struct LayerW { bk: Option, bq: Option, bv: Option, - /// Gemma2 pre/post feed-forward norms (`None` for Llama / Qwen2). Gemma2 wraps - /// each sublayer in a pre- and a post-norm: `post_ln` becomes the POST-attn - /// norm, `pre_ff_ln` the pre-MLP norm, `post_ff_ln` the post-MLP norm. + /// q/k norm weights (`None` unless the arch has `qk_norm`). Per-head families + /// (Qwen3 / Gemma3) size them `[head_dim]`; flat families (OLMo2/3) size them + /// `[n_q*head_dim]` / `[n_kv*head_dim]`. + q_norm: Option, + k_norm: Option, + /// Gemma2/3 pre/post feed-forward norms and the OLMo2/3 post-feedforward norm + /// (`None` for the plain families). Gemma2/3 wrap each sublayer: `post_ln` is + /// the POST-attn norm, `pre_ff_ln` the pre-MLP norm, `post_ff_ln` the post-MLP + /// norm. OLMo2/3 have `post_ln` (post-attn) and `post_ff_ln` (post-MLP) only. pre_ff_ln: Option, post_ff_ln: Option, } @@ -107,12 +115,14 @@ fn head_weight<'a>(embed: &'a Val, lm_head: &'a Option) -> &'a Val { lm_head.as_ref().unwrap_or(embed) } -/// Append layer `li`'s weights (and, for `qkv_bias`, its q/k/v biases) in the one -/// canonical order every graph kind shares, so the emitted arg order matches -/// `weight_names` in `iree.rs` exactly. JAX-alphabetical weights (down, gate, -/// in_ln, post_ln, up, wk, wo, wq, wv), then — when `c.qkv_bias` — the k/q/v -/// projection biases (alphabetical, matching the wk, idx: &mut usize, c: &Config, li: usize) -> LayerW { let h = c.hidden; let inter = c.inter; @@ -121,13 +131,16 @@ fn take_layer_weights(decls: &mut Vec, idx: &mut usize, c: &Config, li: let p = |k: &str| format!("params['layers'][{}]['{}']", li, k); let down = take_arg(decls, idx, Ty::f32(vec![h, inter]), p("down")); let gate = take_arg(decls, idx, Ty::f32(vec![inter, h]), p("gate")); - let in_ln = take_arg(decls, idx, Ty::f32(vec![h]), p("in_ln")); + // input_layernorm: present unless the reordered (OLMo) post-norm drops it. + let in_ln = c + .has_input_norm() + .then(|| take_arg(decls, idx, Ty::f32(vec![h]), p("in_ln"))); let post_ln = take_arg(decls, idx, Ty::f32(vec![h]), p("post_ln")); let up = take_arg(decls, idx, Ty::f32(vec![inter, h]), p("up")); let wk = take_arg(decls, idx, Ty::f32(vec![kv, h]), p("wk")); // o_proj maps `[n_q*head_dim]` -> `[hidden]`, so its weight is `[h, qd]` (HF's // `[out, in]`). For Llama / Qwen2 `qd == h`, so this renders the same square - // type as before (byte-identical); Gemma2 is genuinely non-square. + // type as before (byte-identical); Gemma is genuinely non-square. let wo = take_arg(decls, idx, Ty::f32(vec![h, qd]), p("wo")); let wq = take_arg(decls, idx, Ty::f32(vec![qd, h]), p("wq")); let wv = take_arg(decls, idx, Ty::f32(vec![kv, h]), p("wv")); @@ -139,15 +152,28 @@ fn take_layer_weights(decls: &mut Vec, idx: &mut usize, c: &Config, li: } else { (None, None, None) }; - // Gemma2's two extra per-layer norms, appended after the q/k/v biases slot in - // the same order `weight_names` lists them (pre then post feed-forward). - let (pre_ff_ln, post_ff_ln) = if c.gemma2 { - let pre = take_arg(decls, idx, Ty::f32(vec![h]), p("pre_ff_ln")); - let post = take_arg(decls, idx, Ty::f32(vec![h]), p("post_ff_ln")); - (Some(pre), Some(post)) - } else { - (None, None) + // q/k norm weights, after the biases. Per-head (Qwen3 / Gemma3) sizes them + // `[head_dim]`; flat (OLMo2/3) sizes them `[n_q*head_dim]` / `[n_kv*head_dim]`. + let (q_norm, k_norm) = match c.qk_norm { + Some(qn) => { + let (qsz, ksz) = if qn.per_head { + (c.head_dim, c.head_dim) + } else { + (qd, kv) + }; + let qn_w = take_arg(decls, idx, Ty::f32(vec![qsz]), p("q_norm")); + let kn_w = take_arg(decls, idx, Ty::f32(vec![ksz]), p("k_norm")); + (Some(qn_w), Some(kn_w)) + } + None => (None, None), }; + // Feed-forward norms. Gemma2/3 add pre AND post; OLMo2/3 add post only. + let pre_ff_ln = c + .has_pre_ff_norm() + .then(|| take_arg(decls, idx, Ty::f32(vec![h]), p("pre_ff_ln"))); + let post_ff_ln = c + .has_post_ff_norm() + .then(|| take_arg(decls, idx, Ty::f32(vec![h]), p("post_ff_ln"))); LayerW { down, gate, @@ -161,6 +187,8 @@ fn take_layer_weights(decls: &mut Vec, idx: &mut usize, c: &Config, li: bk, bq, bv, + q_norm, + k_norm, pre_ff_ln, post_ff_ln, } @@ -260,6 +288,12 @@ fn render_signature(decls: &[ArgDecl]) -> String { struct Consts { cos_table: Val, sin_table: Val, + /// Gemma3 / OLMo3 local RoPE tables (`Some` only when the config has a local + /// base), used by the sliding (local) layers; the global layers keep the + /// `cos_table` / `sin_table`. Emitted only when present, so single-RoPE + /// families are byte-identical. + cos_local: Option, + sin_local: Option, zero: Val, one: Val, neg_inf: Val, @@ -275,6 +309,17 @@ fn emit_consts(b: &mut Builder, c: &Config) -> Consts { let (cos, sin) = rope::rope_tables(c, MAX_SEQ); let cos_table = b.const_tensor_f32(&cos, vec![MAX_SEQ, c.head_dim]); let sin_table = b.const_tensor_f32(&sin, vec![MAX_SEQ, c.head_dim]); + // Gemma3 / OLMo3 local RoPE tables (distinct base for the sliding layers). + let (cos_local, sin_local) = match c.rope_local_base { + Some(base) => { + let (cl, sl) = rope::rope_tables_local(c, MAX_SEQ, base); + ( + Some(b.const_tensor_f32(&cl, vec![MAX_SEQ, c.head_dim])), + Some(b.const_tensor_f32(&sl, vec![MAX_SEQ, c.head_dim])), + ) + } + None => (None, None), + }; let zero = b.const_f32(0.0); let one = b.const_f32(1.0); let neg_inf = b.const_f32(f32::NEG_INFINITY); @@ -287,6 +332,8 @@ fn emit_consts(b: &mut Builder, c: &Config) -> Consts { Consts { cos_table, sin_table, + cos_local, + sin_local, zero, one, neg_inf, @@ -311,25 +358,109 @@ fn rms_norm(b: &mut Builder, x: &Val, w: &Val, k: &Consts, hidden: usize) -> Val b.multiply(&xr, w) } -/// Gemma2 `(1 + weight)` norm scale (`weight + 1` over the `[hidden]` feature -/// axis). Gemma stores the RMSNorm weight offset by one, so the gemma2 paths pass -/// `gemma_norm_w(...)` where Llama / Qwen2 pass the raw weight. -fn gemma_norm_w(b: &mut Builder, w: &Val, k: &Consts, hidden: usize) -> Val { - let one = b.broadcast(&k.one, &[], vec![hidden]); +/// Gemma `(1 + weight)` norm scale (`weight + 1` over a `[dim]` feature axis). +/// Gemma stores the RMSNorm weight offset by one, so the Gemma paths pass +/// `gemma_norm_w(...)` where the other families pass the raw weight. +fn gemma_norm_w(b: &mut Builder, w: &Val, k: &Consts, dim: usize) -> Val { + let one = b.broadcast(&k.one, &[], vec![dim]); b.add(w, &one) } -/// The RMSNorm weight to feed `rms_norm`: `1 + w` for Gemma2, the raw `w` -/// otherwise. A `Val` clone is just a handle copy (no emitted op), so the -/// Llama / Qwen2 graphs are unchanged. +/// The RMSNorm weight to feed `rms_norm`: `1 + w` for the Gemma family, the raw +/// `w` otherwise. A `Val` clone is just a handle copy (no emitted op), so the +/// non-Gemma graphs are unchanged. fn norm_w(b: &mut Builder, w: &Val, c: &Config, k: &Consts, hidden: usize) -> Val { - if c.gemma2 { + if c.norm_one_plus { gemma_norm_w(b, w, k, hidden) } else { w.clone() } } +/// RMSNorm over the LAST axis of `x` (any rank), with `w` broadcast over that axis +/// and the optional Gemma `(1+w)` offset. The one helper serves both q/k-norm +/// flavors: per-head norm passes the head-shaped tensor `[.., heads, d]` (reduce +/// over `d`, weight `[d]`), while flat norm passes the folded `[.., heads*d]` +/// (reduce over `heads*d`, weight `[heads*d]`). Emitted only on a `qk_norm` arch, +/// so every existing graph is unchanged. +fn last_axis_rms_norm(b: &mut Builder, x: &Val, w: &Val, k: &Consts, one_plus: bool) -> Val { + let shape = x.ty.shape.clone(); + let last = shape.len() - 1; + let d = shape[last]; + let lead: Vec = (0..last).collect(); + let df = b.const_f32(d as f32); + let sq = b.multiply(x, x); + let ssum = b.reduce_add(&sq, last, &k.zero); // [..lead..] + let red_shape = ssum.ty.shape.clone(); + let dfb = b.broadcast(&df, &[], red_shape.clone()); + let mean = b.divide(&ssum, &dfb); + let epsb = b.broadcast(&k.eps, &[], red_shape); + let meps = b.add(&mean, &epsb); + let r = b.rsqrt(&meps); + let rb = b.broadcast(&r, &lead, shape.clone()); + let xr = b.multiply(x, &rb); + let wv = if one_plus { + gemma_norm_w(b, w, k, d) + } else { + w.clone() + }; + let wb = b.broadcast(&wv, &[last], shape); + b.multiply(&xr, &wb) +} + +/// Fold the last two axes of a head-shaped tensor `[.., heads, d]` into +/// `[.., heads*d]` (the flat q/k-norm feature layout). +fn fold_heads(b: &mut Builder, x: &Val) -> Val { + let mut shape = x.ty.shape.clone(); + let d = shape.pop().expect("head dim"); + let heads = shape.pop().expect("head count"); + shape.push(heads * d); + b.reshape(x, shape) +} + +/// Apply the reserved q/k normalization (issue #494 hook), if the arch has one, to +/// the projected q / k before RoPE. Per-head (Qwen3 / Gemma3) norms each head over +/// `head_dim`; flat (OLMo2/3) folds the heads into the feature axis, norms the +/// whole `[.., heads*d]`, and unfolds. `q` is `[.., n_q, d]` and `kk` is +/// `[.., n_kv, d]` (single decode has no leading axis; the seq paths carry `[N, +/// ...]`). No-op (and no emitted op) for a config without `qk_norm`. +fn apply_qk_norm( + b: &mut Builder, + c: &Config, + k: &Consts, + lw: &LayerW, + q: Val, + kk: Val, +) -> (Val, Val) { + let Some(qn) = c.qk_norm else { + return (q, kk); + }; + let qw = lw + .q_norm + .as_ref() + .expect("qk_norm arch has a q_norm weight"); + let kw = lw + .k_norm + .as_ref() + .expect("qk_norm arch has a k_norm weight"); + if qn.per_head { + let q = last_axis_rms_norm(b, &q, qw, k, qn.one_plus); + let kk = last_axis_rms_norm(b, &kk, kw, k, qn.one_plus); + (q, kk) + } else { + // Flat: fold heads into the feature axis, norm, unfold back to head layout. + let q_shape = q.ty.shape.clone(); + let k_shape = kk.ty.shape.clone(); + let q_folded = fold_heads(b, &q); + let q_normed = last_axis_rms_norm(b, &q_folded, qw, k, qn.one_plus); + let q = b.reshape(&q_normed, q_shape); + let k_folded = fold_heads(b, &kk); + let k_normed = last_axis_rms_norm(b, &k_folded, kw, k, qn.one_plus); + let kk = b.reshape(&k_normed, k_shape); + (q, kk) + } +} + /// Gemma2 `gelu_pytorch_tanh` activation, elementwise over `x` (any shape): /// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`. fn gelu_tanh(b: &mut Builder, x: &Val) -> Val { @@ -363,26 +494,52 @@ fn softcap(b: &mut Builder, x: &Val, cap: f32) -> Val { b.multiply(&t, &capb) } +/// The pre-MLP input at a layout's rank: normed for the pre-norm families (Llama / +/// Qwen / Gemma use `post_attention_layernorm` or `pre_feedforward_layernorm`), the +/// RAW residual for the OLMo reordered post-norm (whose MLP consumes the residual +/// and norms its output instead). Returns the (possibly cloned) input the MLP +/// projects. `norm` applies the layout's rank-appropriate RMSNorm. +fn mlp_pre_norm( + b: &mut Builder, + c: &Config, + k: &Consts, + lw: &LayerW, + x: &Val, + norm: impl FnOnce(&mut Builder, &Val) -> Val, +) -> Val { + match c.norm_style { + NormStyle::OlmoPost => x.clone(), + NormStyle::GemmaFf => { + let w = norm_w( + b, + lw.pre_ff_ln.as_ref().expect("gemma pre_ff_ln"), + c, + k, + c.hidden, + ); + norm(b, &w) + } + NormStyle::Plain => { + let w = norm_w(b, &lw.post_ln, c, k, c.hidden); + norm(b, &w) + } + } +} + /// The seq-shaped (`[n, H]`) per-layer MLP plus its surrounding norms, shared by -/// every multi-row graph (prefill, ragged decode). Llama / Qwen2: a pre-MLP -/// `post_attention_layernorm` then SwiGLU. Gemma2: a pre-MLP -/// `pre_feedforward_layernorm`, GeGLU, and a post-MLP `post_feedforward_layernorm`. -/// Returns the residual already added (`x + down`). For a non-Gemma2 config it -/// emits exactly the op sequence the graphs carried inline, so their text is -/// byte-identical. Writing it once is the lever that makes a new architecture's -/// MLP delta (here, GeGLU + the two FF norms) reach every serve graph at once. +/// every multi-row graph (prefill, ragged decode). Pre-norm families norm the input +/// (`post_attention_layernorm` for Llama/Qwen, `pre_feedforward_layernorm` for +/// Gemma2/3); the OLMo post-norm consumes the raw residual and norms the OUTPUT. +/// The activation is SwiGLU (silu) or, for the Gemma family, GeGLU (gelu_tanh); a +/// post-MLP `post_feedforward_layernorm` follows for Gemma2/3 and OLMo2/3. Returns +/// the residual already added (`x + down`). Llama / Qwen2 / Gemma2 emit exactly the +/// op sequence they carried before, so their text is byte-identical. fn seq_mlp(b: &mut Builder, c: &Config, lw: &LayerW, k: &Consts, x: &Val, n: usize) -> Val { let h = c.hidden; - let pre_mlp = if c.gemma2 { - lw.pre_ff_ln.as_ref().expect("gemma2 pre_ff_ln") - } else { - &lw.post_ln - }; - let pre_mlp_w = norm_w(b, pre_mlp, c, k, h); - let hn2 = rms_norm_seq(b, x, &pre_mlp_w, k, n, h); - let gate = b.linear_seq(&hn2, &lw.gate); - let up = b.linear_seq(&hn2, &lw.up); - let act = if c.gemma2 { + let mlp_in = mlp_pre_norm(b, c, k, lw, x, |b, w| rms_norm_seq(b, x, w, k, n, h)); + let gate = b.linear_seq(&mlp_in, &lw.gate); + let up = b.linear_seq(&mlp_in, &lw.up); + let act = if c.mlp_geglu { gelu_tanh(b, &gate) } else { let neg = b.negate(&gate); @@ -394,14 +551,8 @@ fn seq_mlp(b: &mut Builder, c: &Config, lw: &LayerW, k: &Consts, x: &Val, n: usi }; let act = b.multiply(&act, &up); let down = b.linear_seq(&act, &lw.down); - let down = if c.gemma2 { - let w = norm_w( - b, - lw.post_ff_ln.as_ref().expect("gemma2 post_ff_ln"), - c, - k, - h, - ); + let down = if c.has_post_ff_norm() { + let w = norm_w(b, lw.post_ff_ln.as_ref().expect("post_ff_ln"), c, k, h); rms_norm_seq(b, &down, &w, k, n, h) } else { down @@ -409,6 +560,58 @@ fn seq_mlp(b: &mut Builder, c: &Config, lw: &LayerW, k: &Consts, x: &Val, n: usi b.add(x, &down) } +/// The rank-1 (`[H]`) single-token-decode analog of [`seq_mlp`]: same norm +/// placement and activation, at the single-sequence decode rank. Llama / Qwen2 / +/// Gemma2 emit exactly the op sequence the decode graph carried inline, so their +/// text is byte-identical. +fn single_mlp(b: &mut Builder, c: &Config, lw: &LayerW, k: &Consts, x: &Val) -> Val { + let h = c.hidden; + let mlp_in = mlp_pre_norm(b, c, k, lw, x, |b, w| rms_norm(b, x, w, k, h)); + let gate = b.linear(&mlp_in, &lw.gate); + let up = b.linear(&mlp_in, &lw.up); + let act = if c.mlp_geglu { + gelu_tanh(b, &gate) + } else { + // silu(gate) = gate * sigmoid(gate), sigmoid(z) = 1/(1+exp(-z)) + let neg = b.negate(&gate); + let ex = b.exponential(&neg); + let one_b = b.broadcast(&k.one, &[], vec![c.inter]); + let denom = b.add(&one_b, &ex); + let sig = b.divide(&one_b, &denom); + b.multiply(&gate, &sig) + }; + let act = b.multiply(&act, &up); + let down = b.linear(&act, &lw.down); + let down = if c.has_post_ff_norm() { + let w = norm_w(b, lw.post_ff_ln.as_ref().expect("post_ff_ln"), c, k, h); + rms_norm(b, &down, &w, k, h) + } else { + down + }; + b.add(x, &down) +} + +/// Select the RoPE (cos, sin) tensors for a layer: the local-base pair on a local +/// (sliding) layer of a dual-RoPE config, else the global pair. `local` is +/// [`Config::local_rope_layer`]; a single-RoPE config carries `None` locals, so +/// this returns the global pair and the emit is byte-identical. +fn pick_rope<'a>( + local: bool, + cos: &'a Val, + sin: &'a Val, + cos_local: &'a Option, + sin_local: &'a Option, +) -> (&'a Val, &'a Val) { + if local { + ( + cos_local.as_ref().unwrap_or(cos), + sin_local.as_ref().unwrap_or(sin), + ) + } else { + (cos, sin) + } +} + /// HF half-split RoPE on x:[heads, d]; cos/sin are [d] for the position. fn apply_rope(b: &mut Builder, x: &Val, cos: &Val, sin: &Val, heads: usize, d: usize) -> Val { let half = d / 2; @@ -456,6 +659,10 @@ enum AttnLayout { Single { cos: Val, sin: Val, + /// Gemma3 / OLMo3 local-base RoPE vectors for the sliding layers (`Some` + /// only for a dual-RoPE config); the global layers use `cos` / `sin`. + cos_local: Option, + sin_local: Option, mask: Val, mask_local: Option, cache_len: Val, @@ -467,6 +674,8 @@ enum AttnLayout { bsz: usize, cos: Val, sin: Val, + cos_local: Option, + sin_local: Option, mask: Val, mask_local: Option, pos: Val, @@ -479,6 +688,8 @@ enum AttnLayout { lp: usize, cos: Val, sin: Val, + cos_local: Option, + sin_local: Option, mask: Val, mask_local: Option, }, @@ -541,22 +752,49 @@ impl AttnLayout { (q, kk, vv) } - /// Apply this kind's RoPE to q and k (v is never rotated). - fn rope_qk(&self, b: &mut Builder, c: &Config, q: &Val, kk: &Val) -> (Val, Val) { + /// Apply this kind's RoPE to q and k for layer `li` (v is never rotated). A + /// dual-RoPE config (Gemma3 / OLMo3) selects the local-base table on a sliding + /// layer and the global table on a full layer; single-RoPE configs always use + /// the global table (byte-identical to before). `li` is unused for the latter. + fn rope_qk(&self, b: &mut Builder, c: &Config, li: usize, q: &Val, kk: &Val) -> (Val, Val) { let d = c.head_dim; let (nq, nkv) = (c.n_q, c.n_kv); + let local = c.local_rope_layer(li); match self { - AttnLayout::Single { cos, sin, .. } => { + AttnLayout::Single { + cos, + sin, + cos_local, + sin_local, + .. + } => { + let (cos, sin) = pick_rope(local, cos, sin, cos_local, sin_local); let q = apply_rope(b, q, cos, sin, nq, d); let kk = apply_rope(b, kk, cos, sin, nkv, d); (q, kk) } - AttnLayout::Ragged { bsz, cos, sin, .. } => { + AttnLayout::Ragged { + bsz, + cos, + sin, + cos_local, + sin_local, + .. + } => { + let (cos, sin) = pick_rope(local, cos, sin, cos_local, sin_local); let q = apply_rope_ragged(b, q, cos, sin, *bsz, nq, d); let kk = apply_rope_ragged(b, kk, cos, sin, *bsz, nkv, d); (q, kk) } - AttnLayout::Prefill { lp, cos, sin, .. } => { + AttnLayout::Prefill { + lp, + cos, + sin, + cos_local, + sin_local, + .. + } => { + let (cos, sin) = pick_rope(local, cos, sin, cos_local, sin_local); let q = apply_rope_seq(b, q, cos, sin, *lp, nq, d); let kk = apply_rope_seq(b, kk, cos, sin, *lp, nkv, d); (q, kk) @@ -819,23 +1057,31 @@ fn apply_scale_and_softcap(b: &mut Builder, c: &Config, k: &Consts, scores: Val) } } -/// The architecture RMSNorm applied at a layout's rank: the Gemma2 `(1 + w)` -/// weight offset (a no-op handle-copy for Llama / Qwen2) followed by the layout's -/// rank-appropriate RMSNorm. +/// The input RMSNorm applied at a layout's rank: the Gemma `(1 + w)` weight offset +/// (a no-op handle-copy for the non-Gemma families) followed by the layout's +/// rank-appropriate RMSNorm. The OLMo reordered post-norm has NO input norm +/// (`w_raw` is `None`), so the attention projects the raw residual unchanged. fn arch_norm( b: &mut Builder, c: &Config, k: &Consts, layout: &AttnLayout, x: &Val, - w_raw: &Val, + w_raw: &Option, ) -> Val { - let w = norm_w(b, w_raw, c, k, c.hidden); - layout.norm(b, c, k, x, &w) + match w_raw { + Some(w_raw) => { + let w = norm_w(b, w_raw, c, k, c.hidden); + layout.norm(b, c, k, x, &w) + } + None => x.clone(), + } } -/// Gemma2's post-attention RMSNorm on the sublayer output before the residual (a -/// no-op for Llama / Qwen2, which have no such norm), applied at the layout's rank. +/// The post-attention RMSNorm on the attention output before the residual, for the +/// families that have one (`post_attention_layernorm` in Gemma2/3 and OLMo2/3), +/// applied at the layout's rank. A no-op (handle copy) for the plain pre-norm +/// families (Llama / Qwen2/3 / Gemma1 / SmolLM3), which have no such norm. fn post_attn_norm( b: &mut Builder, c: &Config, @@ -844,7 +1090,7 @@ fn post_attn_norm( attn_out: Val, lw: &LayerW, ) -> Val { - if c.gemma2 { + if c.has_post_attn_norm() { let w = norm_w(b, &lw.post_ln, c, k, c.hidden); layout.norm(b, c, k, &attn_out, &w) } else { @@ -873,11 +1119,18 @@ fn emit_attention( ) -> Val { let hn = arch_norm(b, c, k, layout, x, &lw.in_ln); let (q, kk, vv) = layout.project_qkv(b, c, &hn, lw); - // Reserved hook: a future per-head q/k normalization (e.g. Qwen3) is applied - // to q and kk here, once, and reaches the single / ragged / prefill paths - // together. No dense family the emitter serves emits it yet, so nothing is - // emitted today and every existing graph is unchanged. - let (q, kk) = layout.rope_qk(b, c, &q, &kk); + // Reserved q/k-norm hook (issue #494): Qwen3 / Gemma3 norm each head over + // head_dim, OLMo2/3 norm the whole flat projection; both before RoPE, applied + // once here so they reach the single / ragged / prefill paths together. A + // config without `qk_norm` emits nothing, so every existing graph is unchanged. + let (q, kk) = apply_qk_norm(b, c, k, lw, q, kk); + // RoPE, unless this layer is a NoPE layer (SmolLM3), and on the layer's own + // RoPE base (dual-RoPE local/global for Gemma3 / OLMo3). + let (q, kk) = if c.layer_uses_rope(li) { + layout.rope_qk(b, c, li, &q, &kk) + } else { + (q, kk) + }; let (kslab, vslab) = layout.write_read_kv(b, k, c, li, &kk, &vv, kcache, vcache); let scores = layout.raw_scores(b, c, &q, &kslab); let scores = apply_scale_and_softcap(b, c, k, scores); @@ -907,8 +1160,8 @@ pub fn emit_decode_with(c: &Config, sample: bool, precision: Precision) -> Strin // --- head: embed gather, rope vectors, decode key mask --- let emb_row = b.dynamic_slice(&a.embed, &[&a.token, &k.c0], vec![1, h]); let mut x = b.reshape(&emb_row, vec![h]); - // Gemma2 scales the input embeddings by sqrt(hidden). - if c.gemma2 { + // Gemma scales the input embeddings by sqrt(hidden). + if c.embed_scale { let norm = b.const_f32(c.embed_normalizer()); let nb = b.broadcast(&norm, &[], vec![h]); x = b.multiply(&x, &nb); @@ -918,6 +1171,17 @@ pub fn emit_decode_with(c: &Config, sample: bool, precision: Precision) -> Strin let cos_vec = b.reshape(&cos_row, vec![d]); let sin_row = b.dynamic_slice(&k.sin_table, &[&a.pos, &k.c0], vec![1, d]); let sin_vec = b.reshape(&sin_row, vec![d]); + // Dual-RoPE (Gemma3 / OLMo3): the local-base [d] vectors for the sliding layers. + let (cos_local, sin_local) = match (&k.cos_local, &k.sin_local) { + (Some(ct), Some(st)) => { + let cr = b.dynamic_slice(ct, &[&a.pos, &k.c0], vec![1, d]); + let cl = b.reshape(&cr, vec![d]); + let sr = b.dynamic_slice(st, &[&a.pos, &k.c0], vec![1, d]); + let sl = b.reshape(&sr, vec![d]); + (Some(cl), Some(sl)) + } + _ => (None, None), + }; // mask: keys s valid iff s <= cache_len -> additive 0 / -1e30, shape [S] let ii = b.iota(MAX_SEQ); @@ -945,6 +1209,8 @@ pub fn emit_decode_with(c: &Config, sample: bool, precision: Precision) -> Strin let layout = AttnLayout::Single { cos: cos_vec, sin: sin_vec, + cos_local, + sin_local, mask: kmask, mask_local: kmask_local, cache_len: a.cache_len.clone(), @@ -959,48 +1225,12 @@ pub fn emit_decode_with(c: &Config, sample: bool, precision: Precision) -> Strin // attention block (shared per-layer core, issue #494) x = emit_attention(&mut b, c, &k, lw, li, &x, &layout, &mut kcache, &mut vcache); - // MLP. Pre-MLP norm: Llama / Qwen2 use post_attention_layernorm; Gemma2 - // uses pre_feedforward_layernorm (post_attention_layernorm became the - // post-attn norm above). Activation: SwiGLU (silu) vs Gemma2 GeGLU (gelu). - let pre_mlp = if c.gemma2 { - lw.pre_ff_ln.as_ref().expect("gemma2 pre_ff_ln") - } else { - &lw.post_ln - }; - let pre_mlp_w = norm_w(&mut b, pre_mlp, c, &k, h); - let hn2 = rms_norm(&mut b, &x, &pre_mlp_w, &k, h); - let gate = b.linear(&hn2, &lw.gate); - let up = b.linear(&hn2, &lw.up); - let act = if c.gemma2 { - gelu_tanh(&mut b, &gate) - } else { - // silu(gate) = gate * sigmoid(gate), sigmoid(z) = 1/(1+exp(-z)) - let neg = b.negate(&gate); - let ex = b.exponential(&neg); - let one_b = b.broadcast(&k.one, &[], vec![c.inter]); - let denom = b.add(&one_b, &ex); - let sig = b.divide(&one_b, &denom); - b.multiply(&gate, &sig) - }; - let act = b.multiply(&act, &up); - let down = b.linear(&act, &lw.down); - // Gemma2: post-MLP norm before the residual. - let down = if c.gemma2 { - let w = norm_w( - &mut b, - lw.post_ff_ln.as_ref().expect("gemma2 post_ff_ln"), - c, - &k, - h, - ); - rms_norm(&mut b, &down, &w, &k, h) - } else { - down - }; - x = b.add(&x, &down); + // MLP + its norms (shared with the seq graphs; norm placement / activation + // per the arch), at the single-token decode rank. + x = single_mlp(&mut b, c, lw, &k, &x); } - // --- tail: final norm + LM head (tied embed or untied lm_head), Gemma2 final + // --- tail: final norm + LM head (tied embed or untied lm_head), Gemma final // logit soft-cap, then optional on-device argmax --- let final_w = norm_w(&mut b, &a.final_norm, c, &k, h); let xf = rms_norm(&mut b, &x, &final_w, &k, h); @@ -1189,8 +1419,14 @@ pub fn emit_decode_batched_with( for li in 0..c.n_layers { let lw = &a.layers[li]; - // attention block (RMSNorm over H reuses the [N,H] seq variant, N=B) - let hn = rms_norm_seq(&mut b, &x, &lw.in_ln, &k, bsz, h); // [B, H] + // attention block (RMSNorm over H reuses the [N,H] seq variant, N=B). The + // superseded uniform-B batched graph serves only the pre-norm families + // (Llama / Qwen2), which always carry `input_layernorm`. + let in_ln = lw + .in_ln + .as_ref() + .expect("uniform-B batched decode is emitted only for pre-norm archs"); + let hn = rms_norm_seq(&mut b, &x, in_ln, &k, bsz, h); // [B, H] let q = b.linear_seq(&hn, &lw.wq); // [B, qd] let q = add_proj_bias_seq(&mut b, q, &lw.bq, bsz, nq * d); let q = b.reshape(&q, vec![bsz, nq, d]); @@ -1459,8 +1695,8 @@ pub fn emit_decode_ragged_with( // --- head: per-row embed gather, per-row rope gather, per-row key mask --- let tok_idx = b.reshape(&a.token, vec![bsz, 1]); let mut x = b.gather(&a.embed, &tok_idx); // [B, H] - // Gemma2 scales the input embeddings by sqrt(hidden). - if c.gemma2 { + // Gemma scales the input embeddings by sqrt(hidden). + if c.embed_scale { let norm = b.const_f32(c.embed_normalizer()); let nb = b.broadcast(&norm, &[], vec![bsz, h]); x = b.multiply(&x, &nb); @@ -1470,6 +1706,11 @@ pub fn emit_decode_ragged_with( let pos_idx = b.reshape(&a.pos, vec![bsz, 1]); let cos = b.gather(&k.cos_table, &pos_idx); // [B, d] let sin = b.gather(&k.sin_table, &pos_idx); // [B, d] + // Dual-RoPE (Gemma3 / OLMo3): per-row local-base gathers for the sliding layers. + let (cos_local, sin_local) = match (&k.cos_local, &k.sin_local) { + (Some(ct), Some(st)) => (Some(b.gather(ct, &pos_idx)), Some(b.gather(st, &pos_idx))), + _ => (None, None), + }; // per-row key mask [B,S]: key s valid for row b iff s <= cache_len[b] let ii = b.iota(MAX_SEQ); // [S] @@ -1495,6 +1736,8 @@ pub fn emit_decode_ragged_with( bsz, cos, sin, + cos_local, + sin_local, mask: kmask, mask_local: kmask_local, pos: a.pos.clone(), @@ -1679,8 +1922,8 @@ pub fn emit_prefill_with(c: &Config, sample: bool, precision: Precision) -> Stri // --- head: embed gather, per-position rope vectors, [Lp,Lp] causal mask --- let tok_idx = b.reshape(&a.tokens, vec![lp, 1]); let mut x = b.gather(&a.embed, &tok_idx); // [Lp, H] - // Gemma2 scales the input embeddings by sqrt(hidden). - if c.gemma2 { + // Gemma scales the input embeddings by sqrt(hidden). + if c.embed_scale { let norm = b.const_f32(c.embed_normalizer()); let nb = b.broadcast(&norm, &[], vec![lp, h]); x = b.multiply(&x, &nb); @@ -1689,6 +1932,11 @@ pub fn emit_prefill_with(c: &Config, sample: bool, precision: Precision) -> Stri let pos_idx = b.reshape(&a.positions, vec![lp, 1]); let cos = b.gather(&k.cos_table, &pos_idx); // [Lp, d] let sin = b.gather(&k.sin_table, &pos_idx); // [Lp, d] + // Dual-RoPE (Gemma3 / OLMo3): per-position local-base gathers for sliding layers. + let (cos_local, sin_local) = match (&k.cos_local, &k.sin_local) { + (Some(ct), Some(st)) => (Some(b.gather(ct, &pos_idx)), Some(b.gather(st, &pos_idx))), + _ => (None, None), + }; // causal mask [Lp, Lp]: query i attends key j iff j <= i -> additive 0/-1e30 let irow = b.iota(lp); @@ -1717,6 +1965,8 @@ pub fn emit_prefill_with(c: &Config, sample: bool, precision: Precision) -> Stri lp, cos, sin, + cos_local, + sin_local, mask: cmask, mask_local: cmask_local, }; diff --git a/src/lib/mlxcel-xla/src/emitter/rope.rs b/src/lib/mlxcel-xla/src/emitter/rope.rs index 9d09eb5a..cee4f22e 100644 --- a/src/lib/mlxcel-xla/src/emitter/rope.rs +++ b/src/lib/mlxcel-xla/src/emitter/rope.rs @@ -75,12 +75,26 @@ fn llama3_inv_freq( inv } -/// Build cos and sin tables of shape [max_seq, head_dim] as flat row-major f32. -/// emb = concat([freqs, freqs], -1) where freqs = outer(pos, inv_freq). -pub fn rope_tables(c: &Config, max_seq: usize) -> (Vec, Vec) { - let inv = inv_freq(c); - let half = c.head_dim / 2; - let d = c.head_dim; +/// Plain RoPE base frequencies for an explicit base (Gemma3 / OLMo3 local layers, +/// whose sliding layers use `rope_local_base_freq` instead of `rope_theta`): +/// `inv_freq[i] = 1 / base^((2i)/head_dim)`. +pub fn plain_inv_freq_with_base(head_dim: usize, base: f64) -> Vec { + let half = head_dim / 2; + (0..half) + .map(|i| { + let exponent = (2 * i) as f64 / head_dim as f64; + 1.0 / base.powf(exponent) + }) + .collect() +} + +/// Build cos and sin tables of shape [max_seq, head_dim] as flat row-major f32 from +/// a precomputed `inv_freq`. emb = concat([freqs, freqs], -1) where freqs = +/// outer(pos, inv_freq). Shared by the global table and the Gemma3/OLMo3 local +/// table. +pub fn rope_tables_from_inv(inv: &[f64], head_dim: usize, max_seq: usize) -> (Vec, Vec) { + let half = head_dim / 2; + let d = head_dim; let mut cos = vec![0.0f32; max_seq * d]; let mut sin = vec![0.0f32; max_seq * d]; for p in 0..max_seq { @@ -96,3 +110,19 @@ pub fn rope_tables(c: &Config, max_seq: usize) -> (Vec, Vec) { } (cos, sin) } + +/// Build cos and sin tables of shape [max_seq, head_dim] for the config's global +/// RoPE scheme. +pub fn rope_tables(c: &Config, max_seq: usize) -> (Vec, Vec) { + rope_tables_from_inv(&inv_freq(c), c.head_dim, max_seq) +} + +/// Build the local cos/sin tables for a config with a distinct local RoPE base +/// (Gemma3 / OLMo3 sliding layers). Plain RoPE at `base`. +pub fn rope_tables_local(c: &Config, max_seq: usize, base: f64) -> (Vec, Vec) { + rope_tables_from_inv( + &plain_inv_freq_with_base(c.head_dim, base), + c.head_dim, + max_seq, + ) +} diff --git a/src/lib/mlxcel-xla/src/iree.rs b/src/lib/mlxcel-xla/src/iree.rs index 61434652..6ca634d8 100644 --- a/src/lib/mlxcel-xla/src/iree.rs +++ b/src/lib/mlxcel-xla/src/iree.rs @@ -28,11 +28,13 @@ //! shim, which keeps the weights resident on the device and threads the KV cache //! across steps. Then [`IreeLlama::prefill`] / [`IreeLlama::decode`] are token-in //! / token-out. Emitting from config (issue #449 M3 Stage 2d) replaced the bundled -//! Llama-3.2-1B `.mlir` assets, so any checkpoint of a supported architecture -//! loads: Llama (any size) and Qwen2 (plain RoPE + q/k/v bias; Stage B), the -//! latter adding its bias tensors to `weight_names` to match the emitted graph. -//! An untied checkpoint (`tie_word_embeddings = false`, e.g. Llama-3.1-8B and the -//! larger Qwen2.5 sizes) adds its `lm_head.weight` to `weight_names`, matching the +//! Llama-3.2-1B `.mlir` assets, so any checkpoint of a supported dense family loads: +//! Llama, Qwen2, Qwen3, Gemma1/2/3, SmolLM3, and OLMo2/3 (issue #497). Each family's +//! per-layer weight order in `weight_names` mirrors the emitter's arg schedule: the +//! Qwen2 q/k/v biases, the Qwen3 / Gemma3 / OLMo2/3 q/k norms, the Gemma2/3 feed- +//! forward norms, and (for OLMo2/3's reordered post-norm) the absence of an +//! `input_layernorm`. An untied checkpoint (`tie_word_embeddings = false`, e.g. +//! Llama-3.1-8B, larger Qwen2.5, OLMo2/3) adds its `lm_head.weight`, matching the //! separate `params['lm_head']` arg the emitter takes for the final projection. //! //! Proven token-exact against the HF temp-0 reference in @@ -124,11 +126,13 @@ pub(crate) const RAGGED_B_VALUES: &[usize] = &[4, 8]; /// The weight names in the emitter's exact arg order: embed, final_norm, then — /// for an untied checkpoint (`tie_word_embeddings = false`) — `lm_head.weight`, -/// then per layer down, gate, in_ln, post_ln, up, wk, wo, wq, wv, and — for a -/// `qkv_bias` architecture (Qwen2) — the k/q/v projection biases. The layer count, -/// the untied head, and the presence of biases come from the model config so the -/// order matches the emitted graph's args (`take_lm_head` / `take_layer_weights` -/// in `emitter/model.rs`). +/// then per layer the base weights and the arch-conditional extras. The base order +/// is down, gate, input_layernorm, post_attention_layernorm, up, wk, wo, wq, wv; +/// then, matching `take_layer_weights` in `emitter/model.rs` exactly, the k/q/v +/// biases (`qkv_bias`, Qwen2), the q/k norms (`qk_norm`, Qwen3 / Gemma3 / OLMo2/3), +/// and the feed-forward norms. `input_layernorm` is skipped for the OLMo reordered +/// post-norm (which has none). Every knob comes from the model config so the buffer +/// order lines up with the emitted graph args. fn weight_names(cfg: &Config) -> Vec { let mut names = vec![ "model.embed_tokens.weight".to_string(), @@ -141,10 +145,13 @@ fn weight_names(cfg: &Config) -> Vec { } for i in 0..cfg.n_layers { let p = format!("model.layers.{i}."); + names.push(format!("{p}mlp.down_proj.weight")); + names.push(format!("{p}mlp.gate_proj.weight")); + // input_layernorm: present unless the reordered (OLMo) post-norm drops it. + if cfg.has_input_norm() { + names.push(format!("{p}input_layernorm.weight")); + } for suf in [ - "mlp.down_proj.weight", - "mlp.gate_proj.weight", - "input_layernorm.weight", "post_attention_layernorm.weight", "mlp.up_proj.weight", "self_attn.k_proj.weight", @@ -154,8 +161,7 @@ fn weight_names(cfg: &Config) -> Vec { ] { names.push(format!("{p}{suf}")); } - // Qwen2 q/k/v projection biases, appended per layer in the same k/q/v - // order `take_layer_weights` adds them to the emitted graph args. + // Qwen2 q/k/v projection biases, in the same k/q/v order the emitter adds. if cfg.qkv_bias { for suf in [ "self_attn.k_proj.bias", @@ -165,15 +171,17 @@ fn weight_names(cfg: &Config) -> Vec { names.push(format!("{p}{suf}")); } } - // Gemma2 has two extra per-layer norms (pre/post feed-forward), appended - // in the same order `take_layer_weights` takes their graph args. - if cfg.gemma2 { - for suf in [ - "pre_feedforward_layernorm.weight", - "post_feedforward_layernorm.weight", - ] { - names.push(format!("{p}{suf}")); - } + // q/k norms (Qwen3 / Gemma3 per-head, OLMo2/3 flat), q then k. + if cfg.qk_norm.is_some() { + names.push(format!("{p}self_attn.q_norm.weight")); + names.push(format!("{p}self_attn.k_norm.weight")); + } + // Feed-forward norms: Gemma2/3 add pre AND post; OLMo2/3 add post only. + if cfg.has_pre_ff_norm() { + names.push(format!("{p}pre_feedforward_layernorm.weight")); + } + if cfg.has_post_ff_norm() { + names.push(format!("{p}post_feedforward_layernorm.weight")); } } names diff --git a/src/lib/mlxcel-xla/src/validation.rs b/src/lib/mlxcel-xla/src/validation.rs index 4eb50082..ba5d2d50 100644 --- a/src/lib/mlxcel-xla/src/validation.rs +++ b/src/lib/mlxcel-xla/src/validation.rs @@ -361,6 +361,107 @@ pub(crate) static LLAMA_3_2_1B: ArchFixture = ArchFixture { /// byte-exact gate; see the module docs for the freeze workflow. pub(crate) static REGISTERED: &[&ArchFixture] = &[&LLAMA_3_2_1B]; +/// A golden-less structural fixture (issue #497): a small synthetic `config.json` +/// for a dense family plus the signature every emitted shared-core graph must +/// carry. It registers a family in the harness WITHOUT bundling byte-exact goldens, +/// which suits the new dense pack: their real checkpoints are large (Gemma / +/// SmolLM3 / OLMo2), post-cutoff-heavy, or need a follow-up (OLMo3 yarn RoPE), so +/// freezing real goldens would bloat the repo and pin an un-execution-proven graph. +/// The exact op deltas are locked by the emitter's with/without-diff tests, and the +/// execution tier (`spike/openxla/dense_arch_pack_check.py`) proves correctness on a +/// small synthetic model per family. Mirrors the Qwen2.5 golden-less emit test. +pub(crate) struct StructuralFixture { + /// Family id (e.g. `"qwen3"`). + pub arch: &'static str, + /// A small synthetic `config.json` carrying the family's real arch flags. + pub config_json: &'static str, + /// Substrings every shared-core graph (prefill / decode / ragged) must contain. + pub must_contain: &'static [&'static str], + /// Substrings none of those graphs may contain (the family's absent features). + pub must_not_contain: &'static [&'static str], +} + +/// The shared-core graph kinds a structural fixture is checked against (the serve +/// path: host-sampled prefill / decode logits and ragged decode). +pub(crate) const STRUCTURAL_KINDS: &[GraphKind] = &[ + GraphKind::Prefill { sample: false }, + GraphKind::Decode { sample: false }, + GraphKind::DecodeRagged { + b_max: 4, + sample: false, + }, +]; + +/// The registered golden-less dense families (issue #497). Small synthetic dims +/// (hidden 8, head_dim 4, 4 layers) keep the emit tiny while exercising every +/// arch delta; `head_dim` differs from `hidden / n_q` and `n_q*head_dim` from +/// `hidden`, so the flat q-norm and non-square o_proj widths are genuinely distinct. +pub(crate) static STRUCTURAL_FAMILIES: &[StructuralFixture] = &[ + StructuralFixture { + arch: "qwen3", + config_json: r#"{"model_type":"qwen3","hidden_size":8,"num_attention_heads":3, + "num_key_value_heads":1,"head_dim":4,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":1000000,"vocab_size":12,"attention_bias":false}"#, + must_contain: &["['q_norm']", "['k_norm']", "['in_ln']"], + must_not_contain: &["['bq']", "['pre_ff_ln']", "['post_ff_ln']"], + }, + StructuralFixture { + arch: "gemma1", + config_json: r#"{"model_type":"gemma","hidden_size":8,"num_attention_heads":2, + "num_key_value_heads":1,"head_dim":4,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":10000.0,"vocab_size":12, + "hidden_activation":"gelu_pytorch_tanh"}"#, + must_contain: &["stablehlo.tanh", "['in_ln']"], + must_not_contain: &["['pre_ff_ln']", "['post_ff_ln']", "['q_norm']"], + }, + StructuralFixture { + arch: "gemma3", + config_json: r#"{"model_type":"gemma3_text","hidden_size":8,"num_attention_heads":2, + "num_key_value_heads":1,"head_dim":4,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":1000000,"rope_local_base_freq":10000, + "sliding_window":2,"sliding_window_pattern":3,"query_pre_attn_scalar":4, + "vocab_size":12,"hidden_activation":"gelu_pytorch_tanh"}"#, + must_contain: &[ + "['q_norm']", + "['pre_ff_ln']", + "['post_ff_ln']", + "stablehlo.tanh", + ], + must_not_contain: &["['bq']"], + }, + StructuralFixture { + arch: "smollm3", + config_json: r#"{"model_type":"smollm3","hidden_size":8,"num_attention_heads":2, + "num_key_value_heads":1,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":5000000.0,"vocab_size":12, + "no_rope_layers":[1,1,1,0]}"#, + must_contain: &["['in_ln']"], + must_not_contain: &["['q_norm']", "['pre_ff_ln']", "['bq']"], + }, + StructuralFixture { + arch: "olmo2", + config_json: r#"{"model_type":"olmo2","hidden_size":8,"num_attention_heads":3, + "num_key_value_heads":1,"head_dim":4,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":500000,"vocab_size":12,"tie_word_embeddings":false}"#, + must_contain: &[ + "['q_norm']", + "['k_norm']", + "['post_ff_ln']", + "params['lm_head']", + ], + must_not_contain: &["['in_ln']", "['pre_ff_ln']"], + }, + StructuralFixture { + arch: "olmo3", + config_json: r#"{"model_type":"olmo3","hidden_size":8,"num_attention_heads":3, + "num_key_value_heads":1,"head_dim":4,"intermediate_size":16,"num_hidden_layers":4, + "rms_norm_eps":1e-6,"rope_theta":500000,"vocab_size":12,"sliding_window":2, + "sliding_window_pattern":4,"tie_word_embeddings":false}"#, + must_contain: &["['q_norm']", "['post_ff_ln']", "params['lm_head']"], + must_not_contain: &["['in_ln']", "['pre_ff_ln']"], + }, +]; + #[cfg(test)] mod tests { use super::*; @@ -434,6 +535,48 @@ mod tests { } } + /// Every registered golden-less dense family (issue #497) parses and emits each + /// shared-core graph kind carrying its expected structural signature: the + /// arch's signature args / ops are present and its absent features are absent, + /// in prefill, single decode, and ragged decode alike. This is the harness + /// registration for the new dense pack (Qwen3, Gemma1/3, SmolLM3, OLMo2/3), + /// whose byte-exact op deltas are pinned by the emitter's with/without-diff + /// tests and whose correctness is proven by the execution tier. + #[test] + fn structural_families_emit_expected_signature() { + for fx in STRUCTURAL_FAMILIES { + let graphs = emit_graphs(fx.config_json, STRUCTURAL_KINDS) + .unwrap_or_else(|e| panic!("{}: {e}", fx.arch)); + assert_eq!( + graphs.len(), + STRUCTURAL_KINDS.len(), + "{}: one graph/kind", + fx.arch + ); + for (kind, mlir) in &graphs { + assert!( + mlir.contains("stablehlo."), + "{} {kind}: not StableHLO", + fx.arch + ); + for needle in fx.must_contain { + assert!( + mlir.contains(needle), + "{} {kind}: missing signature {needle:?}", + fx.arch + ); + } + for needle in fx.must_not_contain { + assert!( + !mlir.contains(needle), + "{} {kind}: unexpected {needle:?}", + fx.arch + ); + } + } + } + } + /// The precision guard is a pure predicate over the env value, so it needs no /// racy env mutation to test: only `f16` / `bf16` are non-default. #[test]