Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def _real_quantize(self, inputs):
inputs,
axis=self._axis,
block_sizes=self._block_sizes,
scales=self.amax / 448.0 if self.amax is not None else None,
scales=self.amax / 448.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self._block_sizes is present, should the scales be None here?

And will it be recomputed later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, we should set scales as None here if self._block_sizes is present. The scales will be computed inside QTensor.

if (self.amax is not None and not self._block_sizes)
else None,
)
buffer_to_register["_scale"] = _scale
elif self._num_bits == 8:
Expand Down
33 changes: 33 additions & 0 deletions tests/gpu/torch/quantization/test_qtensor_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,36 @@ def test_nvfp4_dequantize_fast(self, shape, input_dtype):
f"Fast and standard dequantization differ: "
f"max diff = {(dequant_fast - dequant_standard).abs().max()}"
)

@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
("input_shape", "block_sizes"),
[
((128, 1152), {-1: 128}),
((256, 256), {-1: 64, -2: 64}), # 2D block sizes
],
)
def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, block_sizes):
"""Test FP8 quantization with both amax and block_sizes specified."""
quant_cfg = QuantizerAttributeConfig(
num_bits=(4, 3),
block_sizes=block_sizes,
fake_quant=False,
)
quantizer = TensorQuantizer(quant_cfg).to(device)

# Set a mock amax (scalar) - this was causing the bug
mock_amax = torch.tensor(1.5, device=device)
quantizer.amax = mock_amax

# Create input tensor
x = torch.randn(input_shape, dtype=input_dtype, device=device)

# QDQ
q_x = quantizer(x)
deq_x = quantizer(q_x)

assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1)
assert hasattr(quantizer, "_scale")
assert quantizer._scale.numel() > 1