feat: int8/int4 packed weight quantization on the OpenXLA path (#516)#568
Merged
Conversation
Keep MLX 4/8-bit checkpoints resident PACKED (ui32 weight + f16 scales/biases) and dequantize each weight in the StableHLO graph, instead of dequantizing to f32 at load. Opt-in via MLXCEL_XLA_QUANT=packed; off by default. - Emitter: Builder::dequant_affine authors the MLX affine dequant in StableHLO (bit-unpack the ui32 lanes, then q*scale + bias per group), bit-identical to the host weights::dequantize_affine, so the packed graph is token-exact. take_weight routes the 7 per-layer projections through it across all four graph kinds; the dense/unquantized path stays byte-exact (goldens unchanged). - Loader: WeightSpec::QuantRaw + weight_specs expansion upload the packed U32 and f16 scales/biases per dtype; the C shim ABI is now per-weight-dtype. - Gate: Config::supports_packed_quant (standard Llama layout) is shared by the emitter and loader so they never diverge; v1 leaves embed/lm_head f32-resident. - Toolchain: scripts/iree/setup-cuda.sh reproduces the source-built cuda IREE runtime the GB10 path needs. Measured on a GB10 (Llama-3.2-1B 4-bit, greedy, warm vmfb): token-exact with the f16 path but about 4.3x slower (~1.6 vs ~6.7 tok/s), because IREE's CUDA codegen does not fuse the in-graph dequant into the matmul (678 decode dispatches, unchanged by the aggressive-fusion flags), so it rematerializes the f32 weights each step. The memory-bandwidth lever needs a fused quantized-matmul; this lands the correctness-verified packed ABI as its foundation, off by default. See ADR 0004 and the crate README for the full finding. Refs #516.
This was referenced Jul 1, 2026
inureyes
added a commit
that referenced
this pull request
Jul 1, 2026
… win) (#577) On the f16 GPU path, upload the resident linear-projection weights as f16 instead of f32-then-demote-in-graph, halving their per-step weight DRAM read with no dequant and no target fusion. Reuses the #568 per-dtype weight ABI (the C shim already allocates FLOAT_16 buffers), so no C-side change. The emitter's `take_weight` declares the seven per-layer projections (down/gate/up/wk/wo/wq/wv) as f16 args under f16 precision; `dot_general` skips its f32->f16 convert for an already-f16 weight, so the contraction sees the same f16 operand and stays token-exact. The loader packs those projections to f16 (round-to-nearest-even, matching the in-graph convert) into a new `WeightSpec::Proj` / `WeightBuf::F16` and uploads them as WDT_F16. embed / lm_head / norms / caches stay f32-resident; f32 / bf16 precision and the fused (Phi3) layouts keep the f32 arg, so the byte-exact f32 goldens are unchanged. Gated by `Config::supports_f16_resident` in both the emitter and the loader from the same `resolve_precision(device)`, so the uploaded buffer dtype always matches the emitted arg (a mismatch would fail `xla_llama_create`). Validated on GB10 (CUDA, Llama-3.2-1B 4-bit, greedy): f16-resident is token-exact with the f32 path (96/96) at 7.56 tok/s versus 5.23 f32 and ~6.8 for the prior f32-resident f16 path. On GB10 unified memory the halving is not visible via nvidia-smi; it is structural (f16 buffers are 2 bytes/element) and the throughput gain is its proxy. A new exhaustive test checks the f32->f16 packer round-trips every finite f16 pattern; a new emitter test pins exactly the 7 per-layer projections as f16 args. Part of #570
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Picks up the int8 lever from #516 / #513 (ADR 0004). The OpenXLA loader stopped
dequantizing MLX 4/8-bit checkpoints to f32: with
MLXCEL_XLA_QUANT=packedit keepsthe weights resident packed (
ui32weight + f16 scales/biases) and dequantizeseach one in the StableHLO graph, so the bandwidth-relevant packed form survives
into the graph where an int-native target could use it. Opt-in, off by default.
This landed the whole packed ABI, then measured it on a GB10 (the deferral's
"hardware gate"). The headline finding is below.
What changed
builder.rs,model.rs):Builder::dequant_affineauthors the MLXaffine dequant in StableHLO (bit-unpack the
ui32lanes with shift/mask, convert,then
q*scale + biaspergroup_size), bit-identical to the hostweights::dequantize_affine, so the packed graph is token-exact by construction.take_weightroutes the 7 per-layer projections through it across all four graphkinds (decode / ragged / batched / prefill). The dense/unquantized path is
untouched, so the byte-exact goldens are unchanged.
weights.rs,iree.rs):WeightSpec::QuantRaw+weight_specsexpansion upload the packed U32 and f16 scales/biases as separate device buffers;
load_weightsand the C shim ABI are now per-weight-dtype (f32 / f16 / ui32).Config::supports_packed_quant(standard Llama layout: non-fused-qkv,non-fused-gate-up, non-dense-MLP, non-MoE) is shared by the emitter and the loader
so they never diverge. v1 leaves embed / lm_head f32-resident.
scripts/iree/setup-cuda.shreproduces the source-built cuda IREEruntime the GB10 path needs (mirrors
setup-macos.sh).Measured on GB10 (the fusion gate)
Llama-3.2-1B, MLX 4-bit
group_size64, greedy, warm vmfb, CUDA via the source-builtIREE runtime:
MLXCEL_XLA_QUANT=packed)The packed path is correct but ~4.3x slower, because IREE's CUDA codegen does
not fuse the unpack+dequant into the matmul: the decode step is ~678 dispatches and
the reconstructed f32 weight is materialized to DRAM every step, so it pays more
bandwidth + compute, not less (util rises 84% -> 96%). The fusion flags
(
--iree-dispatch-creation-enable-aggressive-fusion,--iree-opt-generalize-matmul,--iree-dispatch-creation-enable-early-trunc-fusion) leave the dispatch countunchanged (678 -> 677).
So the memory-bandwidth payoff is not realized by authoring the dequant in the
portable graph alone; it needs the target to fuse dequant into the matmul (a
quantized-matmul op, or an int8
dot_generallowering to the hardware int8 path).This is the same split the f16 profiling reached, now confirmed for int8: the
graph-level change is in scope and transferable, but the fused low-precision kernel
is upstream IREE's job. The packed path therefore lands off by default, as the
correctness-verified ABI a fused quantized-matmul reuses.
How to run
Validation
cargo test -p mlxcel-xla --lib: 111/111 (dequant-in-graph unit tests for 4/8-bit,the
weight_specspacked-expansion contract test, and the byte-exact goldens allpass; the dense path is unchanged).
cargo clippy -p mlxcel-xla --features iree: clean (no new warnings).original baseline output; perf as tabled above.
Follow-up
A fused quantized-matmul (int8
dot_generallowering, or a StableHLO quantized dotthat IREE/an NPU HAL driver fuses into a systolic int8 kernel) is what turns this from
correctness-only into the bandwidth win. This PR is its foundation. I did not
auto-close #516; leaving that to the maintainer to decide (close as
implemented-and-gated, or keep open for the fused lever).
Refs #516. Part of #513. ADR 0004.