Skip to content

Commit 0c2bb76

Browse files
committed
update block_size args to granularity
1 parent b516304 commit 0c2bb76

File tree

5 files changed

+62
-52
lines changed

5 files changed

+62
-52
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Int8WeightOnlyConfig,
1818
quantize_,
1919
)
20+
from torchao.quantization.granularity import PerRow, PerTensor
2021
from torchao.quantization.utils import compute_error
2122
from torchao.testing.utils import TorchAOIntegrationTestCase
2223

@@ -160,24 +161,35 @@ def test_slice(self, config, device, dtype):
160161
@common_utils.parametrize(
161162
"config",
162163
[
163-
Int8DynamicActivationInt8WeightConfig(version=2),
164-
Int8WeightOnlyConfig(version=2),
164+
Int8DynamicActivationInt8WeightConfig,
165+
Int8WeightOnlyConfig,
165166
],
166167
)
167-
def test_index_select(self, config):
168+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
169+
def test_index_select(self, config, granularity):
168170
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
169171
N, K = 256, 512
170172
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
171173
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
172174
linear.weight.data = x
175+
176+
config = config(version=2, granularity=granularity)
173177
quantize_(linear, config)
174178

175179
x_int8 = linear.weight
176180
x_int8_0 = x_int8[0]
181+
182+
# Test dequantization consistency
177183
torch.testing.assert_close(
178184
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
179185
)
180186

187+
# Test block_size granularity
188+
if isinstance(granularity, PerRow):
189+
self.assertEqual(x_int8.block_size, [1, K])
190+
elif isinstance(granularity, PerTensor):
191+
self.assertEqual(x_int8.block_size, [N, K])
192+
181193
@common_utils.parametrize(
182194
"config",
183195
[
@@ -187,16 +199,17 @@ def test_index_select(self, config):
187199
)
188200
def test_dequantization_accuracy(self, config):
189201
"""Test dequantization accuracy separately"""
190-
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda")
191-
linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda")
192-
linear.weight.data = test_data
202+
linear = torch.nn.Linear(
203+
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
204+
)
205+
weight_fp = copy.deepcopy(linear.weight)
193206
quantize_(linear, config)
194207

195208
tensor = linear.weight
196209
dequantized = tensor.dequantize()
197-
self.assertEqual(dequantized.shape, test_data.shape)
198-
assert compute_error(dequantized, test_data) > 20, (
199-
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}"
210+
self.assertEqual(dequantized.shape, weight_fp.shape)
211+
assert compute_error(dequantized, weight_fp) > 20, (
212+
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
200213
)
201214

202215
@common_utils.parametrize(

torchao/float8/inference.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,7 @@ def _slice_scale_for_dimension(
140140
"""
141141
aten = torch.ops.aten
142142

143-
# Per-tensor quantization (scalar scale)
144-
if scale.numel() == 1:
145-
return scale
146-
147-
# Per-row quantization (1D scale)
148-
if scale.ndim == 1:
149-
if dim == 0:
150-
return aten.slice.Tensor(scale, 0, start, end, step)
151-
else:
152-
return scale
153-
154-
# Block-wise quantization (2D scale)
143+
# Unsupported case for now, this would be 1 scale per data element
155144
if scale.shape == data_shape:
156145
return aten.slice.Tensor(scale, dim, start, end, step)
157146

@@ -169,12 +158,6 @@ def _slice_scale_for_dimension(
169158
# Slice away as normal
170159
return aten.slice.Tensor(scale, dim, start, end, step)
171160
else:
172-
# Error on Step > 1
173-
if step > 1:
174-
raise NotImplementedError(
175-
"Slicing with step > 1 is not implemented for scale tensors."
176-
)
177-
178161
# There is blocking in this dimension
179162
# Calculate which scale elements correspond to the sliced data
180163
scale_start = start // block_size_for_dim if start is not None else None
@@ -184,6 +167,12 @@ def _slice_scale_for_dimension(
184167
else None
185168
)
186169

170+
# Error on Step > 1
171+
if step > 1:
172+
raise NotImplementedError(
173+
"Slicing with step > 1 is not implemented for scale tensors."
174+
)
175+
187176
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
188177

189178

torchao/quantization/quant_api.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,13 +1346,16 @@ class Int8WeightOnlyConfig(AOBaseConfig):
13461346
Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers.
13471347
13481348
Args:
1349-
group_size: Optional[int] = None - Controls the granularity of quantization. If None, applies per-channel quantization.
1350-
Otherwise, applies per-group quantization with the specified group size.
1349+
group_size (version 1) - Controls the granularity of quantization.
1350+
If None, applies per-channel quantization. Otherwise, applies per-group quantization with the specified group size.
1351+
granularity (version 2) - Quantization granularity.
1352+
PerRow() for per-channel quantization, PerTensor() for per-tensor quantization.
13511353
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
13521354
for better performance with this quantization scheme.
13531355
"""
13541356

13551357
group_size: Optional[int] = None
1358+
granularity: Optional[Union[PerRow, PerTensor]] = PerRow()
13561359
set_inductor_config: bool = True
13571360
version: int = 1
13581361

@@ -1387,11 +1390,7 @@ def _int8_weight_only_quantize_tensor(weight, config):
13871390
)
13881391
else:
13891392
assert config.version == 2, f"Unexpected version: {config.version}"
1390-
group_size = config.group_size
1391-
if group_size is None:
1392-
group_size = weight.shape[-1]
1393-
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
1394-
new_weight = Int8Tensor.from_hp(weight, block_size=block_size)
1393+
new_weight = Int8Tensor.from_hp(weight, granularity=config.granularity)
13951394
return new_weight
13961395

13971396

@@ -1572,17 +1571,17 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15721571
else:
15731572
input_quant_func = _int8_asymm_per_token_quant
15741573

1575-
if isinstance(config.granularity, PerTensor):
1576-
# Tensor granularity
1577-
block_size = weight.shape
1578-
else:
1579-
# Per row granularity
1580-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]])
1581-
15821574
if config.version == 1:
15831575
warnings.warn(
15841576
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"
15851577
)
1578+
if isinstance(config.granularity, PerTensor):
1579+
block_size = weight.shape
1580+
else:
1581+
block_size = tuple(
1582+
[1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]
1583+
)
1584+
15861585
quantized_weight = to_affine_quantized_intx(
15871586
weight,
15881587
mapping_type,
@@ -1602,10 +1601,13 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
16021601
)
16031602

