Skip to content

Commit 1179289

Browse files
Armand Sauzaymeta-codesync[bot]
authored andcommitted
Enable specifying output dtype for fp8 quantized communication (#5154)
Summary: X-link: meta-pytorch/torchrec#3568 Pull Request resolved: #5154 X-link: https://github.com/facebookresearch/FBGEMM/pull/2154 Adding fp8_output_dtype parameter to the qcomms config allowing fp8 to dequantize in different float formats as opposed to only FP32 Reviewed By: spcyppt Differential Revision: D86890315 fbshipit-source-id: 1cbfdabd63ad4dc0a1c3d47990aa591a567fc9d0
1 parent eb1ae89 commit 1179289

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

fbgemm_gpu/fbgemm_gpu/quantize_comm.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _dequantize_tensor(
123123
comm_precision: SparseType,
124124
ctx: Optional[QuantizationContext] = None,
125125
is_fwd: bool = True,
126+
fp8_output_dtype: Optional[SparseType] = None,
126127
) -> torch.Tensor:
127128
if comm_precision == SparseType.FP32:
128129
assert quantized_tensor.dtype == torch.float
@@ -137,8 +138,14 @@ def _dequantize_tensor(
137138
if ctx is not None and ctx.row_dim > 0:
138139
row_dim_quant = ctx.row_dim_quant
139140
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
141+
# use provided fp8_output_dtype or default to FP32 (0)
142+
output_dtype_int = (
143+
fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
144+
)
140145
dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
141-
quantized_tensor_2d, is_fwd
146+
quantized_tensor_2d,
147+
is_fwd,
148+
output_dtype_int,
142149
)
143150
return dequant_tensor.view(-1)
144151
else:
@@ -168,6 +175,7 @@ def __init__(
168175
row_dim: Optional[int] = None,
169176
is_fwd: bool = True,
170177
rounding_mode: Optional[RoundingMode] = None,
178+
fp8_output_dtype: Optional[SparseType] = None,
171179
) -> None:
172180
if loss_scale is not None:
173181
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -185,6 +193,7 @@ def __init__(
185193
self._is_fwd = is_fwd
186194
self._row_dim: int = -1 if row_dim is None else row_dim
187195
self._rounding_mode: Optional[RoundingMode] = rounding_mode
196+
self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
188197
if self._comm_precision == SparseType.MX4:
189198
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
190199
self._rounding_mode = (
@@ -216,7 +225,11 @@ def decode(
216225
f"## decoder {self._comm_precision} {self._loss_scale} ##"
217226
):
218227
dequantized_tensor = _dequantize_tensor(
219-
input_tensor, self._comm_precision, ctx, self._is_fwd
228+
input_tensor,
229+
self._comm_precision,
230+
ctx,
231+
self._is_fwd,
232+
fp8_output_dtype=self._fp8_output_dtype,
220233
)
221234
return dequantized_tensor
222235

0 commit comments

Comments
 (0)