Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/lib/mlxcel-xla/src/emitter/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ impl Builder {
self
}

/// The contraction precision this builder emits at (`F32` by default, else the
/// resolved narrow precision). Read by the weight-arg declaration to decide the
/// resident weight dtype (issue #572: an f16 projection arg on the f16 path).
pub(crate) fn precision(&self) -> Precision {
self.precision
}

pub fn body(&self) -> &str {
&self.body
}
Expand Down
15 changes: 15 additions & 0 deletions src/lib/mlxcel-xla/src/emitter/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,21 @@ impl Config {
&& self.moe.is_none()
}

/// Whether the issue #572 f16-resident-weight path can apply: the standard
/// (non-fused-qkv, non-fused-gate-up) layout where every linear projection the
/// emitter takes via `take_weight` maps to a single whole-tensor loader spec
/// (`WeightSpec::Proj`), not a row-slice of a fused tensor. Under f16 precision
/// such projections are declared f16 args and uploaded f16-resident, halving
/// their per-step weight DRAM read. Gated identically in the emitter
/// (`take_weight`) and the loader (`weights::load_weights`) so the uploaded
/// buffer dtype always matches the emitted arg. Independent of quantization (a
/// dequantized 4-bit checkpoint is packed to f16 here just like a bf16 / f16
/// one); the packed in-graph path (`supports_packed_quant`) takes precedence
/// when its env opt-in is set, since it keeps the weights ui32-resident instead.
pub(crate) fn supports_f16_resident(&self) -> bool {
!self.fused_qkv && !self.fused_gate_up
}

/// Attention score scale. Granite supplies the raw multiplier directly
/// (`attention_multiplier`, which replaces `head_dim^-0.5`); Gemma2/3 use
/// `query_pre_attn_scalar^-0.5` (computed in f64 to match HF, since it can
Expand Down
22 changes: 21 additions & 1 deletion src/lib/mlxcel-xla/src/emitter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub(crate) use config::{MoeConfig, SharedExpertConfig};
// so allow the re-export to be unused in the other. `emit_decode_batched` (the
// superseded uniform-B Stage-1 graph) is re-exported for the validation harness.
#[allow(unused_imports)]
pub(crate) use builder::{quant_in_graph, resolve_precision};
pub(crate) use builder::{Precision, quant_in_graph, resolve_precision};
#[allow(unused_imports)]
pub(crate) use model::{
emit_decode, emit_decode_batched, emit_decode_ragged, emit_decode_ragged_with,
Expand Down Expand Up @@ -778,6 +778,26 @@ mod tests {
assert!(!emit_decode_with(&c, true, Precision::F32).contains("f16"));
}

/// issue #572: under f16 precision exactly the 7 per-layer linear projections
/// (down/gate/up/wk/wo/wq/wv) are declared f16 weight args (uploaded f16-resident);
/// embed / lm_head / norms / caches stay f32. Each signature arg renders as
/// `%argN: <ty> loc("...")`, so an f16 weight arg is the substring `xf16> loc(`;
/// counting them pins down that ONLY the projections are narrowed.
#[test]
fn f16_precision_declares_the_projection_weight_args_f16() {
use super::builder::Precision;
let c = Config::llama_3_2_1b();
let f16 = emit_prefill_with(&c, true, Precision::F16);
assert_eq!(
f16.matches("xf16> loc(").count(),
7 * c.n_layers,
"expected exactly the 7 per-layer projections as f16 weight args"
);
// The f32 default declares no f16 weight args (byte-exact goldens preserved).
let f32 = emit_prefill_with(&c, true, Precision::F32);
assert_eq!(f32.matches("xf16> loc(").count(), 0);
}

// ======================================================================
// dense arch pack (issue #498)
// ======================================================================
Expand Down
15 changes: 14 additions & 1 deletion src/lib/mlxcel-xla/src/emitter/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,20 @@ fn take_weight(
);
b.dequant_affine(&packed, &scales, &biases, qc.bits, qc.group_size)
}
_ => take_arg(decls, idx, Ty::f32(vec![out, in_]), loc),
// issue #572: on the f16 GPU path, declare the resident projection weight as
// an f16 arg (uploaded f16-resident by the loader) instead of an f32 arg the
// dot demotes in-graph. `dot_general` sees the same f16 operand either way
// (its f32->f16 convert is skipped for an already-f16 weight), so this stays
// token-exact while halving the weight's per-step DRAM read. f32 / bf16
// precision and the fused layouts keep the f32 arg (byte-identical goldens).
_ => {
let elt = if b.precision() == Precision::F16 && c.supports_f16_resident() {
"f16"
} else {
"f32"
};
take_arg(decls, idx, Ty::new(vec![out, in_], elt), loc)
}
}
}

