Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions mlx/backend/metal/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,32 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {

int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);

// Numerical strict mode (MLX_NUMERICAL_STRICT_MODE=1): bypass the
// shape-dependent fast paths (qmv, qmm_splitk, qvm_split_k) so output is
// bit-identical regardless of M. Costs ~1.5-2.3x slower decode at M=1
// (qmv is heavily optimized for the M=1 case) but gives path-independence
// required for prefix-cache reuse, batched-vs-streaming eval comparison,
// and distillation/RLHF teacher-student equality. See
// env::numerical_strict_mode() in mlx/utils.h.
if (env::numerical_strict_mode()) {
qmm(x,
w,
scales,
biases,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
d,
s,
mode);
return;
}

// It is a matrix matrix product.
if (M >= vector_limit) {
// Use split-K qmm for small M with transposed weights (non-batched only)
Expand Down Expand Up @@ -1477,6 +1503,33 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);

// Numerical strict mode (MLX_NUMERICAL_STRICT_MODE=1): bypass the
// shape-dependent fast paths (gather_qmv, gather_qvm, gather_qmm_rhs) so
// GatherQMM output is bit-identical regardless of M. Same justification as
// QuantizedMatmul::eval_gpu — gather_qmm is the reference path that uses
// sequential register-fma accumulation matching qmm. Necessary for
// bit-equivalence in MoE models (Mixtral-MoE, DeepSeek-MoE, etc.).
if (env::numerical_strict_mode()) {
gather_qmm(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
d,
s,
mode);
return;
}

// We are walking x in order and w is also in order so we can batch up the
// matmuls and reuse reading x and w.
//
Expand Down
19 changes: 19 additions & 0 deletions mlx/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ inline bool enable_tf32() {
return enable_tf32_;
}

// When set, QuantizedMatmul forces the no-split qmm path for all 2D shapes,
// bypassing the qmv (M < vector_limit) and qmm_splitk fast paths. This
// guarantees that quantized_matmul output is independent of input shape:
// q_proj(x[:, -L:]) is bit-identical to q_proj(x)[:, -L:] for any L.
//
// The fast paths use parallel reductions across K (simd-butterfly in qmv,
// partition-then-sum in splitk) which produce different fp32 sums than qmm's
// sequential register-level accumulation. Even when both paths use fp32
// throughout, fp32 is non-associative so the bit patterns differ by ~ULP.
//
// This bites workloads that compare two equivalent paths — prefix-cache reuse,
// batched-vs-streaming eval, distillation/RLHF teacher-student equality. For
// straight inference / training the diff is invisible. Off by default.
inline bool numerical_strict_mode() {
static bool numerical_strict_mode_ =
get_var("MLX_NUMERICAL_STRICT_MODE", 0);
return numerical_strict_mode_;
}

inline int nccl_timeout(int default_value) {
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
return nccl_timeout;
Expand Down