Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def matmul_4bit(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
if quant_state is None:
raise AssertionError
if A.device.type == "cpu":
if getattr(quant_state, "packing_format_for_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def sanity_check():
loss.backward()
adam.step()
p2 = p.data.sum().item()
assert p1 != p2
if p1 == p2:
raise AssertionError


def get_package_version(name: str) -> str:
Expand Down
54 changes: 36 additions & 18 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):


def prefetch_tensor(A: torch.Tensor, to_cpu=False):
assert A.is_paged, "Only paged tensors can be prefetched!"
if not A.is_paged:
raise AssertionError("Only paged tensors can be prefetched!")
if to_cpu:
deviceid = -1
else:
Expand Down Expand Up @@ -218,7 +219,8 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
values = values.sort().values
values /= values.max()

assert values.numel() == 256
if values.numel() != 256:
raise AssertionError

return values

Expand Down Expand Up @@ -254,7 +256,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e + p == total_bits - has_sign
if e + p != total_bits - has_sign:
raise AssertionError
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
Expand All @@ -279,7 +282,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
if signed:
values.append(-value)

assert len(values) == 2**total_bits
if len(values) != 2**total_bits:
raise AssertionError
values.sort()
if total_bits < 8:
gap = 256 - len(values)
Expand Down Expand Up @@ -337,7 +341,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0)
data.append(1.0)

assert len(data) == 2**total_bits
if len(data) != 2**total_bits:
raise AssertionError

gap = 256 - len(data)
for i in range(gap):
Expand Down Expand Up @@ -516,7 +521,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))

qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if not set(qs_dict.keys()).issubset(cls.valid_qs_keys):
raise AssertionError

if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
Expand Down Expand Up @@ -721,7 +727,8 @@ def dequantize_blockwise(
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
"""

assert quant_state is not None or absmax is not None
if quant_state is None and absmax is None:
raise AssertionError
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
Expand Down Expand Up @@ -842,7 +849,8 @@ def get_4bit_type(typename, device=None, blocksize=64):
data = torch.tensor(data, device=device)
data.div_(data.abs().max())

assert data.numel() == 16
if data.numel() != 16:
raise AssertionError

return data

Expand Down Expand Up @@ -1009,7 +1017,8 @@ def dequantize_4bit(
blocksize = 64

if quant_state is None:
assert absmax is not None and out is not None
if absmax is None or out is None:
raise AssertionError

quant_state = QuantState(
absmax=absmax,
Expand Down Expand Up @@ -1365,7 +1374,8 @@ def igemm(
ldc = sB[1]
elif len(sB) == 3:
# special case
assert len(sA) == 3
if len(sA) != 3:
raise AssertionError
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
Expand Down Expand Up @@ -1658,10 +1668,13 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
unpacked_w[::2] = qweight >> 4
qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K)
# pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
assert len(qweight_final.shape) == 2
if len(qweight_final.shape) != 2:
raise AssertionError
N, K = qweight_final.shape[0], qweight_final.shape[1]
assert N % block_n == 0, "N must be divisible by block_n"
assert K % 2 == 0, "K must be even"
if N % block_n != 0:
raise AssertionError("N must be divisible by block_n")
if K % 2 != 0:
raise AssertionError("K must be even")
BLOCK_N = block_n
BIT_COUNT = 32 # (=32 low +32 high)
new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2]
Expand Down Expand Up @@ -1706,18 +1719,23 @@ def _convert_weight_packed_for_cpu_inverse(
qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
recovered_state: QuantState with partially restored fields (best-effort inverse)
"""
assert quant_state.packing_format_for_cpu, "only for packing format"
assert packed_weight.dtype == torch.uint8
assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]"
if not quant_state.packing_format_for_cpu:
raise AssertionError("only for packing format")
if packed_weight.dtype != torch.uint8:
raise AssertionError
if len(packed_weight.shape) != 2:
raise AssertionError("packed_weight should be [N, K/2]")
N, K_half = packed_weight.shape
K = K_half * 2

# 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
BLOCK_N = block_n
BIT_COUNT = 32 # (=32 low + 32 high)

assert N % BLOCK_N == 0, "N must be divisible by block_n"
assert K % 2 == 0, "K must be even"
if N % BLOCK_N != 0:
raise AssertionError("N must be divisible by block_n")
if K % 2 != 0:
raise AssertionError("K must be even")