Expand Down
35 changes: 30 additions & 5 deletions src/lib/mlxcel-xla/src/iree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ use memmap2::Mmap;
use safetensors::{Dtype, SafeTensors};

use crate::emitter::{
Config, emit_decode_ragged_with, emit_decode_with, emit_prefill_with, resolve_precision,
Config, Precision, emit_decode_ragged_with, emit_decode_with, emit_prefill_with,
resolve_precision,
};
// The loader reads the per-architecture checkpoint-weight order from
// `weights::weight_specs`, which sources its names from `weight_names::scheme_names`
Expand All @@ -67,7 +68,7 @@ use crate::emitter::{
// `dequantize_affine_stacked`). Both are pure-Rust and unit-tested without `iree`.
use crate::weights::{
QuantPart, WeightSpec, bf16_to_f32, dequantize_affine, dequantize_affine_stacked, f16_to_f32,
f32_le_to_f32, slice_rows, weight_specs,
f32_le_to_f32, pack_f16, slice_rows, weight_specs,
};

/// Weight-buffer element dtype passed to the C shim (issue #516 per-weight ABI):
Expand All @@ -79,10 +80,12 @@ const WDT_F16: c_int = 1;
const WDT_U32: c_int = 2;

/// A resident-weight host buffer, kept alive until the shim copies it to the device.
/// Either f32 values (a widened / dequantized weight, the proven path) or raw bytes
/// (an MLX packed-U32 weight or its f16 scales / biases, issue #516 packed path).
/// f32 values (a widened / dequantized weight, the proven path), f16 values (an
/// f16-resident projection, issue #572), or raw bytes (an MLX packed-U32 weight or
/// its f16 scales / biases, issue #516 packed path).
enum WeightBuf {
F32(Vec<f32>),
F16(Vec<u16>),
Raw(Vec<u8>),
}

Expand All @@ -91,6 +94,7 @@ impl WeightBuf {
fn as_u8_ptr(&self) -> *const u8 {
match self {
WeightBuf::F32(v) => v.as_ptr() as *const u8,
WeightBuf::F16(v) => v.as_ptr() as *const u8,
WeightBuf::Raw(v) => v.as_ptr(),
}
}
Expand Down Expand Up @@ -369,6 +373,7 @@ fn resolve_weight_shards(model_dir: &Path, names: &[String]) -> Result<Vec<PathB
fn load_weights(
model_dir: &Path,
cfg: &Config,
resident_f16: bool,
) -> Result<(Vec<WeightBuf>, Vec<c_int>, Vec<c_int>, Vec<i64>), String> {
let specs = weight_specs(cfg);
let names: Vec<String> = specs.iter().map(|s| s.tensor_name().to_string()).collect();
Expand Down Expand Up @@ -538,6 +543,22 @@ fn load_weights(
}
bufs[i] = WeightBuf::F32(data);
}
// issue #572: a linear projection. On the f16 GPU path pack it to an
// f16 resident buffer (WDT_F16), matching the emitter's f16 weight arg
// and halving its per-step DRAM read; otherwise upload f32, identical
// to the old `Whole`. `data` / `shape` are the widened row-major weight.
WeightSpec::Proj(_) => {
ranks[i] = shape.len() as c_int;
for (k, &s) in shape.iter().enumerate() {
dims[i * 4 + k] = s as i64;
}
if resident_f16 {
dtypes[i] = WDT_F16;
bufs[i] = WeightBuf::F16(pack_f16(&data));
} else {
bufs[i] = WeightBuf::F32(data);
}
}
WeightSpec::Rows { start, end, .. } => {
if shape.len() != 2 {
return Err(format!(
Expand Down Expand Up @@ -577,7 +598,11 @@ fn create_ctx(
prefill_vmfb: &Path,
decode_vmfb: &Path,
) -> Result<*mut XlaCtx, String> {
let (bufs, dtypes, ranks, dims) = load_weights(model_dir, cfg)?;
// issue #572: the f16 GPU path uploads the projection weights f16-resident to
// match the emitter's f16 args. Uses the same resolve_precision(device) the graph
// emit uses, so the uploaded buffer dtype always lines up with the emitted arg.
let resident_f16 = resolve_precision(device) == Precision::F16 && cfg.supports_f16_resident();
let (bufs, dtypes, ranks, dims) = load_weights(model_dir, cfg, resident_f16)?;
let ptrs: Vec<*const c_void> = bufs
.iter()
.map(|b| b.as_u8_ptr() as *const c_void)
Expand Down
132 changes: 124 additions & 8 deletions src/lib/mlxcel-xla/src/weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ pub(crate) enum WeightSpec {
/// Load the whole checkpoint tensor `name` (widened to f32, an MLX-quantized
/// U32 tensor dequantized to f32).
Whole(String),
/// Load the whole checkpoint tensor `name` for a linear projection the emitter
/// took via `take_weight` (issue #572). Widened to f32 like [`WeightSpec::Whole`],
/// but the loader packs it to an f16 resident buffer when the f16 GPU path is
/// active (`Config::supports_f16_resident` + f16 precision), matching the emitter's
/// f16 weight arg; otherwise it uploads f32, byte-identical to the old `Whole`.
/// Kept distinct from `Whole` (embed / norm / lm_head), which stays f32-resident.
Proj(String),
/// Load rows `[start, end)` of the checkpoint tensor `name` (a fused Phi3
/// projection, split into an emitter arg). Row-major `[out, in]`, so this is
/// the `[start, end)` slice of the `out` axis.
Expand Down Expand Up @@ -65,6 +72,7 @@ impl WeightSpec {
pub(crate) fn tensor_name(&self) -> &str {
match self {
WeightSpec::Whole(n) => n,
WeightSpec::Proj(n) => n,
WeightSpec::Rows { name, .. } => name,
WeightSpec::QuantRaw { name, .. } => name,
}
Expand Down Expand Up @@ -113,7 +121,7 @@ fn push_proj(out: &mut Vec<WeightSpec>, name: String, quant: bool) {
part: QuantPart::Biases,
});
} else {
out.push(WeightSpec::Whole(name));
out.push(WeightSpec::Proj(name));
}
}

Expand Down Expand Up @@ -368,6 +376,65 @@ pub(crate) fn f32_le_to_f32(bytes: &[u8]) -> Vec<f32> {
.collect()
}

/// One f32 -> IEEE 754 half (f16) bit pattern, round-to-nearest, ties-to-even (the
/// IEEE default), matching a `stablehlo.convert` f32 -> f16. So a projection weight
/// packed here (issue #572, f16-resident) is bit-identical to demoting the same f32
/// weight inside the graph, and the contraction sees the same f16 operand and stays
/// token-exact. It is the exact inverse of [`half_to_f32`]:
/// `f32_to_f16_bits(half_to_f32(h)) == h` for every finite, non-NaN f16 `h`.
pub(crate) fn f32_to_f16_bits(x: f32) -> u16 {
let bits = x.to_bits();
let sign = ((bits >> 16) & 0x8000) as u16;
let abs = bits & 0x7fff_ffff;

// NaN / Inf (f32 exponent all ones): NaN -> a canonical quiet f16 NaN, Inf -> f16 Inf.
if abs >= 0x7f80_0000 {
return sign | if abs > 0x7f80_0000 { 0x7e00 } else { 0x7c00 };
}

// f16 biased exponent = (f32 biased exponent - 127) + 15.
let e = (abs >> 23) as i32 - 127 + 15;

if e >= 0x1f {
return sign | 0x7c00; // overflow -> Inf
}

if e <= 0 {
// Subnormal f16, or underflow to a signed zero.
if e < -10 {
return sign; // below half the smallest subnormal -> +/- 0
}
// 24-bit significand (implicit leading 1), shifted into the subnormal range
// and rounded to nearest, ties to even.
let mant = (abs & 0x007f_ffff) | 0x0080_0000;
let shift = (14 - e) as u32; // e in [-10, 0] -> shift in [14, 24]
let q = mant >> shift;
let rem = mant & ((1 << shift) - 1);
let half = 1u32 << (shift - 1);
let round = u32::from(rem > half || (rem == half && q & 1 == 1));
// q + round may reach 0x400, which is exactly the smallest normal (correct).
return sign | (q + round) as u16;
}

// Normal f16: keep the top 10 mantissa bits, round to nearest even on bit 12. A
// mantissa carry rolls into the exponent, an exponent carry into 0x7c00 (Inf) --
// both the correct results.
let mant = abs & 0x007f_ffff;
let base = ((e as u32) << 10) | (mant >> 13);
let rem = mant & 0x1fff; // the 13 dropped low bits
let half = 0x1000u32; // 1 << 12
let round = u32::from(rem > half || (rem == half && base & 1 == 1));
sign | (base + round) as u16
}

/// Pack a row-major f32 weight to its little-endian f16 bit pattern for an
/// f16-resident device upload (issue #572), via [`f32_to_f16_bits`] (RNE, matching
/// the in-graph demotion). The `u16` values are native-endian; the shim copies the
/// raw bytes, so on a little-endian host they land as the f16 buffer IREE expects.
pub(crate) fn pack_f16(data: &[f32]) -> Vec<u16> {
data.iter().map(|&x| f32_to_f16_bits(x)).collect()
}

/// Dequantize one MLX affine-quantized weight to row-major `[out, in]` f32.
///
/// `packed` is the row-major `[out, in_packed]` u32 weight (little-endian bytes,
Expand Down Expand Up @@ -512,13 +579,23 @@ pub(crate) fn dequantize_affine_stacked(
mod tests {
use super::*;

/// The Llama family weight order is the legacy all-`Whole` sequence (embed,
/// norm, then 9 per layer), so the #498 spec loader is byte-identical for it.
/// The Llama family weight order is embed + norm (f32-resident `Whole`) then, per
/// layer, the 7 linear projections as `Proj` (f16-resident-capable, issue #572)
/// and the 2 norms as `Whole`, in the fixed order below. Names / order unchanged.
#[test]
fn weight_specs_llama_is_the_legacy_whole_order() {
fn weight_specs_llama_projections_are_proj_norms_are_whole() {
let c = Config::llama_3_2_1b();
let specs = weight_specs(&c);
assert!(specs.iter().all(|s| matches!(s, WeightSpec::Whole(_))));
// Every projection weight (`*_proj.weight`) is `Proj`; embed / norm / the
// per-layer norms are `Whole`. No fused rows or quant parts for plain Llama.
for s in &specs {
let n = s.tensor_name();
match s {
WeightSpec::Proj(_) => assert!(n.ends_with("_proj.weight"), "Proj {n}"),
WeightSpec::Whole(_) => assert!(!n.ends_with("_proj.weight"), "Whole {n}"),
other => panic!("unexpected spec {other:?}"),
}
}
assert_eq!(specs.len(), 2 + 9 * c.n_layers);
let names: Vec<&str> = specs.iter().map(WeightSpec::tensor_name).collect();
assert_eq!(names[0], "model.embed_tokens.weight");
Expand All @@ -539,6 +616,43 @@ mod tests {
);
}

#[test]
fn f32_to_f16_round_trips_every_finite_half() {
// half_to_f32(h) is exactly representable, so packing it back must recover h
// for every finite, non-NaN f16 pattern (signed zeros, subnormals, normals).
for h in 0u16..=u16::MAX {
if (h >> 10) & 0x1f == 0x1f {
continue; // skip Inf / NaN (NaN is non-canonical); Inf covered below
}
assert_eq!(
f32_to_f16_bits(half_to_f32(h)),
h,
"round-trip failed for f16 bits {h:#06x}"
);
}
}

#[test]
fn f32_to_f16_rounds_ties_to_even_and_saturates() {
// Exact tie between 1.0 (0x3c00, even mantissa) and its successor 0x3c01
// rounds down to the even neighbour; the tie one step up rounds to 0x3c02.
assert_eq!(
f32_to_f16_bits((half_to_f32(0x3c00) + half_to_f32(0x3c01)) / 2.0),
0x3c00
);
assert_eq!(
f32_to_f16_bits((half_to_f32(0x3c01) + half_to_f32(0x3c02)) / 2.0),
0x3c02
);
// Above the f16 max (65504) saturates to +/-Inf; f16 max itself is exact.
assert_eq!(f32_to_f16_bits(65504.0), 0x7bff);
assert_eq!(f32_to_f16_bits(70000.0), 0x7c00);
assert_eq!(f32_to_f16_bits(-70000.0), 0xfc00);
// Signed zero.
assert_eq!(f32_to_f16_bits(0.0), 0x0000);
assert_eq!(f32_to_f16_bits(-0.0), 0x8000);
}

/// Phi3 row-slices the fused `qkv_proj` ([Q|K|V]) and `gate_up_proj` (gate then
/// up) into the emitter's separate args, and is untied (`lm_head` after norm).
#[test]
Expand Down Expand Up @@ -886,12 +1000,14 @@ mod tests {
assert!(
matches!(&specs[i + 2], WeightSpec::QuantRaw { name, part: QuantPart::Biases } if name.ends_with("q_proj.biases"))
);
// quant = false is the legacy all-Whole order (purely additive packed path).
// quant = false is the unquantized layout: projections are `Proj` (issue
// #572), everything else `Whole`, and there are no packed parts (the packed
// path is purely additive).
assert!(
weight_specs_q(&c, false)
.iter()
.all(|s| matches!(s, WeightSpec::Whole(_))),
"unquantized layout is unchanged"
.all(|s| matches!(s, WeightSpec::Whole(_) | WeightSpec::Proj(_))),
"unquantized layout has no packed parts"
);
}
}