Skip to content

Commit 1cfd3b5

Browse files
authored
Introduce SINQ calibration-free quantization algorithm (#3156)
* feat: SINQ quantization algorithm * update sinq algorithm * update sinq ops and add unit test * update device to cpu in SINQ test * fix scale dtype, device * update device to direct override * add qmin, qmax args similar to HQQ
1 parent 4903c55 commit 1cfd3b5

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MappingType,
1616
ZeroPointDomain,
1717
_choose_qparams_affine_tinygemm,
18+
_choose_qparams_and_quantize_scale_only_sinq,
1819
_choose_scale_float8,
1920
_fake_quantize_affine,
2021
_fake_quantize_affine_cachemask,
@@ -823,6 +824,45 @@ def test_maybe_expand_scale_to_tensor_shape(self):
823824
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
824825
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))
825826

827+
def test_choose_qparams_and_quantize_scale_only_sinq(self):
828+
"""Test SINQ quantization produces valid outputs and accuracy."""
829+
torch.manual_seed(self.SEED)
830+
input = torch.randn(128, 256, dtype=torch.float32)
831+
group_size = 64
832+
qmin = -(2 ** (4 - 1))
833+
qmax = 2 ** (4 - 1) - 1
834+
835+
# Run SINQ
836+
qdata, scale_row, scale_col = _choose_qparams_and_quantize_scale_only_sinq(
837+
input,
838+
group_size=group_size,
839+
qmin=qmin,
840+
qmax=qmax,
841+
niter=20,
842+
)
843+
844+
# Check quantized weight is producible
845+
self.assertEqual(qdata.dtype, torch.int8)
846+
self.assertEqual(qdata.shape, input.shape)
847+
self.assertTrue((qdata >= qmin).all() and (qdata <= qmax).all())
848+
849+
# Check scale factors are producible
850+
num_groups = input.shape[1] // group_size
851+
self.assertEqual(scale_row.shape, (input.shape[0], num_groups))
852+
self.assertEqual(scale_col.shape, (input.shape[1],))
853+
self.assertTrue((scale_row > 0).all() and (scale_col > 0).all())
854+
855+
# Check weight transform with 2-scale factor is applicable
856+
qdata_fp32 = qdata.to(torch.float32)
857+
qdata_reshaped = qdata_fp32.reshape(-1, group_size)
858+
scale_row_expanded = scale_row.reshape(-1, 1)
859+
scale_col_reshaped = scale_col.reshape(num_groups, group_size)
860+
scale_col_expanded = scale_col_reshaped.repeat(input.shape[0], 1)
861+
reconstructed = (
862+
qdata_reshaped * scale_row_expanded * scale_col_expanded
863+
).reshape(input.shape)
864+
self.assertFalse(torch.isnan(reconstructed).any())
865+
826866
def test_float8_blockwise_scaling(self):
827867
M, K = 512, 1024
828868
hp_tensor = torch.randn(M, K, dtype=torch.float)

torchao/quantization/quant_primitives.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"_choose_qparams_affine_floatx",
3434
"_choose_qparams_and_quantize_affine_hqq",
3535
"_choose_qparams_and_quantize_scale_only_hqq",
36+
"_choose_qparams_and_quantize_scale_only_sinq",
3637
"_choose_qparams_and_quantize_affine_qqq",
3738
"_choose_scale_float8",
3839
"_choose_qparams_gguf",
@@ -2219,6 +2220,81 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor:
22192220
return qdata, scale
22202221

22212222

2223+
def _choose_qparams_and_quantize_scale_only_sinq(
2224+
tensor: torch.Tensor,
2225+
qmin: int = -(2 ** (4 - 1)),
2226+
qmax: int = 2 ** (4 - 1) - 1,
2227+
group_size: int = 64,
2228+
niter: int = 20,
2229+
compute_dtype: torch.dtype = torch.float16,
2230+
) -> tuple:
2231+
"""
2232+
SINQ: Sinkhorn-Normalized Quantization (https://www.arxiv.org/abs/2509.22944)
2233+
2234+
Iteratively normalizes row and column standard deviations to minimize
2235+
matrix imbalance before quantization with dual scales.
2236+
2237+
Args:
2238+
tensor: Input weight tensor
2239+
group_size: Quantization group size (default: 64)
2240+
niter: Number of Sinkhorn iterations (default: 20)
2241+
compute_dtype: Target compute dtype (default: torch.float16)
2242+
2243+
Returns:
2244+
Tuple of (qdata, scale_row, scale_col)
2245+
"""
2246+
if group_size is not None:
2247+
assert _is_divisible(tensor.numel(), group_size), (
2248+
f"group_size must divide tensor elements. shape: {tensor.shape}, group_size: {group_size}"
2249+
)
2250+
2251+
W = tensor.to(dtype=compute_dtype)
2252+
shape = W.shape
2253+
2254+
# Reshape for 1D tiling
2255+
W = W.reshape(-1, group_size) # [N*num_groups, group_size]
2256+
2257+
# Algorithm 1: Sinkhorn Normalization
2258+
q_min = min(W.std(dim=0).min().item(), W.std(dim=1).min().item())
2259+
q_min = max(q_min, 1e-8)
2260+
2261+
W_hat = W.clone()
2262+
scale_col_sinkhorn = torch.ones(W.shape[1], device=W.device, dtype=compute_dtype)
2263+
scale_row_sinkhorn = torch.ones(W.shape[0], device=W.device, dtype=compute_dtype)
2264+
2265+
for _ in range(niter):
2266+
# Normalize columns (dim=0)
2267+
q_col = W_hat.std(dim=0) / q_min
2268+
q_col = torch.clamp(q_col, min=1e-8)
2269+
W_hat = W_hat / q_col.unsqueeze(0)
2270+
scale_col_sinkhorn = scale_col_sinkhorn * q_col
2271+
2272+
# Normalize rows (dim=1)
2273+
q_row = W_hat.std(dim=1) / q_min
2274+
q_row = torch.clamp(q_row, min=1e-8)
2275+
W_hat = W_hat / q_row.unsqueeze(1)
2276+
scale_row_sinkhorn = scale_row_sinkhorn * q_row
2277+
2278+
# INT8 symmetric quantization
2279+
# TODO: Consider custom bitwidth for SIMD acceleration like vadd4
2280+
scale_s = (W_hat.abs().amax(dim=1, keepdim=True) / float(qmax)).clamp_min(1e-8)
2281+
# TODO: Find better rounding strategy like stochastic rounding
2282+
Q = _Round.apply(W_hat / scale_s).clamp(qmin, qmax)
2283+
# TODO: PERF test for scale factor dtype (FP16 vs. INT8)
2284+
# Although FP16 has high accuracy, FP16×INT8 can't be computed
2285+
# in Tensor Core directly, requiring INT8 to FP16 ops.
2286+
qdata = Q.view(shape).contiguous().to(torch.int8)
2287+
2288+
# Combine RTN scale with row Sinkhorn factor
2289+
scale_row = (
2290+
(scale_s.view(-1) * scale_row_sinkhorn).view(shape[0], -1).to(compute_dtype)
2291+
)
2292+
num_groups = shape[1] // group_size
2293+
scale_col = scale_col_sinkhorn.repeat(num_groups)[: shape[1]].to(compute_dtype)
2294+
2295+
return qdata, scale_row, scale_col
2296+
2297+
22222298
def _choose_qparams_affine_floatx(
22232299
tensor: torch.Tensor, ebits: int, mbits: int
22242300
) -> torch.Tensor:

0 commit comments

Comments
 (0)