Skip to content

feat: f16-resident weights on the OpenXLA path (#572)#577

Merged
inureyes merged 1 commit into
mainfrom
feat/572-f16-resident-weights
Jul 1, 2026
Merged

feat: f16-resident weights on the OpenXLA path (#572)#577
inureyes merged 1 commit into
mainfrom
feat/572-f16-resident-weights

Conversation

@inureyes

@inureyes inureyes commented Jul 1, 2026

Copy link
Copy Markdown
Member

Summary

Part of the OpenXLA low-precision performance epic (#570). On the f16 GPU path, upload the resident linear-projection weights as f16 instead of the previous "upload f32, demote to f16 inside the StableHLO graph" approach, halving their per-step weight DRAM read with no dequant and no target fusion. This is the fusion-free bandwidth win the #568 per-dtype weight ABI already enables, independent of the packed int8 fusion gate.

What changed

  • Emitter. take_weight declares the seven per-layer linear projections (down/gate/up/wk/wo/wq/wv) as f16 args under f16 precision (Ty::new(..., "f16") instead of Ty::f32). dot_general already skips its f32->f16 convert when an operand is already f16, so the activation is demoted as before and the weight flows straight in: the contraction sees the same f16 operand and stays token-exact.
  • Loader. A new WeightSpec::Proj (a projection whole-tensor, distinct from Whole = embed / norm / lm_head) is packed to f16 (round-to-nearest-even, matching IREE's in-graph convert) into a new WeightBuf::F16 and uploaded as WDT_F16. The C shim needs no change (it already maps the f16 dtype code to FLOAT_16, esize 2). embed / lm_head / norms / KV caches stay f32-resident.
  • Gate + plumbing. Config::supports_f16_resident (the non-fused-qkv / non-fused-gate-up layouts, where every take_weight projection maps to a single Proj spec) gates it in both the emitter and the loader, decided from the same resolve_precision(device) the graph emit uses. So the uploaded buffer dtype always matches the emitted arg (a mismatch would fail xla_llama_create at load). f32 / bf16 precision and fused (Phi3) layouts keep the f32 arg, so the byte-exact f32 goldens are unchanged.

Validation

  • Pure Rust (118 lib tests): an exhaustive test verifies the f32->f16 packer round-trips every finite f16 pattern (f32_to_f16_bits(half_to_f32(h)) == h) plus ties-to-even / saturation; a new emitter test pins exactly the 7 per-layer projections as f16 args under f16 and zero under f32; all byte-exact f32 goldens (registered_fixtures_are_byte_exact, emit_all_reproduces_registered_goldens) still pass, so the default path is unchanged.
  • GB10 (CUDA, Llama-3.2-1B 4-bit, greedy): f16-resident is token-exact with the f32 path (96/96 tokens byte-identical) at 7.56 tok/s versus 5.23 f32 and ~6.8 for the prior f32-resident f16 path. Token-exactness plus a successful load also confirm end-to-end that the loader's f16 upload lines up with the emitter's f16 args.

Notes

  • The resident-memory halving is structural (f16 projection buffers are 2 bytes/element vs 4; the shim allocates FLOAT_16). On GB10 unified memory it is not visible via nvidia-smi memory.used, so the throughput gain is its observable proxy.
  • This is orthogonal to the packed int8 path (MLXCEL_XLA_QUANT=packed), which keeps weights ui32-resident and takes precedence when its env opt-in is set.

Closes #572
Part of #570

… win)

On the f16 GPU path, upload the resident linear-projection weights as f16
instead of f32-then-demote-in-graph, halving their per-step weight DRAM read
with no dequant and no target fusion. Reuses the #568 per-dtype weight ABI
(the C shim already allocates FLOAT_16 buffers), so no C-side change.

The emitter's `take_weight` declares the seven per-layer projections
(down/gate/up/wk/wo/wq/wv) as f16 args under f16 precision; `dot_general`
skips its f32->f16 convert for an already-f16 weight, so the contraction sees
the same f16 operand and stays token-exact. The loader packs those projections
to f16 (round-to-nearest-even, matching the in-graph convert) into a new
`WeightSpec::Proj` / `WeightBuf::F16` and uploads them as WDT_F16. embed /
lm_head / norms / caches stay f32-resident; f32 / bf16 precision and the fused
(Phi3) layouts keep the f32 arg, so the byte-exact f32 goldens are unchanged.
Gated by `Config::supports_f16_resident` in both the emitter and the loader
from the same `resolve_precision(device)`, so the uploaded buffer dtype always
matches the emitted arg (a mismatch would fail `xla_llama_create`).

Validated on GB10 (CUDA, Llama-3.2-1B 4-bit, greedy): f16-resident is
token-exact with the f32 path (96/96) at 7.56 tok/s versus 5.23 f32 and ~6.8
for the prior f32-resident f16 path. On GB10 unified memory the halving is not
visible via nvidia-smi; it is structural (f16 buffers are 2 bytes/element) and
the throughput gain is its proxy. A new exhaustive test checks the f32->f16
packer round-trips every finite f16 pattern; a new emitter test pins exactly
the 7 per-layer projections as f16 args.

Part of #570
@inureyes inureyes added type:performance Performance improvements priority:high High priority area:inference Generation, sampling, decoding (incl. speculative, DRY) status:review Under review labels Jul 1, 2026
@inureyes inureyes merged commit 4bdba22 into main Jul 1, 2026
5 checks passed
@inureyes inureyes deleted the feat/572-f16-resident-weights branch July 1, 2026 09:11
@inureyes inureyes added status:done Completed and removed status:review Under review labels Jul 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:high High priority status:done Completed type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: f16-resident weights on the OpenXLA path (fusion-free bandwidth win)

1 participant