Skip to content

feat: int8/int4 packed weight quantization on the OpenXLA path (#516)#568

Merged
inureyes merged 1 commit into
mainfrom
feat/516-int8-packed-quant-openxla
Jul 1, 2026
Merged

feat: int8/int4 packed weight quantization on the OpenXLA path (#516)#568
inureyes merged 1 commit into
mainfrom
feat/516-int8-packed-quant-openxla

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Picks up the int8 lever from #516 / #513 (ADR 0004). The OpenXLA loader stopped
dequantizing MLX 4/8-bit checkpoints to f32: with MLXCEL_XLA_QUANT=packed it keeps
the weights resident packed (ui32 weight + f16 scales/biases) and dequantizes
each one in the StableHLO graph, so the bandwidth-relevant packed form survives
into the graph where an int-native target could use it. Opt-in, off by default.

This landed the whole packed ABI, then measured it on a GB10 (the deferral's
"hardware gate"). The headline finding is below.

What changed

  • Emitter (builder.rs, model.rs): Builder::dequant_affine authors the MLX
    affine dequant in StableHLO (bit-unpack the ui32 lanes with shift/mask, convert,
    then q*scale + bias per group_size), bit-identical to the host
    weights::dequantize_affine, so the packed graph is token-exact by construction.
    take_weight routes the 7 per-layer projections through it across all four graph
    kinds (decode / ragged / batched / prefill). The dense/unquantized path is
    untouched, so the byte-exact goldens are unchanged.
  • Loader (weights.rs, iree.rs): WeightSpec::QuantRaw + weight_specs
    expansion upload the packed U32 and f16 scales/biases as separate device buffers;
    load_weights and the C shim ABI are now per-weight-dtype (f32 / f16 / ui32).
  • Gate: Config::supports_packed_quant (standard Llama layout: non-fused-qkv,
    non-fused-gate-up, non-dense-MLP, non-MoE) is shared by the emitter and the loader
    so they never diverge. v1 leaves embed / lm_head f32-resident.
  • Toolchain: scripts/iree/setup-cuda.sh reproduces the source-built cuda IREE
    runtime the GB10 path needs (mirrors setup-macos.sh).

Measured on GB10 (the fusion gate)

Llama-3.2-1B, MLX 4-bit group_size 64, greedy, warm vmfb, CUDA via the source-built
IREE runtime:

decode path tok/s GPU util correctness
f16, dequant-at-load (default) ~6.7 84% reference
packed, dequant-in-graph (MLXCEL_XLA_QUANT=packed) ~1.6 96% token-exact with the f16 path

The packed path is correct but ~4.3x slower, because IREE's CUDA codegen does
not fuse
the unpack+dequant into the matmul: the decode step is ~678 dispatches and
the reconstructed f32 weight is materialized to DRAM every step, so it pays more
bandwidth + compute, not less (util rises 84% -> 96%). The fusion flags
(--iree-dispatch-creation-enable-aggressive-fusion, --iree-opt-generalize-matmul,
--iree-dispatch-creation-enable-early-trunc-fusion) leave the dispatch count
unchanged (678 -> 677).

So the memory-bandwidth payoff is not realized by authoring the dequant in the
portable graph alone; it needs the target to fuse dequant into the matmul (a
quantized-matmul op, or an int8 dot_general lowering to the hardware int8 path).
This is the same split the f16 profiling reached, now confirmed for int8: the
graph-level change is in scope and transferable, but the fused low-precision kernel
is upstream IREE's job. The packed path therefore lands off by default, as the
correctness-verified ABI a fused quantized-matmul reuses.

How to run

eval "$(scripts/iree/setup-cuda.sh --env)"   # or --env after a first full run
cargo build --release --features xla-iree
# token-exact packed path (opt-in), else default dequant-at-load:
MLXCEL_XLA_QUANT=packed MLXCEL_BACKEND=xla MLXCEL_XLA_DEVICE=cuda \
  ./target/release/mlxcel generate -m <Llama-3.2-1B 4-bit dir> -p "..." -n 48

Validation

  • cargo test -p mlxcel-xla --lib: 111/111 (dequant-in-graph unit tests for 4/8-bit,
    the weight_specs packed-expansion contract test, and the byte-exact goldens all
    pass; the dense path is unchanged).
  • cargo clippy -p mlxcel-xla --features iree: clean (no new warnings).
  • End-to-end on GB10: packed decode is token-exact with the f16 path and with the
    original baseline output; perf as tabled above.

Follow-up

A fused quantized-matmul (int8 dot_general lowering, or a StableHLO quantized dot
that IREE/an NPU HAL driver fuses into a systolic int8 kernel) is what turns this from
correctness-only into the bandwidth win. This PR is its foundation. I did not
auto-close #516; leaving that to the maintainer to decide (close as
implemented-and-gated, or keep open for the fused lever).

Refs #516. Part of #513. ADR 0004.

Keep MLX 4/8-bit checkpoints resident PACKED (ui32 weight + f16 scales/biases)
and dequantize each weight in the StableHLO graph, instead of dequantizing to
f32 at load. Opt-in via MLXCEL_XLA_QUANT=packed; off by default.

- Emitter: Builder::dequant_affine authors the MLX affine dequant in StableHLO
  (bit-unpack the ui32 lanes, then q*scale + bias per group), bit-identical to
  the host weights::dequantize_affine, so the packed graph is token-exact.
  take_weight routes the 7 per-layer projections through it across all four graph
  kinds; the dense/unquantized path stays byte-exact (goldens unchanged).
- Loader: WeightSpec::QuantRaw + weight_specs expansion upload the packed U32 and
  f16 scales/biases per dtype; the C shim ABI is now per-weight-dtype.
- Gate: Config::supports_packed_quant (standard Llama layout) is shared by the
  emitter and loader so they never diverge; v1 leaves embed/lm_head f32-resident.
- Toolchain: scripts/iree/setup-cuda.sh reproduces the source-built cuda IREE
  runtime the GB10 path needs.

Measured on a GB10 (Llama-3.2-1B 4-bit, greedy, warm vmfb): token-exact with the
f16 path but about 4.3x slower (~1.6 vs ~6.7 tok/s), because IREE's CUDA codegen
does not fuse the in-graph dequant into the matmul (678 decode dispatches,
unchanged by the aggressive-fusion flags), so it rematerializes the f32 weights
each step. The memory-bandwidth lever needs a fused quantized-matmul; this lands
the correctness-verified packed ABI as its foundation, off by default. See ADR
0004 and the crate README for the full finding.

Refs #516.
@inureyes inureyes added the area:inference Generation, sampling, decoding (incl. speculative, DRY) label Jul 1, 2026
@inureyes inureyes merged commit 4dbb7d7 into main Jul 1, 2026
5 checks passed
@inureyes inureyes deleted the feat/516-int8-packed-quant-openxla branch July 1, 2026 07:05
inureyes added a commit that referenced this pull request Jul 1, 2026
… win) (#577)

On the f16 GPU path, upload the resident linear-projection weights as f16
instead of f32-then-demote-in-graph, halving their per-step weight DRAM read
with no dequant and no target fusion. Reuses the #568 per-dtype weight ABI
(the C shim already allocates FLOAT_16 buffers), so no C-side change.

The emitter's `take_weight` declares the seven per-layer projections
(down/gate/up/wk/wo/wq/wv) as f16 args under f16 precision; `dot_general`
skips its f32->f16 convert for an already-f16 weight, so the contraction sees
the same f16 operand and stays token-exact. The loader packs those projections
to f16 (round-to-nearest-even, matching the in-graph convert) into a new
`WeightSpec::Proj` / `WeightBuf::F16` and uploads them as WDT_F16. embed /
lm_head / norms / caches stay f32-resident; f32 / bf16 precision and the fused
(Phi3) layouts keep the f32 arg, so the byte-exact f32 goldens are unchanged.
Gated by `Config::supports_f16_resident` in both the emitter and the loader
from the same `resolve_precision(device)`, so the uploaded buffer dtype always
matches the emitted arg (a mismatch would fail `xla_llama_create`).

Validated on GB10 (CUDA, Llama-3.2-1B 4-bit, greedy): f16-resident is
token-exact with the f32 path (96/96) at 7.56 tok/s versus 5.23 f32 and ~6.8
for the prior f32-resident f16 path. On GB10 unified memory the halving is not
visible via nvidia-smi; it is structural (f16 buffers are 2 bytes/element) and
the throughput gain is its proxy. A new exhaustive test checks the f32->f16
packer round-trips every finite f16 pattern; a new emitter test pins exactly
the 7 per-layer projections as f16 args.

Part of #570
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:inference Generation, sampling, decoding (incl. speculative, DRY)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: int8 weight quantization with quantized matmul on the OpenXLA path (#449)

1 participant