diff --git a/torchao/optim/quant_utils.py b/torchao/optim/quant_utils.py index a4035fde1c..4839b1deab 100644 --- a/torchao/optim/quant_utils.py +++ b/torchao/optim/quant_utils.py @@ -5,6 +5,13 @@ # LICENSE file in the root directory of this source tree. import torch from torch import Tensor +try: + from torch.distributed.tensor import DTensor +except Exception: + try: + from torch.distributed._tensor import DTensor + except Exception: + DTensor = tuple() # https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391 @@ -117,7 +124,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor): return out.view(codes.shape) -def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: +def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor: # For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16 # - Round towards zero: [a31, ..., a16, 0, ..., 0] # - Round away from zero: [a31, ..., a16+1, 0, ..., 0] @@ -127,6 +134,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16 # # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16 + is_dt = isinstance(_x_f32, DTensor) + x_f32 = _x_f32.to_local() if is_dt else _x_f32 + rand_16bit = torch.randint( 0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32 ) @@ -142,4 +152,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: ) # alternative, slightly faster # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000 - return x_f32_bits.view(torch.float32).bfloat16() + x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16() + + return DTensor.from_local( + x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements, + run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()), + ) if is_dt else x_bf16_trunc