From 50b2b483f1dbf9412c1e38b2f92637c25d663143 Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 28 May 2026 09:27:39 -0700 Subject: [PATCH] fix(nn/Linear4bit): consume QuantState keys in _load_from_state_dict (#1946) Linear4bit overrides _save_to_state_dict to write weight.absmax / weight.quant_map / weight.nested_* / weight.quant_state.bitsandbytes__* alongside the packed weight, but inherits nn.Linear._load_from_state_dict which only consumes weight and bias. Result: - strict=True load raises Unexpected key(s) in state_dict for every QuantState component. - strict=False silently drops them and the destination layer keeps the freshly-quantized quant_state from the prior .to('cuda') call, which does not match the packed bytes that were just loaded. This mirrors what Linear8bitLt already does for SCB (_load_from_state_dict at modules.py:1119): walk unexpected_keys for entries under 'weight.', collect them into a qs_dict, reconstruct via QuantState.from_dict, install on self.weight, and remove the consumed keys from unexpected_keys. Fixes #1946 --- bitsandbytes/nn/modules.py | 60 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ebc0b0943..1f9d86949 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -606,6 +606,66 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Consume the QuantState components written by ``_save_to_state_dict``. + + Without this, the ``weight.absmax`` / ``weight.quant_map`` / ... keys + land in ``unexpected_keys`` and ``load_state_dict(strict=True)`` raises; + with ``strict=False`` they are silently dropped and the loaded layer + keeps a freshly-quantized ``quant_state`` that does not match the + packed bytes that were just loaded. + """ + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + qs_prefix = prefix + "weight." + qs_dict: dict[str, torch.Tensor] = {} + consumed_keys: list[str] = [] + for key in list(unexpected_keys): + if not key.startswith(qs_prefix): + continue + qs_dict[key[len(qs_prefix):]] = state_dict[key] + consumed_keys.append(key) + + if not qs_dict: + return + + try: + quant_state = QuantState.from_dict(qs_dict=qs_dict, device=self.weight.device) + except Exception as exc: + error_msgs.append( + f"Linear4bit: failed to reconstruct QuantState from state_dict " + f"with prefix '{prefix}': {exc}" + ) + return + + self.weight.quant_state = quant_state + self.weight.bnb_quantized = True + self.weight.blocksize = quant_state.blocksize + self.weight.compress_statistics = quant_state.nested + self.weight.quant_type = quant_state.quant_type + self.quant_state = quant_state + + for key in consumed_keys: + unexpected_keys.remove(key) + def forward(self, x: torch.Tensor): fix_4bit_weight_quant_state_from_module(self) quant_state = self.weight.quant_state