3939
4040log = logging .getLogger (__name__ )
4141
42- # TODO: unpin after resolving the `quant_state` format breaking changes
43- _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes==0.41.0" )
42+ _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes>=0.42.0" )
4443
4544
4645class BitsandbytesPrecision (Precision ):
@@ -344,7 +343,7 @@ def quantize(
344343 def to_empty (self , * , device : _DEVICE , recurse : bool = True ) -> Self :
345344 if self .weight .dtype == torch .uint8 : # was quantized
346345 # cannot init the quantized params directly
347- weight = torch .empty (self .weight .quant_state [ 1 ] , device = device , dtype = torch .half )
346+ weight = torch .empty (self .weight .quant_state . shape , device = device , dtype = torch .half )
348347 else :
349348 weight = torch .empty_like (self .weight .data , device = device )
350349 device = torch .device (device )
@@ -366,7 +365,7 @@ def reset_parameters(self) -> None:
366365 linear_init_finished = isinstance (self .weight , bnb .nn .Params4bit )
367366 if linear_init_finished and self .weight .dtype == torch .uint8 : # was quantized
368367 # cannot init the quantized params directly
369- weight = torch .empty (self .weight .quant_state [ 1 ] , device = self .weight .device , dtype = torch .half )
368+ weight = torch .empty (self .weight .quant_state . shape , device = self .weight .device , dtype = torch .half )
370369 else :
371370 weight = self .weight .data
372371 torch .nn .init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
0 commit comments