Skip to content

test: add reusable per-architecture XLA validation harness#556

Merged
inureyes merged 2 commits into
mainfrom
feature/issue-496-xla-validation-harness
Jul 1, 2026
Merged

test: add reusable per-architecture XLA validation harness#556
inureyes merged 2 commits into
mainfrom
feature/issue-496-xla-validation-harness

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Adds a reusable, two-tier validation harness so adding an OpenXLA (mlxcel-xla) architecture is turnkey, replacing the ad-hoc per-arch flow (dequant a checkpoint offline, run an HF fp32 oracle, run xla_oracle_check by hand). Downstream families (#497-#501) consume it.

What changed

  • src/lib/mlxcel-xla/src/validation.rs (new): the structural tier, a pure-Rust byte-exact gate. A per-architecture ArchFixture registry plus check_arch parses a checkpoint's config.json, emits each graph, and diffs it against the committed golden assets/<arch>/*.mlir, localizing the first differing line on drift. emit_graphs is the golden-less freeze primitive for a brand-new family, and MLXCEL_FREEZE_GOLDENS=1 cargo test ... freeze_goldens re-freezes a registered one. It honors MLXCEL_XLA_PRECISION: the goldens are the default f32 graphs, so a byte-exact run under f16/bf16 is rejected with a clear message instead of a confusing diff.
  • src/lib/mlxcel-xla/src/lib.rs, src/lib/mlxcel-xla/src/emitter/mod.rs: wire in the validation module (same cfg(any(feature = "iree", test)) gating as the sibling modules), re-export emit_decode_batched, and route the emitter's existing from_json_reproduces_bundled_assets_byte_for_byte test through check_arch so the goldens have a single source of truth.
  • scripts/xla/validate_arch.sh (new): the execution tier one command. Given --model <checkpoint> it runs the structural pre-gate, produces the HF fp32 oracle, then runs xla_oracle_check (single-seq greedy == HF oracle) and xla_batch_bench (each batched request == its single-seq reference), reporting pass/fail for both.
  • spike/openxla/oracle_continuation.py (new): the arch-generic oracle producer. Loads a checkpoint in fp32, dequantizing an MLX 4bit/8bit checkpoint offline first with the same affine formula as src/weights.rs (--selftest verifies the dequant math against the Rust hand examples), then records the no-EOS greedy continuation.
  • src/lib/mlxcel-xla/README.md: documents both tiers and the turnkey add-a-family workflow.

Test plan

  • cargo test -p mlxcel-xla --lib validation (structural gate): 6/6 pass; registered_fixtures_are_byte_exact reports a clean byte-exact pass on llama-3.2-1b.
  • cargo test -p mlxcel-xla --lib emitter: 14/14 pass, including the now-delegating from_json_reproduces_bundled_assets_byte_for_byte and the rebased refactor: share per-layer attention core across emitter graph kinds #554 gemma2 test.
  • cargo clippy -p mlxcel-xla --lib --tests -- -D warnings: clean.
  • MLXCEL_FREEZE_GOLDENS=1 cargo test ... freeze_goldens leaves assets/ byte-for-byte unchanged (freeze reproduces the committed goldens).
  • oracle_continuation.py --selftest: dequant matches weights.rs (4-bit and 8-bit hand examples).
  • Execution-tier live run (scripts/xla/validate_arch.sh --model <ckpt>) needs a real xla-iree build plus IREE GPU/CPU, so it is the opt-in run for a GPU session (not CI), by the two-tier design.

Closes #496

Add a two-tier validation harness so adding an OpenXLA architecture is turnkey, replacing the ad-hoc per-arch flow (dequant offline, run an HF oracle, run xla_oracle_check by hand).

Structural tier (src/lib/mlxcel-xla/src/validation.rs, pure Rust, no GPU): a per-architecture ArchFixture registry plus check_arch, which parses a checkpoint's config.json, emits each graph, and diffs it byte-for-byte against the committed golden assets/<arch>/*.mlir, localizing the first differing line on drift. It honors MLXCEL_XLA_PRECISION (a byte-exact run under f16/bf16 is rejected, since that emit differs from the f32 goldens); emit_graphs is the golden-less freeze primitive for a new family. The emitter's own byte-exact test now delegates to check_arch so the goldens have a single source of truth. Demonstrated clean on llama-3.2-1b.

Execution tier (opt-in, needs a real IREE build and a checkpoint): scripts/xla/validate_arch.sh is the one command. It produces the HF fp32 oracle with spike/openxla/oracle_continuation.py (loads fp32, dequantizing an MLX 4bit/8bit checkpoint offline first with the same affine formula as weights.rs), then runs xla_oracle_check (single-seq greedy == HF oracle) and xla_batch_bench (each batched request == its single-seq reference), reporting pass/fail for both. The offline-dequant math is self-tested against the Rust loader (oracle_continuation.py --selftest).

Usage and the turnkey add-a-family workflow are documented in the validation module docs and src/lib/mlxcel-xla/README.md.
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions priority:medium Medium priority area:architecture Architecture and code structure changes status:review Under review status:done Completed and removed status:review Under review labels Jul 1, 2026
…-validation-harness

# Conflicts:
#	src/lib/mlxcel-xla/src/emitter/mod.rs
@inureyes inureyes merged commit e7ffc53 into main Jul 1, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

test: add a reusable per-architecture token-exact validation harness

1 participant