From de6429b8b2d571d7e98ece52ddfc07c3ee4a5516 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Wed, 1 Jul 2026 18:08:11 +0900 Subject: [PATCH] feat: f16-resident weights on the OpenXLA path (fusion-free bandwidth win) On the f16 GPU path, upload the resident linear-projection weights as f16 instead of f32-then-demote-in-graph, halving their per-step weight DRAM read with no dequant and no target fusion. Reuses the #568 per-dtype weight ABI (the C shim already allocates FLOAT_16 buffers), so no C-side change. The emitter's `take_weight` declares the seven per-layer projections (down/gate/up/wk/wo/wq/wv) as f16 args under f16 precision; `dot_general` skips its f32->f16 convert for an already-f16 weight, so the contraction sees the same f16 operand and stays token-exact. The loader packs those projections to f16 (round-to-nearest-even, matching the in-graph convert) into a new `WeightSpec::Proj` / `WeightBuf::F16` and uploads them as WDT_F16. embed / lm_head / norms / caches stay f32-resident; f32 / bf16 precision and the fused (Phi3) layouts keep the f32 arg, so the byte-exact f32 goldens are unchanged. Gated by `Config::supports_f16_resident` in both the emitter and the loader from the same `resolve_precision(device)`, so the uploaded buffer dtype always matches the emitted arg (a mismatch would fail `xla_llama_create`). Validated on GB10 (CUDA, Llama-3.2-1B 4-bit, greedy): f16-resident is token-exact with the f32 path (96/96) at 7.56 tok/s versus 5.23 f32 and ~6.8 for the prior f32-resident f16 path. On GB10 unified memory the halving is not visible via nvidia-smi; it is structural (f16 buffers are 2 bytes/element) and the throughput gain is its proxy. A new exhaustive test checks the f32->f16 packer round-trips every finite f16 pattern; a new emitter test pins exactly the 7 per-layer projections as f16 args. Part of #570 --- src/lib/mlxcel-xla/src/emitter/builder.rs | 7 ++ src/lib/mlxcel-xla/src/emitter/config.rs | 15 +++ src/lib/mlxcel-xla/src/emitter/mod.rs | 22 +++- src/lib/mlxcel-xla/src/emitter/model.rs | 15 ++- src/lib/mlxcel-xla/src/iree.rs | 35 +++++- src/lib/mlxcel-xla/src/weights.rs | 132 ++++++++++++++++++++-- 6 files changed, 211 insertions(+), 15 deletions(-) diff --git a/src/lib/mlxcel-xla/src/emitter/builder.rs b/src/lib/mlxcel-xla/src/emitter/builder.rs index 234fecd1..acabb3b5 100644 --- a/src/lib/mlxcel-xla/src/emitter/builder.rs +++ b/src/lib/mlxcel-xla/src/emitter/builder.rs @@ -175,6 +175,13 @@ impl Builder { self } + /// The contraction precision this builder emits at (`F32` by default, else the + /// resolved narrow precision). Read by the weight-arg declaration to decide the + /// resident weight dtype (issue #572: an f16 projection arg on the f16 path). + pub(crate) fn precision(&self) -> Precision { + self.precision + } + pub fn body(&self) -> &str { &self.body } diff --git a/src/lib/mlxcel-xla/src/emitter/config.rs b/src/lib/mlxcel-xla/src/emitter/config.rs index 91f1aa24..b737a2b7 100644 --- a/src/lib/mlxcel-xla/src/emitter/config.rs +++ b/src/lib/mlxcel-xla/src/emitter/config.rs @@ -1087,6 +1087,21 @@ impl Config { && self.moe.is_none() } + /// Whether the issue #572 f16-resident-weight path can apply: the standard + /// (non-fused-qkv, non-fused-gate-up) layout where every linear projection the + /// emitter takes via `take_weight` maps to a single whole-tensor loader spec + /// (`WeightSpec::Proj`), not a row-slice of a fused tensor. Under f16 precision + /// such projections are declared f16 args and uploaded f16-resident, halving + /// their per-step weight DRAM read. Gated identically in the emitter + /// (`take_weight`) and the loader (`weights::load_weights`) so the uploaded + /// buffer dtype always matches the emitted arg. Independent of quantization (a + /// dequantized 4-bit checkpoint is packed to f16 here just like a bf16 / f16 + /// one); the packed in-graph path (`supports_packed_quant`) takes precedence + /// when its env opt-in is set, since it keeps the weights ui32-resident instead. + pub(crate) fn supports_f16_resident(&self) -> bool { + !self.fused_qkv && !self.fused_gate_up + } + /// Attention score scale. Granite supplies the raw multiplier directly /// (`attention_multiplier`, which replaces `head_dim^-0.5`); Gemma2/3 use /// `query_pre_attn_scalar^-0.5` (computed in f64 to match HF, since it can diff --git a/src/lib/mlxcel-xla/src/emitter/mod.rs b/src/lib/mlxcel-xla/src/emitter/mod.rs index 16b026c7..d78701d2 100644 --- a/src/lib/mlxcel-xla/src/emitter/mod.rs +++ b/src/lib/mlxcel-xla/src/emitter/mod.rs @@ -62,7 +62,7 @@ pub(crate) use config::{MoeConfig, SharedExpertConfig}; // so allow the re-export to be unused in the other. `emit_decode_batched` (the // superseded uniform-B Stage-1 graph) is re-exported for the validation harness. #[allow(unused_imports)] -pub(crate) use builder::{quant_in_graph, resolve_precision}; +pub(crate) use builder::{Precision, quant_in_graph, resolve_precision}; #[allow(unused_imports)] pub(crate) use model::{ emit_decode, emit_decode_batched, emit_decode_ragged, emit_decode_ragged_with, @@ -778,6 +778,26 @@ mod tests { assert!(!emit_decode_with(&c, true, Precision::F32).contains("f16")); } + /// issue #572: under f16 precision exactly the 7 per-layer linear projections + /// (down/gate/up/wk/wo/wq/wv) are declared f16 weight args (uploaded f16-resident); + /// embed / lm_head / norms / caches stay f32. Each signature arg renders as + /// `%argN: loc("...")`, so an f16 weight arg is the substring `xf16> loc(`; + /// counting them pins down that ONLY the projections are narrowed. + #[test] + fn f16_precision_declares_the_projection_weight_args_f16() { + use super::builder::Precision; + let c = Config::llama_3_2_1b(); + let f16 = emit_prefill_with(&c, true, Precision::F16); + assert_eq!( + f16.matches("xf16> loc(").count(), + 7 * c.n_layers, + "expected exactly the 7 per-layer projections as f16 weight args" + ); + // The f32 default declares no f16 weight args (byte-exact goldens preserved). + let f32 = emit_prefill_with(&c, true, Precision::F32); + assert_eq!(f32.matches("xf16> loc(").count(), 0); + } + // ====================================================================== // dense arch pack (issue #498) // ====================================================================== diff --git a/src/lib/mlxcel-xla/src/emitter/model.rs b/src/lib/mlxcel-xla/src/emitter/model.rs index 6854e1d1..7631f4a7 100644 --- a/src/lib/mlxcel-xla/src/emitter/model.rs +++ b/src/lib/mlxcel-xla/src/emitter/model.rs @@ -180,7 +180,20 @@ fn take_weight( ); b.dequant_affine(&packed, &scales, &biases, qc.bits, qc.group_size) } - _ => take_arg(decls, idx, Ty::f32(vec![out, in_]), loc), + // issue #572: on the f16 GPU path, declare the resident projection weight as + // an f16 arg (uploaded f16-resident by the loader) instead of an f32 arg the + // dot demotes in-graph. `dot_general` sees the same f16 operand either way + // (its f32->f16 convert is skipped for an already-f16 weight), so this stays + // token-exact while halving the weight's per-step DRAM read. f32 / bf16 + // precision and the fused layouts keep the f32 arg (byte-identical goldens). + _ => { + let elt = if b.precision() == Precision::F16 && c.supports_f16_resident() { + "f16" + } else { + "f32" + }; + take_arg(decls, idx, Ty::new(vec![out, in_], elt), loc) + } } } diff --git a/src/lib/mlxcel-xla/src/iree.rs b/src/lib/mlxcel-xla/src/iree.rs index f670cc0a..625fe22a 100644 --- a/src/lib/mlxcel-xla/src/iree.rs +++ b/src/lib/mlxcel-xla/src/iree.rs @@ -57,7 +57,8 @@ use memmap2::Mmap; use safetensors::{Dtype, SafeTensors}; use crate::emitter::{ - Config, emit_decode_ragged_with, emit_decode_with, emit_prefill_with, resolve_precision, + Config, Precision, emit_decode_ragged_with, emit_decode_with, emit_prefill_with, + resolve_precision, }; // The loader reads the per-architecture checkpoint-weight order from // `weights::weight_specs`, which sources its names from `weight_names::scheme_names` @@ -67,7 +68,7 @@ use crate::emitter::{ // `dequantize_affine_stacked`). Both are pure-Rust and unit-tested without `iree`. use crate::weights::{ QuantPart, WeightSpec, bf16_to_f32, dequantize_affine, dequantize_affine_stacked, f16_to_f32, - f32_le_to_f32, slice_rows, weight_specs, + f32_le_to_f32, pack_f16, slice_rows, weight_specs, }; /// Weight-buffer element dtype passed to the C shim (issue #516 per-weight ABI): @@ -79,10 +80,12 @@ const WDT_F16: c_int = 1; const WDT_U32: c_int = 2; /// A resident-weight host buffer, kept alive until the shim copies it to the device. -/// Either f32 values (a widened / dequantized weight, the proven path) or raw bytes -/// (an MLX packed-U32 weight or its f16 scales / biases, issue #516 packed path). +/// f32 values (a widened / dequantized weight, the proven path), f16 values (an +/// f16-resident projection, issue #572), or raw bytes (an MLX packed-U32 weight or +/// its f16 scales / biases, issue #516 packed path). enum WeightBuf { F32(Vec), + F16(Vec), Raw(Vec), } @@ -91,6 +94,7 @@ impl WeightBuf { fn as_u8_ptr(&self) -> *const u8 { match self { WeightBuf::F32(v) => v.as_ptr() as *const u8, + WeightBuf::F16(v) => v.as_ptr() as *const u8, WeightBuf::Raw(v) => v.as_ptr(), } } @@ -369,6 +373,7 @@ fn resolve_weight_shards(model_dir: &Path, names: &[String]) -> Result Result<(Vec, Vec, Vec, Vec), String> { let specs = weight_specs(cfg); let names: Vec = specs.iter().map(|s| s.tensor_name().to_string()).collect(); @@ -538,6 +543,22 @@ fn load_weights( } bufs[i] = WeightBuf::F32(data); } + // issue #572: a linear projection. On the f16 GPU path pack it to an + // f16 resident buffer (WDT_F16), matching the emitter's f16 weight arg + // and halving its per-step DRAM read; otherwise upload f32, identical + // to the old `Whole`. `data` / `shape` are the widened row-major weight. + WeightSpec::Proj(_) => { + ranks[i] = shape.len() as c_int; + for (k, &s) in shape.iter().enumerate() { + dims[i * 4 + k] = s as i64; + } + if resident_f16 { + dtypes[i] = WDT_F16; + bufs[i] = WeightBuf::F16(pack_f16(&data)); + } else { + bufs[i] = WeightBuf::F32(data); + } + } WeightSpec::Rows { start, end, .. } => { if shape.len() != 2 { return Err(format!( @@ -577,7 +598,11 @@ fn create_ctx( prefill_vmfb: &Path, decode_vmfb: &Path, ) -> Result<*mut XlaCtx, String> { - let (bufs, dtypes, ranks, dims) = load_weights(model_dir, cfg)?; + // issue #572: the f16 GPU path uploads the projection weights f16-resident to + // match the emitter's f16 args. Uses the same resolve_precision(device) the graph + // emit uses, so the uploaded buffer dtype always lines up with the emitted arg. + let resident_f16 = resolve_precision(device) == Precision::F16 && cfg.supports_f16_resident(); + let (bufs, dtypes, ranks, dims) = load_weights(model_dir, cfg, resident_f16)?; let ptrs: Vec<*const c_void> = bufs .iter() .map(|b| b.as_u8_ptr() as *const c_void) diff --git a/src/lib/mlxcel-xla/src/weights.rs b/src/lib/mlxcel-xla/src/weights.rs index 9c7e8824..16d75164 100644 --- a/src/lib/mlxcel-xla/src/weights.rs +++ b/src/lib/mlxcel-xla/src/weights.rs @@ -34,6 +34,13 @@ pub(crate) enum WeightSpec { /// Load the whole checkpoint tensor `name` (widened to f32, an MLX-quantized /// U32 tensor dequantized to f32). Whole(String), + /// Load the whole checkpoint tensor `name` for a linear projection the emitter + /// took via `take_weight` (issue #572). Widened to f32 like [`WeightSpec::Whole`], + /// but the loader packs it to an f16 resident buffer when the f16 GPU path is + /// active (`Config::supports_f16_resident` + f16 precision), matching the emitter's + /// f16 weight arg; otherwise it uploads f32, byte-identical to the old `Whole`. + /// Kept distinct from `Whole` (embed / norm / lm_head), which stays f32-resident. + Proj(String), /// Load rows `[start, end)` of the checkpoint tensor `name` (a fused Phi3 /// projection, split into an emitter arg). Row-major `[out, in]`, so this is /// the `[start, end)` slice of the `out` axis. @@ -65,6 +72,7 @@ impl WeightSpec { pub(crate) fn tensor_name(&self) -> &str { match self { WeightSpec::Whole(n) => n, + WeightSpec::Proj(n) => n, WeightSpec::Rows { name, .. } => name, WeightSpec::QuantRaw { name, .. } => name, } @@ -113,7 +121,7 @@ fn push_proj(out: &mut Vec, name: String, quant: bool) { part: QuantPart::Biases, }); } else { - out.push(WeightSpec::Whole(name)); + out.push(WeightSpec::Proj(name)); } } @@ -368,6 +376,65 @@ pub(crate) fn f32_le_to_f32(bytes: &[u8]) -> Vec { .collect() } +/// One f32 -> IEEE 754 half (f16) bit pattern, round-to-nearest, ties-to-even (the +/// IEEE default), matching a `stablehlo.convert` f32 -> f16. So a projection weight +/// packed here (issue #572, f16-resident) is bit-identical to demoting the same f32 +/// weight inside the graph, and the contraction sees the same f16 operand and stays +/// token-exact. It is the exact inverse of [`half_to_f32`]: +/// `f32_to_f16_bits(half_to_f32(h)) == h` for every finite, non-NaN f16 `h`. +pub(crate) fn f32_to_f16_bits(x: f32) -> u16 { + let bits = x.to_bits(); + let sign = ((bits >> 16) & 0x8000) as u16; + let abs = bits & 0x7fff_ffff; + + // NaN / Inf (f32 exponent all ones): NaN -> a canonical quiet f16 NaN, Inf -> f16 Inf. + if abs >= 0x7f80_0000 { + return sign | if abs > 0x7f80_0000 { 0x7e00 } else { 0x7c00 }; + } + + // f16 biased exponent = (f32 biased exponent - 127) + 15. + let e = (abs >> 23) as i32 - 127 + 15; + + if e >= 0x1f { + return sign | 0x7c00; // overflow -> Inf + } + + if e <= 0 { + // Subnormal f16, or underflow to a signed zero. + if e < -10 { + return sign; // below half the smallest subnormal -> +/- 0 + } + // 24-bit significand (implicit leading 1), shifted into the subnormal range + // and rounded to nearest, ties to even. + let mant = (abs & 0x007f_ffff) | 0x0080_0000; + let shift = (14 - e) as u32; // e in [-10, 0] -> shift in [14, 24] + let q = mant >> shift; + let rem = mant & ((1 << shift) - 1); + let half = 1u32 << (shift - 1); + let round = u32::from(rem > half || (rem == half && q & 1 == 1)); + // q + round may reach 0x400, which is exactly the smallest normal (correct). + return sign | (q + round) as u16; + } + + // Normal f16: keep the top 10 mantissa bits, round to nearest even on bit 12. A + // mantissa carry rolls into the exponent, an exponent carry into 0x7c00 (Inf) -- + // both the correct results. + let mant = abs & 0x007f_ffff; + let base = ((e as u32) << 10) | (mant >> 13); + let rem = mant & 0x1fff; // the 13 dropped low bits + let half = 0x1000u32; // 1 << 12 + let round = u32::from(rem > half || (rem == half && base & 1 == 1)); + sign | (base + round) as u16 +} + +/// Pack a row-major f32 weight to its little-endian f16 bit pattern for an +/// f16-resident device upload (issue #572), via [`f32_to_f16_bits`] (RNE, matching +/// the in-graph demotion). The `u16` values are native-endian; the shim copies the +/// raw bytes, so on a little-endian host they land as the f16 buffer IREE expects. +pub(crate) fn pack_f16(data: &[f32]) -> Vec { + data.iter().map(|&x| f32_to_f16_bits(x)).collect() +} + /// Dequantize one MLX affine-quantized weight to row-major `[out, in]` f32. /// /// `packed` is the row-major `[out, in_packed]` u32 weight (little-endian bytes, @@ -512,13 +579,23 @@ pub(crate) fn dequantize_affine_stacked( mod tests { use super::*; - /// The Llama family weight order is the legacy all-`Whole` sequence (embed, - /// norm, then 9 per layer), so the #498 spec loader is byte-identical for it. + /// The Llama family weight order is embed + norm (f32-resident `Whole`) then, per + /// layer, the 7 linear projections as `Proj` (f16-resident-capable, issue #572) + /// and the 2 norms as `Whole`, in the fixed order below. Names / order unchanged. #[test] - fn weight_specs_llama_is_the_legacy_whole_order() { + fn weight_specs_llama_projections_are_proj_norms_are_whole() { let c = Config::llama_3_2_1b(); let specs = weight_specs(&c); - assert!(specs.iter().all(|s| matches!(s, WeightSpec::Whole(_)))); + // Every projection weight (`*_proj.weight`) is `Proj`; embed / norm / the + // per-layer norms are `Whole`. No fused rows or quant parts for plain Llama. + for s in &specs { + let n = s.tensor_name(); + match s { + WeightSpec::Proj(_) => assert!(n.ends_with("_proj.weight"), "Proj {n}"), + WeightSpec::Whole(_) => assert!(!n.ends_with("_proj.weight"), "Whole {n}"), + other => panic!("unexpected spec {other:?}"), + } + } assert_eq!(specs.len(), 2 + 9 * c.n_layers); let names: Vec<&str> = specs.iter().map(WeightSpec::tensor_name).collect(); assert_eq!(names[0], "model.embed_tokens.weight"); @@ -539,6 +616,43 @@ mod tests { ); } + #[test] + fn f32_to_f16_round_trips_every_finite_half() { + // half_to_f32(h) is exactly representable, so packing it back must recover h + // for every finite, non-NaN f16 pattern (signed zeros, subnormals, normals). + for h in 0u16..=u16::MAX { + if (h >> 10) & 0x1f == 0x1f { + continue; // skip Inf / NaN (NaN is non-canonical); Inf covered below + } + assert_eq!( + f32_to_f16_bits(half_to_f32(h)), + h, + "round-trip failed for f16 bits {h:#06x}" + ); + } + } + + #[test] + fn f32_to_f16_rounds_ties_to_even_and_saturates() { + // Exact tie between 1.0 (0x3c00, even mantissa) and its successor 0x3c01 + // rounds down to the even neighbour; the tie one step up rounds to 0x3c02. + assert_eq!( + f32_to_f16_bits((half_to_f32(0x3c00) + half_to_f32(0x3c01)) / 2.0), + 0x3c00 + ); + assert_eq!( + f32_to_f16_bits((half_to_f32(0x3c01) + half_to_f32(0x3c02)) / 2.0), + 0x3c02 + ); + // Above the f16 max (65504) saturates to +/-Inf; f16 max itself is exact. + assert_eq!(f32_to_f16_bits(65504.0), 0x7bff); + assert_eq!(f32_to_f16_bits(70000.0), 0x7c00); + assert_eq!(f32_to_f16_bits(-70000.0), 0xfc00); + // Signed zero. + assert_eq!(f32_to_f16_bits(0.0), 0x0000); + assert_eq!(f32_to_f16_bits(-0.0), 0x8000); + } + /// Phi3 row-slices the fused `qkv_proj` ([Q|K|V]) and `gate_up_proj` (gate then /// up) into the emitter's separate args, and is untied (`lm_head` after norm). #[test] @@ -886,12 +1000,14 @@ mod tests { assert!( matches!(&specs[i + 2], WeightSpec::QuantRaw { name, part: QuantPart::Biases } if name.ends_with("q_proj.biases")) ); - // quant = false is the legacy all-Whole order (purely additive packed path). + // quant = false is the unquantized layout: projections are `Proj` (issue + // #572), everything else `Whole`, and there are no packed parts (the packed + // path is purely additive). assert!( weight_specs_q(&c, false) .iter() - .all(|s| matches!(s, WeightSpec::Whole(_))), - "unquantized layout is unchanged" + .all(|s| matches!(s, WeightSpec::Whole(_) | WeightSpec::Proj(_))), + "unquantized layout has no packed parts" ); } }