From 9a190cba4ec0c08ed5b92fcdea19a2f927c11062 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Wed, 1 Jul 2026 09:39:10 +0900 Subject: [PATCH] test: add reusable per-architecture XLA validation harness Add a two-tier validation harness so adding an OpenXLA architecture is turnkey, replacing the ad-hoc per-arch flow (dequant offline, run an HF oracle, run xla_oracle_check by hand). Structural tier (src/lib/mlxcel-xla/src/validation.rs, pure Rust, no GPU): a per-architecture ArchFixture registry plus check_arch, which parses a checkpoint's config.json, emits each graph, and diffs it byte-for-byte against the committed golden assets//*.mlir, localizing the first differing line on drift. It honors MLXCEL_XLA_PRECISION (a byte-exact run under f16/bf16 is rejected, since that emit differs from the f32 goldens); emit_graphs is the golden-less freeze primitive for a new family. The emitter's own byte-exact test now delegates to check_arch so the goldens have a single source of truth. Demonstrated clean on llama-3.2-1b. Execution tier (opt-in, needs a real IREE build and a checkpoint): scripts/xla/validate_arch.sh is the one command. It produces the HF fp32 oracle with spike/openxla/oracle_continuation.py (loads fp32, dequantizing an MLX 4bit/8bit checkpoint offline first with the same affine formula as weights.rs), then runs xla_oracle_check (single-seq greedy == HF oracle) and xla_batch_bench (each batched request == its single-seq reference), reporting pass/fail for both. The offline-dequant math is self-tested against the Rust loader (oracle_continuation.py --selftest). Usage and the turnkey add-a-family workflow are documented in the validation module docs and src/lib/mlxcel-xla/README.md. --- scripts/xla/validate_arch.sh | 127 +++++++ spike/openxla/oracle_continuation.py | 290 ++++++++++++++++ src/lib/mlxcel-xla/README.md | 55 +++ src/lib/mlxcel-xla/src/emitter/mod.rs | 39 +-- src/lib/mlxcel-xla/src/lib.rs | 8 + src/lib/mlxcel-xla/src/validation.rs | 483 ++++++++++++++++++++++++++ 6 files changed, 973 insertions(+), 29 deletions(-) create mode 100755 scripts/xla/validate_arch.sh create mode 100644 spike/openxla/oracle_continuation.py create mode 100644 src/lib/mlxcel-xla/src/validation.rs diff --git a/scripts/xla/validate_arch.sh b/scripts/xla/validate_arch.sh new file mode 100755 index 00000000..d8ad951b --- /dev/null +++ b/scripts/xla/validate_arch.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash +# One-command per-architecture validation for the OpenXLA / IREE backend (issue #496). +# +# Given a checkpoint, this drives every gate an architecture must pass: +# +# [structural] byte-exact emitter regression (pure Rust, no GPU) - a fast +# pre-gate that the registered fixtures still emit their goldens. +# [gate 1/2] token-exact single-sequence vs an HF fp32 oracle (xla_oracle_check). +# [gate 2/2] serve reference-exact: every batched request equals its single-seq +# reference (xla_batch_bench). +# +# The oracle is produced here too: spike/openxla/oracle_continuation.py loads the +# checkpoint in fp32 (dequantizing an MLX 4-bit/8-bit checkpoint offline first) and +# records the greedy continuation the token-exact gate diffs against. +# +# The two execution gates need a real IREE build (the `xla-iree` feature), so set +# the IREE environment first (see src/lib/mlxcel-xla/README.md): +# - CPU (prebuilt dist): export IREE_DIST=/path/to/iree-dist +# - CUDA (GB10): export IREE_CUDA_HOME=... IREE_CUDA_COMPILE=... (--device cuda) +# - macOS (Metal): eval "$(scripts/iree/setup-macos.sh --env)" +# The structural pre-gate needs none of that; run it alone with --structural-only. +# +# Usage: +# scripts/xla/validate_arch.sh --model [options] +# +# Options: +# --model checkpoint directory (required) +# --device IREE HAL device (default: $MLXCEL_XLA_DEVICE or local-task) +# --prompt oracle prompt (default: "The capital of France is") +# --max-new oracle continuation length / token-exact steps (default: 40) +# --batch serve B_max slots (default: 4) +# --requests serve request count (default: 2*batch) +# --maxcap serve per-request token budget clamp (default: 24) +# --oracle oracle JSON path (default: a temp file) +# --skip-structural skip the byte-exact pre-gate +# --structural-only run only the byte-exact pre-gate (no IREE build, no GPU) +# -h, --help this help +# +# Exit status: 0 only if every run gate passes. +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +# The oracle venv (spike/openxla/README.md) is gitignored; override its python with +# MLXCEL_ORACLE_PYTHON when it lives outside this checkout (e.g. a shared setup). +VENV_PY="${MLXCEL_ORACLE_PYTHON:-$REPO_ROOT/spike/openxla/.venv/bin/python}" +ORACLE_PY="$REPO_ROOT/spike/openxla/oracle_continuation.py" + +usage() { sed -n '2,/^set -euo/p' "${BASH_SOURCE[0]}" | sed 's/^# \{0,1\}//; $d'; } + +MODEL="" +DEVICE="${MLXCEL_XLA_DEVICE:-local-task}" +PROMPT="The capital of France is" +MAXNEW=40 +BATCH=4 +REQUESTS="" +MAXCAP=24 +ORACLE="" +SKIP_STRUCTURAL=0 +STRUCTURAL_ONLY=0 + +while [ $# -gt 0 ]; do + case "$1" in + --model) MODEL="${2:?}"; shift 2 ;; + --device) DEVICE="${2:?}"; shift 2 ;; + --prompt) PROMPT="${2:?}"; shift 2 ;; + --max-new) MAXNEW="${2:?}"; shift 2 ;; + --batch) BATCH="${2:?}"; shift 2 ;; + --requests) REQUESTS="${2:?}"; shift 2 ;; + --maxcap) MAXCAP="${2:?}"; shift 2 ;; + --oracle) ORACLE="${2:?}"; shift 2 ;; + --skip-structural) SKIP_STRUCTURAL=1; shift ;; + --structural-only) STRUCTURAL_ONLY=1; shift ;; + -h|--help) usage; exit 0 ;; + *) echo "error: unknown argument: $1" >&2; usage; exit 2 ;; + esac +done + +# --- structural pre-gate (fast, pure Rust, no GPU / IREE) --- +if [ "$SKIP_STRUCTURAL" -eq 0 ]; then + echo "== [structural] byte-exact emitter regression (cargo test) ==" + cargo test -p mlxcel-xla --lib \ + validation::tests::registered_fixtures_are_byte_exact -- --nocapture + echo "[structural] PASS" +fi +if [ "$STRUCTURAL_ONLY" -eq 1 ]; then + echo "RESULT: PASS (structural only)" + exit 0 +fi + +# --- validation of the execution-tier inputs --- +[ -n "$MODEL" ] || { echo "error: --model is required" >&2; usage; exit 2; } +[ -d "$MODEL" ] || { echo "error: model directory not found: $MODEL" >&2; exit 2; } +[ -x "$VENV_PY" ] || { echo "error: oracle venv python not found at $VENV_PY (see spike/openxla/README.md)" >&2; exit 3; } +REQUESTS="${REQUESTS:-$((2 * BATCH))}" +ORACLE="${ORACLE:-$(mktemp -t xla_oracle.XXXXXX.json)}" + +echo "== validate_arch: model=$MODEL device=$DEVICE ==" + +# --- Tier 1: produce the HF fp32 oracle (offline dequant if MLX-quantized) --- +echo "== [oracle] producing HF fp32 continuation -> $ORACLE ==" +"$VENV_PY" "$ORACLE_PY" --model "$MODEL" --out "$ORACLE" \ + --prompt "$PROMPT" --max-new "$MAXNEW" + +# --- Gate 1/2: token-exact single-sequence vs the oracle --- +echo "== [gate 1/2] token-exact single-seq (xla_oracle_check) ==" +gate1=0 +cargo run --release --features xla-iree --example xla_oracle_check -- \ + --model "$MODEL" --oracle "$ORACLE" --device "$DEVICE" || gate1=$? + +# --- Gate 2/2: serve reference-exact (reuses the oracle's prompt_ids) --- +echo "== [gate 2/2] serve reference-exact (xla_batch_bench) ==" +gate2=0 +cargo run --release --features xla-iree --example xla_batch_bench -- \ + --model "$MODEL" --prompts "$ORACLE" --device "$DEVICE" \ + --batch "$BATCH" --requests "$REQUESTS" --maxcap "$MAXCAP" || gate2=$? + +echo "" +echo "== summary ==" +[ "$SKIP_STRUCTURAL" -eq 0 ] && echo "structural (byte-exact) : PASS" || echo "structural (byte-exact) : SKIPPED" +[ "$gate1" -eq 0 ] && echo "token-exact single-seq : PASS" || echo "token-exact single-seq : FAIL" +[ "$gate2" -eq 0 ] && echo "serve reference-exact : PASS" || echo "serve reference-exact : FAIL" +if [ "$gate1" -eq 0 ] && [ "$gate2" -eq 0 ]; then + echo "RESULT: PASS (both execution gates)" + exit 0 +fi +echo "RESULT: FAIL" +exit 1 diff --git a/spike/openxla/oracle_continuation.py b/spike/openxla/oracle_continuation.py new file mode 100644 index 00000000..d8399f62 --- /dev/null +++ b/spike/openxla/oracle_continuation.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +"""Architecture-generic HF fp32 greedy oracle for the OpenXLA token-exactness gate +(issue #496). + +Given a checkpoint (bf16 / f16 / f32, or an MLX affine 4-bit / 8-bit quantized +checkpoint), this produces the external reference the `xla_oracle_check` example +diffs against: the pure next-token-argmax trajectory for N steps with NO EOS stop, +loaded and run in fp32 (the exact widening the XLA path applies to its weights). +The first generated token is the argmax after the FULL prompt, matching +`XlaReferenceEngine::prefill_first`. + +For a quantized checkpoint it first dequantizes the packed weights to f32 offline +(the "offline dequant to f32 oracle" step), using the same affine formula the Rust +loader applies (`src/lib/mlxcel-xla/src/weights.rs::dequantize_affine`: +`w = q * scale + bias`, `q` unpacked low-order-first, f16 scales/biases), so the +oracle weights match what the engine dequantizes to and token-exactness is +meaningful. + +Usage: + # bf16 / f32 checkpoint (no dequant needed): + spike/openxla/.venv/bin/python spike/openxla/oracle_continuation.py \\ + --model /models/qwen2.5-0.5b-bf16 --out /tmp/oracle.json \\ + --prompt "The capital of France is" --max-new 40 + + # MLX 4-bit / 8-bit checkpoint (dequantized to f32 first): + ... --model /models/qwen2.5-0.5b-4bit --out /tmp/oracle.json + + # verify the dequant math mirrors the Rust loader (numpy only, no model): + ... --selftest + +Output JSON: {"prompt_text", "prompt_ids": [int...], "ref_token_ids": [int...]}. +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import sys +import tempfile + +import numpy as np + + +def dequantize_affine( + packed: np.ndarray, + scales: np.ndarray, + biases: np.ndarray, + bits: int, + group_size: int, +) -> np.ndarray: + """Dequantize one MLX affine-quantized weight to row-major ``[out, in]`` f32. + + Mirrors ``src/lib/mlxcel-xla/src/weights.rs::dequantize_affine`` exactly so the + oracle and the XLA engine dequantize to identical f32 weights. + + Args: + packed: ``[out, in_packed]`` uint32 weight (``in_packed = in * bits / 32``), + with ``bits``-wide codes packed low-order-first per u32. + scales: ``[out, in / group_size]`` per-group scale (widened to f32). + biases: ``[out, in / group_size]`` per-group bias (widened to f32). + bits: 4 or 8. + group_size: input columns sharing one scale/bias. + + Returns: + ``[out, in]`` float32 array, ``w[o, i] = q[o, i] * scale[o, i // group_size] + + bias[o, i // group_size]``. + """ + if bits not in (4, 8): + raise ValueError(f"unsupported quantization bits {bits} (expected 4 or 8)") + out, in_packed = packed.shape + per_u32 = 32 // bits + in_ = in_packed * per_u32 + n_groups = scales.shape[1] + if group_size <= 0 or in_ != n_groups * group_size: + raise ValueError( + f"group_size {group_size} x n_groups {n_groups} != in dimension {in_}" + ) + mask = np.uint32((1 << bits) - 1) + shifts = np.arange(per_u32, dtype=np.uint32) * np.uint32(bits) + # [out, in_packed, per_u32]: code j of packed[o, p] is (u >> (bits*j)) & mask. + codes = (packed[:, :, None] >> shifts[None, None, :]) & mask + codes = codes.reshape(out, in_).astype(np.float32) + scale_full = np.repeat(scales.astype(np.float32), group_size, axis=1) + bias_full = np.repeat(biases.astype(np.float32), group_size, axis=1) + return codes * scale_full + bias_full + + +def _selftest() -> None: + """Assert the numpy dequant matches the Rust ``weights.rs`` hand-examples.""" + # 4-bit: u32 0x8765_4321 -> nibbles [1..8] low-first; groups (2.0, +10), (0.5, -1). + packed4 = np.array([[0x87654321]], dtype=np.uint32) + scales = np.array([[2.0, 0.5]], dtype=np.float32) + biases = np.array([[10.0, -1.0]], dtype=np.float32) + w4 = dequantize_affine(packed4, scales, biases, 4, 4) + want4 = np.array([[12.0, 14.0, 16.0, 18.0, 1.5, 2.0, 2.5, 3.0]], dtype=np.float32) + assert np.array_equal(w4, want4), f"4-bit mismatch: {w4}" + + # 8-bit: u32 0x281E_140A -> bytes [10, 20, 30, 40] low-first; same two groups. + packed8 = np.array([[0x281E140A]], dtype=np.uint32) + w8 = dequantize_affine(packed8, scales, biases, 8, 2) + want8 = np.array([[30.0, 50.0, 14.0, 19.0]], dtype=np.float32) + assert np.array_equal(w8, want8), f"8-bit mismatch: {w8}" + + +def read_quantization(model_dir: str) -> dict | None: + """Return ``{"bits", "group_size"}`` if ``config.json`` marks an MLX-quantized + checkpoint, else ``None``.""" + with open(os.path.join(model_dir, "config.json")) as f: + cfg = json.load(f) + q = cfg.get("quantization") or cfg.get("quantization_config") + if not q: + return None + return {"bits": int(q["bits"]), "group_size": int(q["group_size"])} + + +def _shard_files(model_dir: str) -> list[tuple[str, list[str]]]: + """``[(safetensors_path, [tensor_name...])]`` for a single-file or sharded + checkpoint.""" + index = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.exists(index): + with open(index) as f: + weight_map = json.load(f)["weight_map"] + by_file: dict[str, list[str]] = {} + for name, filename in weight_map.items(): + by_file.setdefault(filename, []).append(name) + return [(os.path.join(model_dir, fn), names) for fn, names in by_file.items()] + single = os.path.join(model_dir, "model.safetensors") + if os.path.exists(single): + from safetensors import safe_open + + with safe_open(single, framework="numpy") as f: + return [(single, list(f.keys()))] + raise FileNotFoundError(f"no model.safetensors or index.json in {model_dir}") + + +def dequant_checkpoint(model_dir: str, out_dir: str, bits: int, group_size: int) -> None: + """Write an HF-loadable f32 checkpoint (dequantized weights, no quantization + block) into ``out_dir``.""" + from safetensors import safe_open + from safetensors.numpy import save_file + + os.makedirs(out_dir, exist_ok=True) + tensors: dict[str, np.ndarray] = {} + for path, names in _shard_files(model_dir): + nameset = set(names) + # numpy handle for the uint32 packed weights; torch handle so bf16 / f16 + # copy-through tensors and f16 scales/biases widen to f32 exactly. + with ( + safe_open(path, framework="numpy") as fnp, + safe_open(path, framework="pt") as fpt, + ): + for name in names: + if name.endswith(".scales") or name.endswith(".biases"): + continue # consumed with the paired .weight + base = name[: -len(".weight")] if name.endswith(".weight") else None + if base is not None and f"{base}.scales" in nameset: + packed = fnp.get_tensor(name).astype(np.uint32) + scales = fpt.get_tensor(f"{base}.scales").float().numpy() + biases = fpt.get_tensor(f"{base}.biases").float().numpy() + tensors[name] = dequantize_affine( + packed, scales, biases, bits, group_size + ) + else: + tensors[name] = fpt.get_tensor(name).float().numpy() + save_file(tensors, os.path.join(out_dir, "model.safetensors")) + + with open(os.path.join(model_dir, "config.json")) as f: + cfg = json.load(f) + cfg.pop("quantization", None) + cfg.pop("quantization_config", None) + with open(os.path.join(out_dir, "config.json"), "w") as f: + json.dump(cfg, f, indent=2) + + # Tokenizer + generation config so the dequantized dir loads like the original. + passthrough = { + "generation_config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + "added_tokens.json", + "chat_template.jinja", + } + for fn in os.listdir(model_dir): + if fn.startswith("tokenizer") or fn in passthrough: + shutil.copy(os.path.join(model_dir, fn), os.path.join(out_dir, fn)) + + +def hf_greedy_oracle(model_dir: str, prompt: str, n_new: int) -> dict: + """Run HF fp32 greedy (pure argmax, no EOS stop) for ``n_new`` steps.""" + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_dir) + model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float32) + model.eval() + + prompt_ids = tok(prompt, return_tensors="pt").input_ids # [1, L] + ids = prompt_ids.clone() + ref: list[int] = [] + with torch.no_grad(): + for _ in range(n_new): + logits = model(ids).logits[:, -1, :] # [1, V] + nxt = int(torch.argmax(logits, dim=-1).item()) + ref.append(nxt) + ids = torch.cat([ids, torch.tensor([[nxt]])], dim=1) + return { + "prompt_text": prompt, + "prompt_ids": prompt_ids[0].tolist(), + "ref_token_ids": ref, + "decoded": tok.decode(ref), + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + ap.add_argument("--model", help="checkpoint directory") + ap.add_argument("--out", help="output oracle JSON path") + ap.add_argument("--prompt", default="The capital of France is") + ap.add_argument("--max-new", type=int, default=40) + ap.add_argument( + "--dequant-dir", + default=None, + help="where to write the dequantized f32 checkpoint (default: a temp dir " + "removed on exit); implies --keep-dequant if set", + ) + ap.add_argument( + "--keep-dequant", + action="store_true", + help="do not remove the dequantized checkpoint afterwards", + ) + ap.add_argument( + "--selftest", + action="store_true", + help="verify the dequant math mirrors the Rust loader (numpy only), then exit", + ) + args = ap.parse_args() + + if args.selftest: + _selftest() + print("oracle_continuation self-test: OK (dequant matches weights.rs)") + return 0 + + if not args.model or not args.out: + ap.error("--model and --out are required (unless --selftest)") + + quant = read_quantization(args.model) + source = args.model + tmp_dir = None + if quant is not None: + target = args.dequant_dir or tempfile.mkdtemp(prefix="mlxcel-dequant-") + tmp_dir = None if (args.dequant_dir or args.keep_dequant) else target + print( + f"[oracle] {args.model} is MLX-quantized (bits={quant['bits']}, " + f"group_size={quant['group_size']}); dequantizing to f32 -> {target}", + flush=True, + ) + dequant_checkpoint(args.model, target, quant["bits"], quant["group_size"]) + source = target + + print( + f"[oracle] HF fp32 greedy from {source}: prompt={args.prompt!r} " + f"max_new={args.max_new}", + flush=True, + ) + result = hf_greedy_oracle(source, args.prompt, args.max_new) + with open(args.out, "w") as f: + json.dump( + { + "prompt_text": result["prompt_text"], + "prompt_ids": result["prompt_ids"], + "ref_token_ids": result["ref_token_ids"], + }, + f, + ) + print( + f"[oracle] wrote {args.out}: {len(result['prompt_ids'])} prompt tokens, " + f"{len(result['ref_token_ids'])} reference tokens", + flush=True, + ) + print(f"[oracle] continuation: {result['decoded']!r}", flush=True) + + if tmp_dir is not None: + shutil.rmtree(tmp_dir, ignore_errors=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/lib/mlxcel-xla/README.md b/src/lib/mlxcel-xla/README.md index ac93bb46..02332a2f 100644 --- a/src/lib/mlxcel-xla/README.md +++ b/src/lib/mlxcel-xla/README.md @@ -209,6 +209,60 @@ MLXCEL_XLA_DEVICE=cuda cargo run --release --features xla-iree \ --example xla_batch_bench -- --device cuda --batch 8 --requests 16 --maxcap 48 ``` +## Per-architecture validation harness (#496) + +Validating an architecture has two tiers, split so the cheap one can gate every +change while the expensive one stays opt-in. Both are reusable, so adding a family +is turnkey. + +**Structural (byte-exact, pure Rust, no GPU).** The emitter must reproduce a +frozen StableHLO golden for each registered architecture, byte for byte. This is +the fast regression gate that catches a graph which drifted from its validated +form. It is the `src/validation.rs` module, run as a normal test: + +```bash +cargo test -p mlxcel-xla --lib validation +``` + +`validation::REGISTERED` lists the fixtures (currently `llama-3.2-1b`, whose +goldens live in `assets/llama-3.2-1b/`). The check honors `MLXCEL_XLA_PRECISION`: +the goldens are the default `f32` graphs, so a byte-exact run under `f16` / `bf16` +is rejected with a clear message instead of a confusing diff. + +**Execution (token-exact + reference-exact).** One command produces the HF fp32 +oracle and runs both run gates (needs a real IREE build and a checkpoint): + +```bash +export IREE_DIST=/path/to/iree-dist # or IREE_CUDA_HOME=... for CUDA +scripts/xla/validate_arch.sh --model --device local-task +``` + +It (0) runs the structural pre-gate, (1) produces the oracle with +`spike/openxla/oracle_continuation.py` (loads the checkpoint in fp32, +dequantizing an MLX 4-bit / 8-bit checkpoint offline first with the same affine +formula as `src/weights.rs`), (2) runs `xla_oracle_check` (single-seq greedy == +HF oracle), and (3) runs `xla_batch_bench` (every batched request == its +single-seq reference). It exits non-zero if any run gate fails. +`--structural-only` runs just the pure-Rust pre-gate (no IREE, no GPU). + +### Adding a family is turnkey + +1. Emit correctly: extend `Config::from_json` and the emitter for the + architecture (the structural invariant tests in `emitter/mod.rs` cover the new + switches: RoPE kind, q/k/v bias, tied / untied head, soft-caps). +2. Prove it: `scripts/xla/validate_arch.sh --model ` must report a + clean token-exact + reference-exact pass on a real checkpoint. +3. Freeze goldens (optional, for a byte-exact CI guard): emit each graph with + `validation::emit_graphs(config_json, kinds)` and write them to + `assets//*.mlir`; an already-registered family re-freezes in place with + `MLXCEL_FREEZE_GOLDENS=1 cargo test -p mlxcel-xla --lib validation::tests::freeze_goldens`. +4. Register: add an `ArchFixture` to `validation::REGISTERED`; the byte-exact test + then guards it forever. + +Not every family bundles goldens: Qwen2.5 (`assets/qwen2.5-0.5b/`) is emitted at +load and covered by the emitter's structural tests plus the execution tier, with +no committed `.mlir`. The harness still drives its emit through `emit_graphs`. + ## File map | Path | Purpose | @@ -216,6 +270,7 @@ MLXCEL_XLA_DEVICE=cuda cargo run --release --features xla-iree \ | `src/lib.rs` | `XlaInferenceSession`: the single-sequence `InferenceSession` impl + greedy drive loop. | | `src/iree.rs` | (feature `iree`) FFI to the shim; `IreeLlama` (single-seq) and `IreeRaggedLlama` (batched) load weights, compile + run the graphs. | | `src/batch.rs` | (feature `iree`) `XlaBatchEngine`: the continuous-batching engine (slots + queue + admit/decode/evict) and `XlaReferenceEngine` (single-seq reference for validation). The backend-neutral `Scheduler` bookkeeping is unit-tested without IREE. | +| `src/validation.rs` | (issue #496) Reusable per-architecture structural harness: the `ArchFixture` registry + `check_arch` byte-exact golden gate + `emit_graphs` freeze primitive. Pure Rust; runs under `cargo test`. The execution tier lives in `scripts/xla/validate_arch.sh` + `spike/openxla/oracle_continuation.py`. | | `csrc/xla_iree.c` | C shim over the IREE runtime C API (one session, resident weights, threaded KV; single-seq `prefill`/`decode` plus the ragged `prefill_slot`/`decode_ragged` with a device-side per-slot KV write). | | `build.rs` | (feature `iree`) compiles the shim against `IREE_DIST` headers. The runtime link recipe lives in the **root** `mlxcel/build.rs` (a dependency's link-args do not propagate to the binary). | | `assets/llama-3.2-1b/` | The #451-emitted `prefill` / `decode_step` StableHLO graphs (on-device-argmax variant) plus the ragged `decode_ragged_b{4,8}` graphs, compiled to vmfbs at load. | diff --git a/src/lib/mlxcel-xla/src/emitter/mod.rs b/src/lib/mlxcel-xla/src/emitter/mod.rs index feb41f15..c366ac54 100644 --- a/src/lib/mlxcel-xla/src/emitter/mod.rs +++ b/src/lib/mlxcel-xla/src/emitter/mod.rs @@ -40,7 +40,7 @@ mod model; mod rope; pub(crate) use config::Config; -pub(crate) use model::{emit_decode, emit_decode_ragged, emit_prefill}; +pub(crate) use model::{emit_decode, emit_decode_batched, emit_decode_ragged, emit_prefill}; #[cfg(test)] mod tests { @@ -48,11 +48,6 @@ mod tests { use super::*; const CONFIG_JSON: &str = include_str!("../../assets/llama-3.2-1b/config.json"); - const PREFILL: &str = include_str!("../../assets/llama-3.2-1b/prefill.mlir"); - const DECODE: &str = include_str!("../../assets/llama-3.2-1b/decode.mlir"); - const PREFILL_LOGITS: &str = include_str!("../../assets/llama-3.2-1b/prefill_logits.mlir"); - const RAGGED_B4: &str = include_str!("../../assets/llama-3.2-1b/decode_ragged_logits_b4.mlir"); - const RAGGED_B8: &str = include_str!("../../assets/llama-3.2-1b/decode_ragged_logits_b8.mlir"); const QWEN_CONFIG_JSON: &str = include_str!("../../assets/qwen2.5-0.5b/config.json"); fn occurs(haystack: &str, needle: &str) -> usize { @@ -119,31 +114,17 @@ mod tests { } } - /// The whole point of Stage A: a `Config` parsed from the real - /// Llama-3.2-1B-Instruct `config.json` emits every bundled graph - /// byte-for-byte. This proves the load-time emit path reproduces the assets - /// the engine shipped with, so switching from `include_str!` to emit-at-load - /// cannot change the compiled graphs for this model. + /// A `Config` parsed from the real Llama-3.2-1B-Instruct `config.json` emits + /// every bundled graph byte-for-byte, so switching from `include_str!` to + /// emit-at-load cannot change the compiled graphs for this model. Asserted + /// through the reusable per-architecture validation harness (issue #496), + /// which owns the golden fixtures (`crate::validation::LLAMA_3_2_1B`); adding + /// a family is then a registry row rather than a copy of this test. #[test] fn from_json_reproduces_bundled_assets_byte_for_byte() { - let c = Config::from_json_str(CONFIG_JSON).expect("parse Llama-3.2-1B config.json"); - assert_eq!(emit_prefill(&c, true), PREFILL, "prefill.mlir (argmax)"); - assert_eq!(emit_decode(&c, true), DECODE, "decode.mlir (argmax)"); - assert_eq!( - emit_prefill(&c, false), - PREFILL_LOGITS, - "prefill_logits.mlir" - ); - assert_eq!( - emit_decode_ragged(&c, 4, false), - RAGGED_B4, - "decode_ragged_logits_b4.mlir" - ); - assert_eq!( - emit_decode_ragged(&c, 8, false), - RAGGED_B8, - "decode_ragged_logits_b8.mlir" - ); + let report = crate::validation::check_arch(&crate::validation::LLAMA_3_2_1B) + .expect("llama-3.2-1b fixture parses at the default precision"); + assert!(report.passed(), "{report}"); } /// `from_json` reads the same values the spike hard-coded, so it emits diff --git a/src/lib/mlxcel-xla/src/lib.rs b/src/lib/mlxcel-xla/src/lib.rs index 933a7e9d..c2ec1593 100644 --- a/src/lib/mlxcel-xla/src/lib.rs +++ b/src/lib/mlxcel-xla/src/lib.rs @@ -83,6 +83,14 @@ mod weights; #[allow(dead_code)] mod emitter; +// Reusable per-architecture validation harness (issue #496). Pure Rust; present +// under `iree` (so the harness is available to tooling) and under `test` (the +// byte-exact structural gate runs here). The engine never calls it, so dead_code +// is allowed under a non-test `iree` build, matching the sibling modules above. +#[cfg(any(feature = "iree", test))] +#[allow(dead_code)] +mod validation; + #[cfg(feature = "iree")] pub use batch::{EngineEvent, FinishReason, XlaBatchEngine, XlaReferenceEngine}; #[cfg(feature = "iree")] diff --git a/src/lib/mlxcel-xla/src/validation.rs b/src/lib/mlxcel-xla/src/validation.rs new file mode 100644 index 00000000..4eb50082 --- /dev/null +++ b/src/lib/mlxcel-xla/src/validation.rs @@ -0,0 +1,483 @@ +// Copyright 2025-2026 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Reusable per-architecture validation harness (issue #496). +//! +//! Adding an architecture to the OpenXLA backend needs two kinds of proof, and +//! this crate splits them into two tiers so the cheap one can gate every change +//! and the expensive one stays opt-in: +//! +//! 1. **Structural (byte-exact) - this module.** The Rust emitter must reproduce +//! a frozen, trusted StableHLO graph for the architecture, byte for byte. It +//! is pure Rust (no IREE, no GPU, no checkpoint), so it is the cheap +//! regression gate that catches a graph which drifted from its validated form. +//! [`check_arch`] parses a fixture's `config.json`, emits each declared graph, +//! and diffs it against the committed golden `assets//*.mlir`, localizing +//! the first differing line on a mismatch. +//! +//! 2. **Execution (token-exact / reference-exact) - outside this crate.** The +//! compiled graph must run token-exact against an HF fp32 oracle (single +//! sequence) and reference-exact through the batched serve path. That needs +//! real IREE execution and a checkpoint, so it lives as the `xla_oracle_check` +//! and `xla_batch_bench` examples, driven end to end (oracle included) by +//! `scripts/xla/validate_arch.sh`. See that script and the crate `README.md`. +//! +//! # Adding a family is turnkey +//! +//! Freezing goldens is a one-time authoring step, gated on the execution tier +//! having proven the emit is correct for a real checkpoint: +//! +//! 1. Prove token-exactness with `scripts/xla/validate_arch.sh --model `. +//! 2. Freeze the now-trusted graphs to `assets//`: [`emit_graphs`] emits +//! each graph from the checkpoint's `config.json` with no committed golden +//! required, so it bootstraps a brand-new family; write each result to its +//! `.mlir` file (the [`ArchFixture::emit_all`] / `freeze_goldens` path +//! re-freezes an already-registered family). +//! 3. Register the fixture: add an [`ArchFixture`] pairing the config with its +//! golden `.mlir` files and append it to [`REGISTERED`]. The data-driven test +//! then guards it byte-for-byte forever. +//! +//! Not every family bundles goldens: an architecture whose graphs are emitted at +//! load (Qwen2.5, see `assets/qwen2.5-0.5b/README.md`) is covered by the +//! emitter's structural invariant tests and the execution tier instead. The +//! harness still drives its emit through [`emit_graphs`] (see the tests below). +//! +//! The emitter reads `MLXCEL_XLA_PRECISION` at emit time (default `f32`); the +//! committed goldens are the default-precision graphs, so [`check_arch`] rejects +//! a byte-exact run under a non-default precision (`f16` / `bf16`) with a clear +//! error rather than reporting a confusing diff. + +use crate::emitter::{Config, emit_decode, emit_decode_batched, emit_decode_ragged, emit_prefill}; + +/// One emitted graph kind, matching the emitter's entry points. `sample = true` +/// ends the graph in an on-device argmax returning a token id (the single-sequence +/// session path); `sample = false` returns the raw logits (the batched serve path +/// samples on the host). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum GraphKind { + /// Bucketed prompt prefill ([`emit_prefill`]). + Prefill { sample: bool }, + /// Single-token decode step ([`emit_decode`]). + Decode { sample: bool }, + /// Ragged (continuous-batching) decode for `b_max` slots ([`emit_decode_ragged`]). + DecodeRagged { b_max: usize, sample: bool }, + /// Uniform-`b_max` batched decode ([`emit_decode_batched`]). + DecodeBatched { b_max: usize, sample: bool }, +} + +impl GraphKind { + /// Emit this graph for `cfg`, at the ambient `MLXCEL_XLA_PRECISION`. + fn emit(self, cfg: &Config) -> String { + match self { + GraphKind::Prefill { sample } => emit_prefill(cfg, sample), + GraphKind::Decode { sample } => emit_decode(cfg, sample), + GraphKind::DecodeRagged { b_max, sample } => emit_decode_ragged(cfg, b_max, sample), + GraphKind::DecodeBatched { b_max, sample } => emit_decode_batched(cfg, b_max, sample), + } + } +} + +impl core::fmt::Display for GraphKind { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + GraphKind::Prefill { sample } => write!(f, "prefill(sample={sample})"), + GraphKind::Decode { sample } => write!(f, "decode(sample={sample})"), + GraphKind::DecodeRagged { b_max, sample } => { + write!(f, "decode_ragged(b={b_max}, sample={sample})") + } + GraphKind::DecodeBatched { b_max, sample } => { + write!(f, "decode_batched(b={b_max}, sample={sample})") + } + } + } +} + +/// One golden graph within an architecture fixture: which graph to emit and the +/// expected MLIR (the committed `assets//`). +pub(crate) struct GraphFixture { + /// The graph kind to emit and compare. + pub kind: GraphKind, + /// File name under `assets//`, used in diagnostics and by the freeze path. + pub golden_name: &'static str, + /// The golden MLIR text (`include_str!` of the committed asset). + pub golden: &'static str, +} + +/// A per-architecture structural fixture: a checkpoint's `config.json` and the +/// set of graphs the emitter must reproduce byte-for-byte. +pub(crate) struct ArchFixture { + /// Directory name under `assets/` and id in reports (e.g. `"llama-3.2-1b"`). + pub arch: &'static str, + /// The model's `config.json` text (`include_str!` of the committed asset). + pub config_json: &'static str, + /// The golden graphs this architecture pins. + pub graphs: &'static [GraphFixture], +} + +impl ArchFixture { + /// Emit every declared graph, returning `(golden_name, mlir)` pairs. The + /// re-freeze primitive for an already-registered family: write each pair to + /// `assets//` to refresh the goldens after an intentional, + /// execution-tier-validated emitter change. Emits at the ambient + /// `MLXCEL_XLA_PRECISION`. + /// + /// # Errors + /// Returns the config-parse error if `config_json` is not a supported arch. + pub fn emit_all(&self) -> Result, String> { + let cfg = + Config::from_json_str(self.config_json).map_err(|e| format!("{}: {e}", self.arch))?; + Ok(self + .graphs + .iter() + .map(|g| (g.golden_name, g.kind.emit(&cfg))) + .collect()) + } +} + +/// Emit each requested graph for the architecture in `config_json`, at the ambient +/// `MLXCEL_XLA_PRECISION`. The authoring/bootstrap primitive: it needs no committed +/// golden, so it can freeze a brand-new family before its [`ArchFixture`] exists. +/// +/// # Errors +/// Returns the config-parse error if `config_json` is not a supported architecture. +pub(crate) fn emit_graphs( + config_json: &str, + kinds: &[GraphKind], +) -> Result, String> { + let cfg = Config::from_json_str(config_json)?; + Ok(kinds.iter().map(|&k| (k, k.emit(&cfg))).collect()) +} + +/// The first line at which two MLIR texts differ (1-based), with both sides, so a +/// frozen-reference drift is easy to localize. +pub(crate) struct LineDiff { + /// 1-based line number of the first difference. + pub line: usize, + /// The golden's line at `line` (`` if the golden is shorter). + pub expected: String, + /// The emitted line at `line` (`` if the emit is shorter). + pub actual: String, + /// Total lines in the golden. + pub expected_lines: usize, + /// Total lines in the emitted graph. + pub actual_lines: usize, +} + +/// The outcome of checking one golden graph. +pub(crate) struct GraphOutcome { + /// The graph kind that was emitted. + pub kind: GraphKind, + /// The golden file name it was compared against. + pub golden_name: &'static str, + /// `None` on a byte-exact match; the first differing line otherwise. + pub diff: Option, +} + +impl GraphOutcome { + /// The emit matched the golden byte-for-byte. + pub fn matched(&self) -> bool { + self.diff.is_none() + } +} + +/// The outcome of checking one architecture fixture. +pub(crate) struct ArchReport { + /// The fixture's architecture id. + pub arch: &'static str, + /// Per-graph outcomes, in fixture order. + pub graphs: Vec, +} + +impl ArchReport { + /// Every golden matched byte-for-byte. + pub fn passed(&self) -> bool { + self.graphs.iter().all(GraphOutcome::matched) + } +} + +impl core::fmt::Display for ArchReport { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + writeln!( + f, + "arch {}: {}", + self.arch, + if self.passed() { "PASS" } else { "FAIL" } + )?; + for g in &self.graphs { + match &g.diff { + None => writeln!(f, " ok {} ({})", g.golden_name, g.kind)?, + Some(d) => writeln!( + f, + " DIFF {} ({}): first mismatch at line {} ({} golden lines vs {} emitted)\n\ + \x20 golden : {}\n\ + \x20 emitted: {}", + g.golden_name, + g.kind, + d.line, + d.expected_lines, + d.actual_lines, + d.expected, + d.actual + )?, + } + } + Ok(()) + } +} + +/// Whether an `MLXCEL_XLA_PRECISION` value selects the default (`f32`) emit that +/// the committed goldens were frozen at. Only `"f16"` / `"bf16"` are non-default, +/// mirroring `emitter::builder::precision_from_env`. +fn precision_is_default(value: Option<&str>) -> bool { + !matches!(value, Some("f16") | Some("bf16")) +} + +/// Reject a byte-exact run under a non-default precision: the goldens are the f32 +/// graphs, so an `f16` / `bf16` emit would diff confusingly against them. +fn ambient_precision_is_default() -> Result<(), String> { + let v = std::env::var("MLXCEL_XLA_PRECISION").ok(); + if precision_is_default(v.as_deref()) { + Ok(()) + } else { + Err(format!( + "the structural goldens are the default-precision (f32) graphs; unset \ + MLXCEL_XLA_PRECISION (currently {:?}) to run the byte-exact check", + v.unwrap_or_default() + )) + } +} + +/// Parse a fixture's `config.json`, emit each declared graph, and compare it to the +/// committed golden byte-for-byte, localizing the first differing line on a +/// mismatch. +/// +/// The committed goldens are the default-precision (`f32`) graphs, so this errors +/// if `MLXCEL_XLA_PRECISION` selects `f16` / `bf16` (which would emit a different +/// graph). Inspect [`ArchReport::passed`]; the `Display` renders a per-graph +/// pass/fail summary with the first-diff location. +/// +/// # Errors +/// Returns an error if the ambient precision is not the default the goldens were +/// frozen at, or if `config.json` does not parse to a supported architecture. +pub(crate) fn check_arch(fx: &ArchFixture) -> Result { + ambient_precision_is_default()?; + let cfg = Config::from_json_str(fx.config_json).map_err(|e| format!("{}: {e}", fx.arch))?; + let graphs = fx + .graphs + .iter() + .map(|g| { + let emitted = g.kind.emit(&cfg); + GraphOutcome { + kind: g.kind, + golden_name: g.golden_name, + diff: first_line_diff(g.golden, &emitted), + } + }) + .collect(); + Ok(ArchReport { + arch: fx.arch, + graphs, + }) +} + +/// The first line at which `golden` and `emitted` differ, or `None` if identical. +fn first_line_diff(golden: &str, emitted: &str) -> Option { + if golden == emitted { + return None; + } + let (mut g, mut e) = (golden.lines(), emitted.lines()); + let mut line = 0usize; + loop { + line += 1; + match (g.next(), e.next()) { + (Some(a), Some(b)) if a == b => {} + (a, b) => { + return Some(LineDiff { + line, + expected: a.unwrap_or("").to_string(), + actual: b.unwrap_or("").to_string(), + expected_lines: golden.lines().count(), + actual_lines: emitted.lines().count(), + }); + } + } + } +} + +/// Llama-3.2-1B-Instruct: the reference architecture, with committed goldens. Its +/// graphs are the issue #451-emitted StableHLO the backend ships in +/// `assets/llama-3.2-1b/` (on-device-argmax `prefill` / `decode`, plus the +/// host-sampled `prefill_logits` and ragged `decode` serve graphs for B_max 4/8). +pub(crate) static LLAMA_3_2_1B: ArchFixture = ArchFixture { + arch: "llama-3.2-1b", + config_json: include_str!("../assets/llama-3.2-1b/config.json"), + graphs: &[ + GraphFixture { + kind: GraphKind::Prefill { sample: true }, + golden_name: "prefill.mlir", + golden: include_str!("../assets/llama-3.2-1b/prefill.mlir"), + }, + GraphFixture { + kind: GraphKind::Decode { sample: true }, + golden_name: "decode.mlir", + golden: include_str!("../assets/llama-3.2-1b/decode.mlir"), + }, + GraphFixture { + kind: GraphKind::Prefill { sample: false }, + golden_name: "prefill_logits.mlir", + golden: include_str!("../assets/llama-3.2-1b/prefill_logits.mlir"), + }, + GraphFixture { + kind: GraphKind::DecodeRagged { + b_max: 4, + sample: false, + }, + golden_name: "decode_ragged_logits_b4.mlir", + golden: include_str!("../assets/llama-3.2-1b/decode_ragged_logits_b4.mlir"), + }, + GraphFixture { + kind: GraphKind::DecodeRagged { + b_max: 8, + sample: false, + }, + golden_name: "decode_ragged_logits_b8.mlir", + golden: include_str!("../assets/llama-3.2-1b/decode_ragged_logits_b8.mlir"), + }, + ], +}; + +/// Every registered structural fixture. Append a family here to add it to the +/// byte-exact gate; see the module docs for the freeze workflow. +pub(crate) static REGISTERED: &[&ArchFixture] = &[&LLAMA_3_2_1B]; + +#[cfg(test)] +mod tests { + use super::*; + + const QWEN_CONFIG_JSON: &str = include_str!("../assets/qwen2.5-0.5b/config.json"); + + /// The structural gate: every registered fixture emits its committed goldens + /// byte-for-byte. Demonstrates llama-3.2-1b passing (issue #496); this is the + /// single source of truth the emitter's `from_json_reproduces_bundled_assets` + /// test delegates to. + #[test] + fn registered_fixtures_are_byte_exact() { + for fx in REGISTERED { + let report = check_arch(fx).unwrap_or_else(|e| panic!("{}: {e}", fx.arch)); + assert!(report.passed(), "{report}"); + } + } + + /// The guard actually catches drift: a wrong golden fails and the report + /// localizes the first differing line, so a downstream emitter change that + /// shifts a graph is caught with a pointer, not a wall of text. + #[test] + fn check_arch_detects_drift() { + static DRIFT_GRAPHS: [GraphFixture; 1] = [GraphFixture { + kind: GraphKind::Decode { sample: true }, + golden_name: "decode.mlir", + golden: "// intentionally wrong first line\n", + }]; + let drifted = ArchFixture { + arch: "llama-3.2-1b (drift check)", + config_json: include_str!("../assets/llama-3.2-1b/config.json"), + graphs: &DRIFT_GRAPHS, + }; + let report = check_arch(&drifted).expect("config parses"); + assert!(!report.passed(), "a wrong golden must fail"); + let d = report.graphs[0] + .diff + .as_ref() + .expect("a mismatch is localized"); + assert_eq!(d.line, 1, "the first line differs"); + } + + /// The harness drives a second, golden-less architecture end to end through + /// the emit path: Qwen2.5-0.5B parses and every graph kind emits non-empty + /// StableHLO carrying the Qwen2 q/k/v bias args, proving the harness is + /// arch-agnostic without bundling large goldens (Qwen graphs are emitted at + /// load; see `assets/qwen2.5-0.5b/README.md`). + #[test] + fn emit_graphs_drives_a_golden_less_arch() { + let kinds = [ + GraphKind::Prefill { sample: true }, + GraphKind::Decode { sample: false }, + GraphKind::DecodeRagged { + b_max: 4, + sample: false, + }, + GraphKind::DecodeBatched { + b_max: 4, + sample: false, + }, + ]; + let graphs = emit_graphs(QWEN_CONFIG_JSON, &kinds).expect("qwen2.5-0.5b config parses"); + assert_eq!(graphs.len(), kinds.len(), "one graph per requested kind"); + for (kind, mlir) in &graphs { + assert!(!mlir.is_empty(), "{kind} emitted empty MLIR"); + assert!(mlir.contains("stablehlo."), "{kind} is not StableHLO"); + assert!( + mlir.contains("['bq']"), + "{kind} missing the Qwen2 q bias arg" + ); + } + } + + /// The precision guard is a pure predicate over the env value, so it needs no + /// racy env mutation to test: only `f16` / `bf16` are non-default. + #[test] + fn precision_default_detection() { + assert!(precision_is_default(None), "unset is the f32 default"); + assert!(precision_is_default(Some("f32"))); + assert!(precision_is_default(Some("anything-else"))); + assert!(!precision_is_default(Some("f16"))); + assert!(!precision_is_default(Some("bf16"))); + } + + /// `emit_all` reproduces the committed goldens for a registered fixture, so + /// re-freezing is a no-op when the emitter is unchanged (the freeze path is + /// safe to run) and the `(golden_name, mlir)` pairing lines up with the files. + #[test] + fn emit_all_reproduces_registered_goldens() { + for fx in REGISTERED { + let emitted = fx.emit_all().unwrap_or_else(|e| panic!("{}: {e}", fx.arch)); + assert_eq!(emitted.len(), fx.graphs.len()); + for ((name, mlir), gf) in emitted.iter().zip(fx.graphs) { + assert_eq!(*name, gf.golden_name); + assert_eq!(mlir, gf.golden, "{}/{name} re-emits its golden", fx.arch); + } + } + } + + /// Re-freeze the committed goldens for every registered fixture. A no-op unless + /// `MLXCEL_FREEZE_GOLDENS=1`; when set it rewrites each `assets//*.mlir` + /// from the current emitter output. Run it to refresh goldens after an + /// intentional, execution-tier-validated emitter change; the byte-exact test + /// then guards the refreshed files. A brand-new (unregistered) family freezes + /// with [`emit_graphs`] before its fixture exists (see the module docs). + #[test] + fn freeze_goldens() { + if std::env::var("MLXCEL_FREEZE_GOLDENS").as_deref() != Ok("1") { + return; + } + let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("assets"); + for fx in REGISTERED { + for (name, mlir) in fx.emit_all().unwrap_or_else(|e| panic!("{}: {e}", fx.arch)) { + let p = root.join(fx.arch).join(name); + std::fs::write(&p, mlir).unwrap_or_else(|e| panic!("write {}: {e}", p.display())); + eprintln!("froze {}", p.display()); + } + } + } +}