Skip to content

INT8 + Gradient Checkpointing: CB/SCB state machine breaks on GC recompute (+ meta-device SCB AttributeError) #1927

@sirfyyn

Description

@sirfyyn

Bug: INT8 + Gradient Checkpointing state machine breaks during recompute

Environment

  • bitsandbytes latest (tested 0.43.x / 0.44.x)
  • transformers Gemma4 26B-A4B (MoE, 4B active params)
  • PEFT LoRA r=16, task_type=CAUSAL_LM
  • RTX 4090 (24GB) + 60GB CPU RAM offload via accelerate
  • Gradient Checkpointing enabled

Problem 1: CB/SCB state machine — autograd/_functions.py

Gradient Checkpointing re-runs the forward pass during backward. On the recompute pass, state.CB is already populated from the first forward pass. The existing code path that builds CB/SCB only handles the case where the weight is still in BF16 — it doesn't handle pre-quantized Int8Params where B.dtype == torch.int8.

Result: NaN loss, or SCB lookup fails on recompute, or wrong SCB values get used.

Fix (3 sites in autograd/_functions.py — P1, P9, P2):

# Before (existing code):
else:
    state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

# After:
if B.dtype == torch.int8:
    state.CB = B.data.clone()
    if hasattr(B, "SCB") and B.SCB is not None:
        state.SCB = B.SCB.to(state.CB.device)
    else:
        state.SCB = state.CB.float().abs().amax(dim=1).div(127.0).to(state.CB.device)
else:
    state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

The same pattern appears in 3 places (the matmul_4bit path, the forward path, and the recompute path). All three need the guard.

Problem 2: Meta-device SCB AttributeError — nn/modules.py

During model init with device_map="auto" and INT8, some layers land on "meta" device. When nn/modules.py accesses B.SCB on a meta tensor, it raises:

AttributeError: 'Int8Params' object has no attribute 'SCB'

Fix (P3):

# In Int8Params.cuda() / to() — guard before SCB access:
if not hasattr(self, "SCB") or self.SCB is None:
    self.SCB = self.float().abs().amax(dim=1).div(127.0).to(self.device)

Critical interaction: model.train() must precede gradient_checkpointing_enable()

This is not strictly a bitsandbytes bug, but worth documenting here because it interacts with INT8 GC:

# WRONG — GC hooks never register, CB accumulates for ALL layers → OOM
model.gradient_checkpointing_enable()
model.train()

# CORRECT
model.train()
model.gradient_checkpointing_enable()

Without model.train() first, requires_grad isn't set when GC scans the graph → GC silently does nothing → every layer's state.CB accumulates → OOM.

Full working example

All patches + a complete training script for Gemma4 26B on RTX 4090:
https://github.com/sirfyyn/consumer-llm-patches

Benchmark: ~6.25s/step at 512 tokens with 10 CPU-offloaded layers. Step time nearly flat across seq lengths (CPU→GPU transfer dominates, not compute).

Happy to submit a PR if this approach looks right to you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions