diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..91d8e09a4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -262,7 +262,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): B = state.CB CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - output = torch.nn.functional.linear(A, CB, bias) + with torch.no_grad(): + output = torch.nn.functional.linear(A, CB, bias) ctx.state = state ctx.dtype_A = A.dtype ctx.grad_shape = A.shape @@ -317,7 +318,10 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + # no_grad prevents F.linear from saving the dequantized weight + # (MatMul4Bit.backward re-dequantizes from NF4 anyway) + with torch.no_grad(): + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) if out is not None: out.copy_(output) output = out