Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchao/optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for different PyTorch versions? If yes, could you clarify which specific versions?

in general we support 3 most recent stable PyTorch releases max, so if it's older than that I'd just leave it out

from torch.distributed.tensor import DTensor
except Exception:
try:
from torch.distributed._tensor import DTensor
except Exception:
DTensor = tuple()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rethrow this instead returning tuple



# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
Expand Down Expand Up @@ -117,7 +124,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
return out.view(codes.shape)


def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
Expand All @@ -127,6 +134,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
is_dt = isinstance(_x_f32, DTensor)
x_f32 = _x_f32.to_local() if is_dt else _x_f32

rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
Expand All @@ -142,4 +152,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
)
# alternative, slightly faster
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
return x_f32_bits.view(torch.float32).bfloat16()
x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16()

return DTensor.from_local(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks reasonable, can we add a test to cover?

x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements,
run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()),
) if is_dt else x_bf16_trunc
Loading