# [N, K/2] -> [-1, 64] (32 low + 32 high)
packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64]
Expand Down
24 changes: 16 additions & 8 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line

# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert module.weight.shape[1] == 1
if module.weight.shape[1] != 1:
raise AssertionError
if not isinstance(module.weight, Params4bit):
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
module.weight.quant_state = module.quant_state
Expand Down Expand Up @@ -866,8 +867,10 @@ def forward(self, input: Tensor) -> Tensor:
rows = self.weight.data
row_stats = self.weight.SCB

assert rows.shape == (self.num_embeddings, self.embedding_dim)
assert row_stats.shape == (self.num_embeddings,)
if rows.shape != (self.num_embeddings, self.embedding_dim):
raise AssertionError
if row_stats.shape != (self.num_embeddings,):
raise AssertionError

compressed_output = F.embedding(input, rows)
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))
Expand Down Expand Up @@ -928,35 +931,40 @@ def __init__(
)

def _forward_with_partial_dequantize(self, input: Tensor):
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
if self.embedding_dim % self.weight.quant_state.blocksize != 0:
raise AssertionError

w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)

output_4bit = torch.nn.functional.embedding(
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
input=input,
).view(-1, 1)
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
if output_4bit.shape != (input.numel() * self.embedding_dim // 2, 1):
raise AssertionError

blocks_per_emb = self.embedding_dim // self.weight.blocksize

absmax = self.weight.quant_state.absmax
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)
if absmax.shape != (self.num_embeddings * blocks_per_emb,):
raise AssertionError

output_absmax = torch.nn.functional.embedding(
weight=absmax.view(self.num_embeddings, blocks_per_emb),
input=input,
).view(
-1,
)
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
if output_absmax.shape != (input.numel() * blocks_per_emb,):
raise AssertionError

output_quant_state = copy.deepcopy(self.weight.quant_state)
output_quant_state.absmax = output_absmax
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))

output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
assert output.shape == (*input.shape, self.embedding_dim)
if output.shape != (*input.shape, self.embedding_dim):
raise AssertionError

return output.to(self.dtype)

Expand Down
6 changes: 4 additions & 2 deletions bitsandbytes/nn/parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,16 @@ def _parametrized_state_dict_post_hook(
clean_key = f"{prefix}{param_name}"
state_dict[clean_key] = state_dict.pop(original_key)

assert P.is_parametrized(module, param_name)
if not P.is_parametrized(module, param_name):
raise AssertionError

# Find the parametrization, which should have the quantization state.
parametrization: Bnb4bitParametrization = next(
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
)

assert parametrization is not None, "Parametrization not found for the parameter."
if parametrization is None:
raise AssertionError("Parametrization not found for the parameter.")

quant_state = parametrization.quant_state

Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/optim/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def step(self, closure=None):

update_scale = 1.0
if max_unorm > 0.0:
assert p.dtype == torch.float32
if p.dtype != torch.float32:
raise AssertionError
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
if unorm > max_unorm * pnorm:
Expand Down
9 changes: 6 additions & 3 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None)
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if key is not None and value is not None:
assert key_value_dict is None
if key_value_dict is not None:
raise AssertionError
key_value_dict = {key: value}

if key_value_dict is not None:
Expand Down Expand Up @@ -286,8 +287,10 @@ def to_gpu(self):
def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
if pmodule is None:
raise AssertionError
if not (isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)):
raise AssertionError
found = False
for gindex, group in enumerate(self.param_groups):
if found:
Expand Down
6 changes: 4 additions & 2 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
if not isinstance(module, torch.nn.Linear):
raise AssertionError
tracer = OutlierTracer.get_instance()
hvalue = tracer.get_hvalue(module.weight)
if hvalue not in tracer.hvalue2outlier_idx:
Expand All @@ -20,7 +21,8 @@ def outlier_hook(module, input):
# assign the current layer the outlier idx found from the weight
# of the previous linear layer
if tracer.outliers[-1].numel() > 0:
assert tracer.outliers[-1].max() < module.weight.shape[1]
if tracer.outliers[-1].max() >= module.weight.shape[1]:
raise AssertionError
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]

else:
Expand Down
3 changes: 2 additions & 1 deletion check_bnb_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

p2 = p.data.sum().item()

assert p1 != p2
if p1 == p2:
raise AssertionError
print("SUCCESS!")
print("Installation was successful!")