From 3670b8f77b96cfe05cad5a3f2122b929560f1b01 Mon Sep 17 00:00:00 2001 From: Christian Butterweck Date: Sat, 2 May 2026 19:26:40 +0200 Subject: [PATCH] fix: prevent F.linear from saving dequantized weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MatMul4Bit and MatMul8bitLt dequantize NF4/INT8 weights to bf16, then call F.linear which internally saves the bf16 weight for backward. But both backprop functions re-dequantize from the stored quantization state anyway. The saved bf16 weight (~0.5 GB per layer for 9B models) accumulates across all layers during forward, causing excessive VRAM usage during QLoRA training. Fix: wrap F.linear in torch.no_grad() to prevent intermediate autograd node creation. Backward path unchanged — already correct. Result: ~0.5 GB per layer VRAM reduction with zero quality impact. --- bitsandbytes/autograd/_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..91d8e09a4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -262,7 +262,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): B = state.CB CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - output = torch.nn.functional.linear(A, CB, bias) + with torch.no_grad(): + output = torch.nn.functional.linear(A, CB, bias) ctx.state = state ctx.dtype_A = A.dtype ctx.grad_shape = A.shape @@ -317,7 +318,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + # no_grad prevents F.linear from saving the dequantized weight + # (MatMul4Bit.backward re-dequantizes from NF4 anyway) + with torch.no_grad(): + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) if out is not None: out.copy_(output) output = out