16041603
assert config.version == 2, f"Unexpected version: {config.version}"
1604+
# Compute block_size from granularity for activation quantization kwargs
1605+
block_size = get_block_size(weight.shape, config.granularity)
1606+
16051607
quantized_weight = Int8Tensor.from_hp(
16061608
weight,
1607-
block_size,
1608-
act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size),
1609+
granularity=config.granularity,
1610+
act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=list(block_size)),
16091611
)
16101612

16111613
return quantized_weight

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _choose_quant_func_and_quantize_tensor(
5757
elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
5858
return Int8Tensor.from_hp(
5959
tensor,
60-
quant_kwargs.block_size,
60+
granularity=quant_kwargs.granularity,
6161
act_quant_kwargs=quant_kwargs,
6262
)
6363

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
_slice_scale_for_dimension,
1515
)
1616
from torchao.kernel import int_scaled_matmul
17+
from torchao.quantization.granularity import PerRow
1718
from torchao.quantization.quant_primitives import (
1819
MappingType,
1920
_maybe_expand_scale_to_tensor_shape,
@@ -24,6 +25,7 @@
2425
QuantizeTensorKwargs,
2526
_choose_quant_func_and_quantize_tensor,
2627
)
28+
from torchao.quantization.utils import get_block_size
2729
from torchao.utils import TorchAOBaseTensor, fill_defaults
2830

2931
__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"]
@@ -37,10 +39,12 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3739
3840
Args:
3941
block_size (list[int]): block size for quantization granularity
42+
granularity: the granularity for the Tensor, currently either PerRow() or PerTensor()
4043
# TODO: Static quantization support using `static_scale`, `static_zero_point`
4144
"""
4245

4346
block_size: list[int]
47+
granularity = PerRow()
4448

4549

4650
class Int8Tensor(TorchAOBaseTensor):
@@ -101,26 +105,28 @@ def __repr__(self):
101105
@classmethod
102106
def from_hp(
103107
cls,
104-
w: torch.Tensor,
105-
block_size: list[int],
108+
w_hp: torch.Tensor,
109+
granularity=PerRow(),
106110
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
107111
):
108-
if w.dim() not in [2, 3] or len(block_size) != w.dim():
112+
block_size = list(get_block_size(w_hp.shape, granularity))
113+
114+
if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim():
109115
raise ValueError("Expected 2D or 3D tensor with same block_size length")
110116

111117
scale, zero_point = choose_qparams_affine(
112-
input=w,
118+
input=w_hp,
113119
mapping_type=MappingType.SYMMETRIC,
114120
block_size=block_size,
115121
target_dtype=torch.int8,
116122
quant_min=-128,
117123
quant_max=127,
118-
scale_dtype=w.dtype,
124+
scale_dtype=w_hp.dtype,
119125
zero_point_dtype=torch.int8,
120126
)
121127

122128
int_data = quantize_affine(
123-
w,
129+
w_hp,
124130
block_size=block_size,
125131
scale=scale,
126132
zero_point=zero_point,
@@ -132,7 +138,7 @@ def from_hp(
132138
scale,
133139
block_size,
134140
act_quant_kwargs=act_quant_kwargs,
135-
dtype=w.dtype,
141+
dtype=w_hp.dtype,
136142
)
137143

138144
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
@@ -290,7 +296,7 @@ def _(func, types, args, kwargs):
290296
Int8Tensor(
291297
selected_qdata,
292298
selected_scale,
293-
[selected_qdata.shape[-1]],
299+
self.block_size[1:],
294300
self.act_quant_kwargs,
295301
self.dtype,
296302
),

0 commit comments

Comments
 (0)