From bcd6d35ab12958bc886447613df44009880384d5 Mon Sep 17 00:00:00 2001 From: rapsealk Date: Thu, 23 Apr 2026 16:20:15 +0900 Subject: [PATCH] Rewrite asserts as exceptions Mechanically rewrite `assert cond[, msg]` in library code as `if not cond: raise AssertionError[(msg)]` so the checks still fire under `python -O`, which strips asserts. Test files are left alone since pytest relies on `assert` for introspection. Closes #1408. Assisted-by: Claude Opus 4.7 (1M context) --- bitsandbytes/autograd/_functions.py | 3 +- bitsandbytes/diagnostics/main.py | 3 +- bitsandbytes/functional.py | 54 +++++++++++++++++++---------- bitsandbytes/nn/modules.py | 24 ++++++++----- bitsandbytes/nn/parametrize.py | 6 ++-- bitsandbytes/optim/lars.py | 3 +- bitsandbytes/optim/optimizer.py | 9 +++-- bitsandbytes/utils.py | 6 ++-- check_bnb_install.py | 3 +- 9 files changed, 74 insertions(+), 37 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..d90d783f8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -381,7 +381,8 @@ def matmul_4bit( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ): - assert quant_state is not None + if quant_state is None: + raise AssertionError if A.device.type == "cpu": if getattr(quant_state, "packing_format_for_cpu", False): out = F.gemv_4bit(A, B, out, state=quant_state) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 74da662b6..c51768c83 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -39,7 +39,8 @@ def sanity_check(): loss.backward() adam.step() p2 = p.data.sum().item() - assert p1 != p2 + if p1 == p2: + raise AssertionError def get_package_version(name: str) -> str: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0165a1288..5ce61390d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -99,7 +99,8 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): def prefetch_tensor(A: torch.Tensor, to_cpu=False): - assert A.is_paged, "Only paged tensors can be prefetched!" + if not A.is_paged: + raise AssertionError("Only paged tensors can be prefetched!") if to_cpu: deviceid = -1 else: @@ -218,7 +219,8 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): values = values.sort().values values /= values.max() - assert values.numel() == 256 + if values.numel() != 256: + raise AssertionError return values @@ -254,7 +256,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e + p == total_bits - has_sign + if e + p != total_bits - has_sign: + raise AssertionError # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): @@ -279,7 +282,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) if signed: values.append(-value) - assert len(values) == 2**total_bits + if len(values) != 2**total_bits: + raise AssertionError values.sort() if total_bits < 8: gap = 256 - len(values) @@ -337,7 +341,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.append(0) data.append(1.0) - assert len(data) == 2**total_bits + if len(data) != 2**total_bits: + raise AssertionError gap = 256 - len(data) for i in range(gap): @@ -516,7 +521,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes - assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + if not set(qs_dict.keys()).issubset(cls.valid_qs_keys): + raise AssertionError if "nested_absmax" in qs_dict: offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) @@ -721,7 +727,8 @@ def dequantize_blockwise( The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. """ - assert quant_state is not None or absmax is not None + if quant_state is None and absmax is None: + raise AssertionError if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -842,7 +849,8 @@ def get_4bit_type(typename, device=None, blocksize=64): data = torch.tensor(data, device=device) data.div_(data.abs().max()) - assert data.numel() == 16 + if data.numel() != 16: + raise AssertionError return data @@ -1009,7 +1017,8 @@ def dequantize_4bit( blocksize = 64 if quant_state is None: - assert absmax is not None and out is not None + if absmax is None or out is None: + raise AssertionError quant_state = QuantState( absmax=absmax, @@ -1365,7 +1374,8 @@ def igemm( ldc = sB[1] elif len(sB) == 3: # special case - assert len(sA) == 3 + if len(sA) != 3: + raise AssertionError if not (sA[0] == sB[0] and sA[1] == sB[1]): raise ValueError( f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}", @@ -1658,10 +1668,13 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat unpacked_w[::2] = qweight >> 4 qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K) # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit - assert len(qweight_final.shape) == 2 + if len(qweight_final.shape) != 2: + raise AssertionError N, K = qweight_final.shape[0], qweight_final.shape[1] - assert N % block_n == 0, "N must be divisible by block_n" - assert K % 2 == 0, "K must be even" + if N % block_n != 0: + raise AssertionError("N must be divisible by block_n") + if K % 2 != 0: + raise AssertionError("K must be even") BLOCK_N = block_n BIT_COUNT = 32 # (=32 low +32 high) new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] @@ -1706,9 +1719,12 @@ def _convert_weight_packed_for_cpu_inverse( qweight: [*, N, K] uint8, original qweight shape (quant_state.shape) recovered_state: QuantState with partially restored fields (best-effort inverse) """ - assert quant_state.packing_format_for_cpu, "only for packing format" - assert packed_weight.dtype == torch.uint8 - assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]" + if not quant_state.packing_format_for_cpu: + raise AssertionError("only for packing format") + if packed_weight.dtype != torch.uint8: + raise AssertionError + if len(packed_weight.shape) != 2: + raise AssertionError("packed_weight should be [N, K/2]") N, K_half = packed_weight.shape K = K_half * 2 @@ -1716,8 +1732,10 @@ def _convert_weight_packed_for_cpu_inverse( BLOCK_N = block_n BIT_COUNT = 32 # (=32 low + 32 high) - assert N % BLOCK_N == 0, "N must be divisible by block_n" - assert K % 2 == 0, "K must be even" + if N % BLOCK_N != 0: + raise AssertionError("N must be divisible by block_n") + if K % 2 != 0: + raise AssertionError("K must be even") # [N, K/2] -> [-1, 64] (32 low + 32 high) packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bfd41d5dd..431b1d26a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -495,7 +495,8 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line # the quant state got lost when the parameter got converted. This happens for example for fsdp # since we registered the module, we can recover the state here - assert module.weight.shape[1] == 1 + if module.weight.shape[1] != 1: + raise AssertionError if not isinstance(module.weight, Params4bit): module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True) module.weight.quant_state = module.quant_state @@ -866,8 +867,10 @@ def forward(self, input: Tensor) -> Tensor: rows = self.weight.data row_stats = self.weight.SCB - assert rows.shape == (self.num_embeddings, self.embedding_dim) - assert row_stats.shape == (self.num_embeddings,) + if rows.shape != (self.num_embeddings, self.embedding_dim): + raise AssertionError + if row_stats.shape != (self.num_embeddings,): + raise AssertionError compressed_output = F.embedding(input, rows) compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1)) @@ -928,7 +931,8 @@ def __init__( ) def _forward_with_partial_dequantize(self, input: Tensor): - assert self.embedding_dim % self.weight.quant_state.blocksize == 0 + if self.embedding_dim % self.weight.quant_state.blocksize != 0: + raise AssertionError w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1) @@ -936,12 +940,14 @@ def _forward_with_partial_dequantize(self, input: Tensor): weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2), input=input, ).view(-1, 1) - assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1) + if output_4bit.shape != (input.numel() * self.embedding_dim // 2, 1): + raise AssertionError blocks_per_emb = self.embedding_dim // self.weight.blocksize absmax = self.weight.quant_state.absmax - assert absmax.shape == (self.num_embeddings * blocks_per_emb,) + if absmax.shape != (self.num_embeddings * blocks_per_emb,): + raise AssertionError output_absmax = torch.nn.functional.embedding( weight=absmax.view(self.num_embeddings, blocks_per_emb), @@ -949,14 +955,16 @@ def _forward_with_partial_dequantize(self, input: Tensor): ).view( -1, ) - assert output_absmax.shape == (input.numel() * blocks_per_emb,) + if output_absmax.shape != (input.numel() * blocks_per_emb,): + raise AssertionError output_quant_state = copy.deepcopy(self.weight.quant_state) output_quant_state.absmax = output_absmax output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim)) output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state) - assert output.shape == (*input.shape, self.embedding_dim) + if output.shape != (*input.shape, self.embedding_dim): + raise AssertionError return output.to(self.dtype) diff --git a/bitsandbytes/nn/parametrize.py b/bitsandbytes/nn/parametrize.py index 4a956c7fa..4877f1d52 100644 --- a/bitsandbytes/nn/parametrize.py +++ b/bitsandbytes/nn/parametrize.py @@ -175,14 +175,16 @@ def _parametrized_state_dict_post_hook( clean_key = f"{prefix}{param_name}" state_dict[clean_key] = state_dict.pop(original_key) - assert P.is_parametrized(module, param_name) + if not P.is_parametrized(module, param_name): + raise AssertionError # Find the parametrization, which should have the quantization state. parametrization: Bnb4bitParametrization = next( filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None ) - assert parametrization is not None, "Parametrization not found for the parameter." + if parametrization is None: + raise AssertionError("Parametrization not found for the parameter.") quant_state = parametrization.quant_state diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index c2f5aa784..339d21c37 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -248,7 +248,8 @@ def step(self, closure=None): update_scale = 1.0 if max_unorm > 0.0: - assert p.dtype == torch.float32 + if p.dtype != torch.float32: + raise AssertionError pnorm = torch.norm(p.detach()) unorm = torch.norm(update) if unorm > max_unorm * pnorm: diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index dfc6e5d65..e8e6adaf3 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -100,7 +100,8 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None) if isinstance(parameters, torch.Tensor): parameters = [parameters] if key is not None and value is not None: - assert key_value_dict is None + if key_value_dict is not None: + raise AssertionError key_value_dict = {key: value} if key_value_dict is not None: @@ -286,8 +287,10 @@ def to_gpu(self): def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) - assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + if pmodule is None: + raise AssertionError + if not (isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)): + raise AssertionError found = False for gindex, group in enumerate(self.param_groups): if found: diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 513baceab..92d7c5d28 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -9,7 +9,8 @@ def outlier_hook(module, input): - assert isinstance(module, torch.nn.Linear) + if not isinstance(module, torch.nn.Linear): + raise AssertionError tracer = OutlierTracer.get_instance() hvalue = tracer.get_hvalue(module.weight) if hvalue not in tracer.hvalue2outlier_idx: @@ -20,7 +21,8 @@ def outlier_hook(module, input): # assign the current layer the outlier idx found from the weight # of the previous linear layer if tracer.outliers[-1].numel() > 0: - assert tracer.outliers[-1].max() < module.weight.shape[1] + if tracer.outliers[-1].max() >= module.weight.shape[1]: + raise AssertionError tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] else: diff --git a/check_bnb_install.py b/check_bnb_install.py index 7a9dc93fc..d6326177d 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -16,6 +16,7 @@ p2 = p.data.sum().item() -assert p1 != p2 +if p1 == p2: + raise AssertionError print("SUCCESS!") print("Installation was successful!")