Skip to content

Add MLX_NUMERICAL_STRICT_MODE for shape-independent quantized_matmul#3473

Open
rakshith48 wants to merge 1 commit intoml-explore:mainfrom
rakshith48:numerical-strict-mode
Open

Add MLX_NUMERICAL_STRICT_MODE for shape-independent quantized_matmul#3473
rakshith48 wants to merge 1 commit intoml-explore:mainfrom
rakshith48:numerical-strict-mode

Conversation

@rakshith48
Copy link
Copy Markdown

Upstream issue + PR for ml-explore/mlx

Issue title

quantized_matmul output depends on input shape (M dimension); add MLX_NUMERICAL_STRICT_MODE opt-in for path-independent output

Summary

mx.quantized_matmul dispatches to one of three GPU kernels based on the M (batch/sequence) dimension of the input:

  • qmv for M < vector_limit (~10–32 depending on K, N, arch)
  • qmm_splitk for vector_limit ≤ M < ~65 and transposed weights (B=1)
  • qmm for everything else

All three kernels accumulate K in fp32 internally (BlockMMA's AccumType=float, qmv uses fp32 accumulators). However, they use different reduction trees across K:

Kernel K reduction structure
qmm sequential register-fma chain over BK=32 fragments, in order
qmm_splitk partition K into split_k blocks, each accumulated independently in registers, summed in a separate reduction kernel
qmv parallelize K across simd lanes, collapse with simd_sum (hardware butterfly)

fp32 floating-point sum is not associative: (a + b) + c ≠ a + (b + c) in general. So three kernels computing the same dot product over the same K=4096 produce three different bit patterns, differing by ~1.5–2.7 × 10⁻⁵ per element.

For straight inference and training this is invisible. But it silently breaks any workload that compares two equivalent execution paths:

  • Prefix-cache reuse (e.g. vLLM-style inference engines): q_proj(x_full) vs q_proj(x_full[:, -L:]) should match for the overlapping tokens. They don't.
  • Batched-vs-streaming eval (lm-evaluation-harness, custom eval frameworks): batching changes M, output diverges.
  • Distillation / RLHF: teacher and student forward passes are expected to be deterministic functions of the same inputs.

Reproducer

Tested on MLX v0.31.2, Apple M1 16GB, Qwen3-8B-4bit (mlx-community/Qwen3-8B-4bit).

import mlx.core as mx
from mlx_lm import load

m, _ = load('mlx-community/Qwen3-8B-4bit')
attn = m.model.layers[0].self_attn

x_full = mx.random.normal(shape=(1, 2200, m.args.hidden_size)).astype(mx.float16)
full = attn.q_proj(x_full)

for L in [8, 16, 32, 48, 64, 65, 128]:
    sliced = attn.q_proj(x_full[:, -L:, :])
    diff = float(mx.max(mx.abs(full[:, -L:, :] - sliced)))
    print(f"q_proj L={L:3d}: max_abs_diff = {diff:.4e}")

Output (without strict mode):

q_proj L=  8: max_abs_diff = 1.9073e-05    ← qmv path
q_proj L= 16: max_abs_diff = 2.1458e-05    ← qmv path
q_proj L= 32: max_abs_diff = 2.1458e-05    ← qmv path
q_proj L= 48: max_abs_diff = 2.0504e-05    ← qmm_splitk path
q_proj L= 64: max_abs_diff = 2.0504e-05    ← qmm_splitk path
q_proj L= 65: max_abs_diff = 0.0000e+00    ← qmm path (matches reference)
q_proj L=128: max_abs_diff = 0.0000e+00    ← qmm path

Same pattern in k_proj with the boundary at L=256→257 (because N=1024 instead of 4096 changes the splitk threshold).

Important: this is NOT just an fp16 ULP issue

An earlier analysis suggested the bug was fp16 partial-sum rounding in splitk's intermediate buffer (with the fix being: store partials in fp32 instead of fp16). I implemented that fix and verified it produces bit-identical output to pristine on the actual reproducer above. The reason: with bf16 scales the model output is auto-promoted to fp32 by MLX, so there's no fp16 cast anywhere in the splitk path that would benefit from precision promotion. The ~2 × 10⁻⁵ diff is purely fp32 non-associativity from differing reduction trees.

The implication: promoting partial-sums to fp32 does not fix the bit-equivalence problem for fp32-output paths. Only matching the reduction tree (or skipping the fast paths entirely) gives bit-equivalence.

Proposed fix: MLX_NUMERICAL_STRICT_MODE opt-in

Add an opt-in env-var-controlled flag matching MLX's existing convention (MLX_ENABLE_TF32, MLX_METAL_FAST_SYNCH, etc.):

mlx/utils.h — add helper:

inline bool numerical_strict_mode() {
  static bool numerical_strict_mode_ = get_var("MLX_NUMERICAL_STRICT_MODE", 0);
  return numerical_strict_mode_;
}

mlx/backend/metal/quantized.cpp::QuantizedMatmul::eval_gpu — gate at top of dispatch:

if (env::numerical_strict_mode()) {
  qmm(x, w, scales, biases, out, transpose_, group_size_, bits_,
      M, N, K, d, s, mode);
  return;
}
// ... existing qmv / qmm_splitk / qmm dispatch logic ...

That single gate at eval_gpu top covers all three shape-dependent paths (qmv, qmm_splitk, qvm_split_k) because qmm is the canonical reference.

Validation

With MLX_NUMERICAL_STRICT_MODE=1:

q_proj L=  8: max_abs_diff = 0.0000e+00
q_proj L= 16: max_abs_diff = 0.0000e+00
q_proj L= 32: max_abs_diff = 0.0000e+00
q_proj L= 48: max_abs_diff = 0.0000e+00
q_proj L= 64: max_abs_diff = 0.0000e+00
q_proj L= 65: max_abs_diff = 0.0000e+00
q_proj L=128: max_abs_diff = 0.0000e+00

Bit-identical at every L for every projection.

Performance cost (honest numbers)

Measured on M1 16GB with Qwen3-8B-4bit:

Workload OFF ON Slowdown
Short-prompt decode (~10 tok prefix) 2.18 tok/s 0.96 tok/s 2.3× slower
Long-prompt decode (~80 tok prefix) 2.39 tok/s 1.51 tok/s 1.6× slower

The decode-loop slowdown is significant because qmv is heavily optimized for M=1 generation; bypassing it forces qmm to use a 32×32 tile for a single output row.

This is why the flag is opt-in, not default-on. For users running:

  • Single-stream chat inference: don't enable. Speed matters; bit-equivalence doesn't.
  • Eval frameworks comparing batched and streaming runs: enable. Decode tok/s isn't the bottleneck; correctness is.
  • Distillation/RLHF teacher-student loops: enable. You need deterministic forward passes.
  • Prefix-cache reuse engines: enable for any path-comparison testing; can disable in production hot loops.

Files changed

  • mlx/utils.h — +20 LoC (env helper + comment block)
  • mlx/backend/metal/quantized.cpp — +18 LoC (gate at QuantizedMatmul::eval_gpu)

Total: ~40 LoC. No new kernels, no new tests of existing behavior, no breaking changes. Off-by-default → zero impact on any existing user.

Test

Reproducer in mac-llm-bench/eval/test_numerical_strict_mode.py — script that runs the boundary check in both modes:

  • Without flag: prints diffs at all small-L boundaries (informational, no assertion)
  • With flag: asserts diff == 0.0 at every L; exits 1 if any L fails

Reference

Hardware: Apple M1 (16 GB unified memory)
MLX version: 0.31.2
Model: mlx-community/Qwen3-8B-4bit

Discovered while building an Apple Silicon eval harness for Qwen3-8B (mac-llm-bench). The original symptom was a ~0.5pp MMLU accuracy regression when prefix-cache reuse was enabled; root-caused to this path-dependence by sweeping L and observing the boundary at L=64 (q/o_proj) and L=256 (k/v_proj) match the dispatcher's split_k threshold.

…d_matmul

quantized_matmul currently dispatches to one of three GPU kernels based on
input M (qmv, qmm_splitk, qmm), each using a different K-reduction tree.
fp32 sum is non-associative so the three paths produce slightly different
bit patterns (~1.5-2.7e-5 per element) for the same dot product.

For inference and training this is invisible. But it silently breaks any
workload that compares two equivalent execution paths:
- prefix-cache reuse (vLLM-style engines)
- batched-vs-streaming eval comparison (lm-evaluation-harness)
- distillation/RLHF teacher-student forward-pass equality

Adds an opt-in env-var-gated flag matching MLX's existing convention
(MLX_ENABLE_TF32, etc.). When set, QuantizedMatmul and GatherQMM force
the no-split qmm/gather_qmm reference paths so output is bit-identical
regardless of M.

Cost: ~1.5-2.3x slower decode at M=1 (qmv is heavily optimized for the
single-token case). Off by default; users opting in are explicitly
trading throughput for correctness.

Files changed:
- mlx/utils.h (+19 LoC): env::numerical_strict_mode() helper
- mlx/backend/metal/quantized.cpp (+53 LoC): gates in QuantizedMatmul::eval_gpu
  and GatherQMM::eval_gpu
@AirRunner
Copy link
Copy Markdown

I hit a severe manifestation of this bug in quantized_scaled_dot_product_attention when using GQA with expand_dims broadcasting. The error magnitude is orders of magnitude larger than the ~2e-5 observed in the q_proj case, enough to completely corrupt token distributions and cause repetition loops.

Setup: a 2-token verification pass (M=2). Queries shape (1, 4, 4, 2, 256) (B, n_kv_heads, n_repeats, M=2, D), quantized keys (1, 4, 1, N, ...) with stride-0 expand_dims on the n_repeats dim (GQA broadcasting). 8-bit keys, 4-bit values.

Reproducer (no model required, tested on MLX 0.31.1 and 0.31.2)

import mlx.core as mx
from mlx.utils import tree_map

B, n_kv_heads, n_repeats, D = 1, 4, 4, 256
n_q_heads = n_kv_heads * n_repeats
key_bits, value_bits, group_size = 8, 4, 64

mx.random.seed(42)

def run(N, M):
    keys_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
    values_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
    q_keys = mx.quantize(keys_f, group_size=group_size, bits=key_bits)
    q_values = mx.quantize(values_f, group_size=group_size, bits=value_bits)
    queries = (mx.random.normal((B, n_q_heads, M, D)) * D**-0.5).astype(mx.float16)
    mx.eval(q_keys, q_values, queries)

    # reference: dequantize then float matmul
    keys_dq = mx.dequantize(*q_keys, group_size=group_size, bits=key_bits)
    values_dq = mx.dequantize(*q_values, group_size=group_size, bits=value_bits)
    qr = queries.reshape(B, n_kv_heads, n_repeats, M, D)
    s_ref = mx.softmax(qr @ keys_dq[:,:,None,:,:].transpose(0,1,2,4,3), axis=-1)
    out_ref = (s_ref @ values_dq[:,:,None,:,:]).reshape(B, n_q_heads, M, D)

    # quantized_matmul with expand_dims broadcast (GQA)
    qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
    qv_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
    s_qmm = mx.softmax(mx.quantized_matmul(qr, *qk_e, transpose=True, group_size=group_size, bits=key_bits), axis=-1)
    out_qmm = mx.quantized_matmul(s_qmm, *qv_e, transpose=False, group_size=group_size, bits=value_bits).reshape(B, n_q_heads, M, D)

    mx.eval(out_ref, out_qmm)
    diff = mx.max(mx.abs(out_ref.astype(mx.float32) - out_qmm.astype(mx.float32))).item()
    print(f"N={N:5d} M={M}: max_diff={diff:.6f}")

for N in [512, 2048, 4096, 7358]:
    for M in [1, 2]:
        run(N, M)

Output on M4 Pro

N=  512 M=1: max_diff=0.000244   # within quantization noise
N=  512 M=2: max_diff=0.000244   # ok, qmm path
N= 2048 M=1: max_diff=0.000122
N= 2048 M=2: max_diff=0.140656   # qmv triggered, error grows
N= 4096 M=1: max_diff=0.000092
N= 4096 M=2: max_diff=0.118958   # wrong
N= 7358 M=1: max_diff=0.000061
N= 7358 M=2: max_diff=0.079422   # wrong, mean_diff ~0.009

In actual inference logs (real model, N=7358), I measured max_diff up to 12.46 between the two paths in certain attention layers, far past the threshold where softmax distributions are corrupted and the model produces severely degraded output (repetition loops, wrong task execution).

On the fix

MLX_NUMERICAL_STRICT_MODE as written would cover this case, but the 2.3x decode slowdown from bypassing qmv applies globally, even to calls with no stride-0 tensors.

For the GQA case specifically, a more targeted fix (detecting zero strides on the batch dimensions of w in eval_gpu before the dispatch) would handle this case with no env var and no penalty on standard M=1 decode. The full root cause analysis and proposed fix is in #3480.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants