Bug: INT8 + Gradient Checkpointing state machine breaks during recompute
Environment
- bitsandbytes latest (tested 0.43.x / 0.44.x)
- transformers Gemma4 26B-A4B (MoE, 4B active params)
- PEFT LoRA r=16, task_type=CAUSAL_LM
- RTX 4090 (24GB) + 60GB CPU RAM offload via
accelerate
- Gradient Checkpointing enabled
Problem 1: CB/SCB state machine — autograd/_functions.py
Gradient Checkpointing re-runs the forward pass during backward. On the recompute pass, state.CB is already populated from the first forward pass. The existing code path that builds CB/SCB only handles the case where the weight is still in BF16 — it doesn't handle pre-quantized Int8Params where B.dtype == torch.int8.
Result: NaN loss, or SCB lookup fails on recompute, or wrong SCB values get used.
Fix (3 sites in autograd/_functions.py — P1, P9, P2):
# Before (existing code):
else:
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
# After:
if B.dtype == torch.int8:
state.CB = B.data.clone()
if hasattr(B, "SCB") and B.SCB is not None:
state.SCB = B.SCB.to(state.CB.device)
else:
state.SCB = state.CB.float().abs().amax(dim=1).div(127.0).to(state.CB.device)
else:
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
The same pattern appears in 3 places (the matmul_4bit path, the forward path, and the recompute path). All three need the guard.
Problem 2: Meta-device SCB AttributeError — nn/modules.py
During model init with device_map="auto" and INT8, some layers land on "meta" device. When nn/modules.py accesses B.SCB on a meta tensor, it raises:
AttributeError: 'Int8Params' object has no attribute 'SCB'
Fix (P3):
# In Int8Params.cuda() / to() — guard before SCB access:
if not hasattr(self, "SCB") or self.SCB is None:
self.SCB = self.float().abs().amax(dim=1).div(127.0).to(self.device)
Critical interaction: model.train() must precede gradient_checkpointing_enable()
This is not strictly a bitsandbytes bug, but worth documenting here because it interacts with INT8 GC:
# WRONG — GC hooks never register, CB accumulates for ALL layers → OOM
model.gradient_checkpointing_enable()
model.train()
# CORRECT
model.train()
model.gradient_checkpointing_enable()
Without model.train() first, requires_grad isn't set when GC scans the graph → GC silently does nothing → every layer's state.CB accumulates → OOM.
Full working example
All patches + a complete training script for Gemma4 26B on RTX 4090:
https://github.com/sirfyyn/consumer-llm-patches
Benchmark: ~6.25s/step at 512 tokens with 10 CPU-offloaded layers. Step time nearly flat across seq lengths (CPU→GPU transfer dominates, not compute).
Happy to submit a PR if this approach looks right to you.
Bug: INT8 + Gradient Checkpointing state machine breaks during recompute
Environment
accelerateProblem 1: CB/SCB state machine —
autograd/_functions.pyGradient Checkpointing re-runs the forward pass during backward. On the recompute pass,
state.CBis already populated from the first forward pass. The existing code path that buildsCB/SCBonly handles the case where the weight is still in BF16 — it doesn't handle pre-quantizedInt8ParamswhereB.dtype == torch.int8.Result: NaN loss, or
SCBlookup fails on recompute, or wrong SCB values get used.Fix (3 sites in
autograd/_functions.py— P1, P9, P2):The same pattern appears in 3 places (the
matmul_4bitpath, the forward path, and the recompute path). All three need the guard.Problem 2: Meta-device SCB AttributeError —
nn/modules.pyDuring model init with
device_map="auto"and INT8, some layers land on"meta"device. Whennn/modules.pyaccessesB.SCBon a meta tensor, it raises:Fix (P3):
Critical interaction:
model.train()must precedegradient_checkpointing_enable()This is not strictly a bitsandbytes bug, but worth documenting here because it interacts with INT8 GC:
Without
model.train()first,requires_gradisn't set when GC scans the graph → GC silently does nothing → every layer'sstate.CBaccumulates → OOM.Full working example
All patches + a complete training script for Gemma4 26B on RTX 4090:
https://github.com/sirfyyn/consumer-llm-patches
Benchmark: ~6.25s/step at 512 tokens with 10 CPU-offloaded layers. Step time nearly flat across seq lengths (CPU→GPU transfer dominates, not compute).
Happy to submit a PR if this approach looks right to you.