test: add reusable per-architecture XLA validation harness#556
Merged
Conversation
Add a two-tier validation harness so adding an OpenXLA architecture is turnkey, replacing the ad-hoc per-arch flow (dequant offline, run an HF oracle, run xla_oracle_check by hand). Structural tier (src/lib/mlxcel-xla/src/validation.rs, pure Rust, no GPU): a per-architecture ArchFixture registry plus check_arch, which parses a checkpoint's config.json, emits each graph, and diffs it byte-for-byte against the committed golden assets/<arch>/*.mlir, localizing the first differing line on drift. It honors MLXCEL_XLA_PRECISION (a byte-exact run under f16/bf16 is rejected, since that emit differs from the f32 goldens); emit_graphs is the golden-less freeze primitive for a new family. The emitter's own byte-exact test now delegates to check_arch so the goldens have a single source of truth. Demonstrated clean on llama-3.2-1b. Execution tier (opt-in, needs a real IREE build and a checkpoint): scripts/xla/validate_arch.sh is the one command. It produces the HF fp32 oracle with spike/openxla/oracle_continuation.py (loads fp32, dequantizing an MLX 4bit/8bit checkpoint offline first with the same affine formula as weights.rs), then runs xla_oracle_check (single-seq greedy == HF oracle) and xla_batch_bench (each batched request == its single-seq reference), reporting pass/fail for both. The offline-dequant math is self-tested against the Rust loader (oracle_continuation.py --selftest). Usage and the turnkey add-a-family workflow are documented in the validation module docs and src/lib/mlxcel-xla/README.md.
…-validation-harness # Conflicts: # src/lib/mlxcel-xla/src/emitter/mod.rs
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a reusable, two-tier validation harness so adding an OpenXLA (
mlxcel-xla) architecture is turnkey, replacing the ad-hoc per-arch flow (dequant a checkpoint offline, run an HF fp32 oracle, runxla_oracle_checkby hand). Downstream families (#497-#501) consume it.What changed
src/lib/mlxcel-xla/src/validation.rs(new): the structural tier, a pure-Rust byte-exact gate. A per-architectureArchFixtureregistry pluscheck_archparses a checkpoint'sconfig.json, emits each graph, and diffs it against the committed goldenassets/<arch>/*.mlir, localizing the first differing line on drift.emit_graphsis the golden-less freeze primitive for a brand-new family, andMLXCEL_FREEZE_GOLDENS=1 cargo test ... freeze_goldensre-freezes a registered one. It honorsMLXCEL_XLA_PRECISION: the goldens are the defaultf32graphs, so a byte-exact run underf16/bf16is rejected with a clear message instead of a confusing diff.src/lib/mlxcel-xla/src/lib.rs,src/lib/mlxcel-xla/src/emitter/mod.rs: wire in thevalidationmodule (samecfg(any(feature = "iree", test))gating as the sibling modules), re-exportemit_decode_batched, and route the emitter's existingfrom_json_reproduces_bundled_assets_byte_for_bytetest throughcheck_archso the goldens have a single source of truth.scripts/xla/validate_arch.sh(new): the execution tier one command. Given--model <checkpoint>it runs the structural pre-gate, produces the HF fp32 oracle, then runsxla_oracle_check(single-seq greedy == HF oracle) andxla_batch_bench(each batched request == its single-seq reference), reporting pass/fail for both.spike/openxla/oracle_continuation.py(new): the arch-generic oracle producer. Loads a checkpoint in fp32, dequantizing an MLX 4bit/8bit checkpoint offline first with the same affine formula assrc/weights.rs(--selftestverifies the dequant math against the Rust hand examples), then records the no-EOS greedy continuation.src/lib/mlxcel-xla/README.md: documents both tiers and the turnkey add-a-family workflow.Test plan
cargo test -p mlxcel-xla --lib validation(structural gate): 6/6 pass;registered_fixtures_are_byte_exactreports a clean byte-exact pass onllama-3.2-1b.cargo test -p mlxcel-xla --lib emitter: 14/14 pass, including the now-delegatingfrom_json_reproduces_bundled_assets_byte_for_byteand the rebased refactor: share per-layer attention core across emitter graph kinds #554 gemma2 test.cargo clippy -p mlxcel-xla --lib --tests -- -D warnings: clean.MLXCEL_FREEZE_GOLDENS=1 cargo test ... freeze_goldensleavesassets/byte-for-byte unchanged (freeze reproduces the committed goldens).oracle_continuation.py --selftest: dequant matchesweights.rs(4-bit and 8-bit hand examples).scripts/xla/validate_arch.sh --model <ckpt>) needs a realxla-ireebuild plus IREE GPU/CPU, so it is the opt-in run for a GPU session (not CI), by the two-tier design.Closes #496