|
33 | 33 | "_choose_qparams_affine_floatx", |
34 | 34 | "_choose_qparams_and_quantize_affine_hqq", |
35 | 35 | "_choose_qparams_and_quantize_scale_only_hqq", |
| 36 | + "_choose_qparams_and_quantize_scale_only_sinq", |
36 | 37 | "_choose_qparams_and_quantize_affine_qqq", |
37 | 38 | "_choose_scale_float8", |
38 | 39 | "_choose_qparams_gguf", |
@@ -2219,6 +2220,81 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor: |
2219 | 2220 | return qdata, scale |
2220 | 2221 |
|
2221 | 2222 |
|
| 2223 | +def _choose_qparams_and_quantize_scale_only_sinq( |
| 2224 | + tensor: torch.Tensor, |
| 2225 | + qmin: int = -(2 ** (4 - 1)), |
| 2226 | + qmax: int = 2 ** (4 - 1) - 1, |
| 2227 | + group_size: int = 64, |
| 2228 | + niter: int = 20, |
| 2229 | + compute_dtype: torch.dtype = torch.float16, |
| 2230 | +) -> tuple: |
| 2231 | + """ |
| 2232 | + SINQ: Sinkhorn-Normalized Quantization (https://www.arxiv.org/abs/2509.22944) |
| 2233 | +
|
| 2234 | + Iteratively normalizes row and column standard deviations to minimize |
| 2235 | + matrix imbalance before quantization with dual scales. |
| 2236 | +
|
| 2237 | + Args: |
| 2238 | + tensor: Input weight tensor |
| 2239 | + group_size: Quantization group size (default: 64) |
| 2240 | + niter: Number of Sinkhorn iterations (default: 20) |
| 2241 | + compute_dtype: Target compute dtype (default: torch.float16) |
| 2242 | +
|
| 2243 | + Returns: |
| 2244 | + Tuple of (qdata, scale_row, scale_col) |
| 2245 | + """ |
| 2246 | + if group_size is not None: |
| 2247 | + assert _is_divisible(tensor.numel(), group_size), ( |
| 2248 | + f"group_size must divide tensor elements. shape: {tensor.shape}, group_size: {group_size}" |
| 2249 | + ) |
| 2250 | + |
| 2251 | + W = tensor.to(dtype=compute_dtype) |
| 2252 | + shape = W.shape |
| 2253 | + |
| 2254 | + # Reshape for 1D tiling |
| 2255 | + W = W.reshape(-1, group_size) # [N*num_groups, group_size] |
| 2256 | + |
| 2257 | + # Algorithm 1: Sinkhorn Normalization |
| 2258 | + q_min = min(W.std(dim=0).min().item(), W.std(dim=1).min().item()) |
| 2259 | + q_min = max(q_min, 1e-8) |
| 2260 | + |
| 2261 | + W_hat = W.clone() |
| 2262 | + scale_col_sinkhorn = torch.ones(W.shape[1], device=W.device, dtype=compute_dtype) |
| 2263 | + scale_row_sinkhorn = torch.ones(W.shape[0], device=W.device, dtype=compute_dtype) |
| 2264 | + |
| 2265 | + for _ in range(niter): |
| 2266 | + # Normalize columns (dim=0) |
| 2267 | + q_col = W_hat.std(dim=0) / q_min |
| 2268 | + q_col = torch.clamp(q_col, min=1e-8) |
| 2269 | + W_hat = W_hat / q_col.unsqueeze(0) |
| 2270 | + scale_col_sinkhorn = scale_col_sinkhorn * q_col |
| 2271 | + |
| 2272 | + # Normalize rows (dim=1) |
| 2273 | + q_row = W_hat.std(dim=1) / q_min |
| 2274 | + q_row = torch.clamp(q_row, min=1e-8) |
| 2275 | + W_hat = W_hat / q_row.unsqueeze(1) |
| 2276 | + scale_row_sinkhorn = scale_row_sinkhorn * q_row |
| 2277 | + |
| 2278 | + # INT8 symmetric quantization |
| 2279 | + # TODO: Consider custom bitwidth for SIMD acceleration like vadd4 |
| 2280 | + scale_s = (W_hat.abs().amax(dim=1, keepdim=True) / float(qmax)).clamp_min(1e-8) |
| 2281 | + # TODO: Find better rounding strategy like stochastic rounding |
| 2282 | + Q = _Round.apply(W_hat / scale_s).clamp(qmin, qmax) |
| 2283 | + # TODO: PERF test for scale factor dtype (FP16 vs. INT8) |
| 2284 | + # Although FP16 has high accuracy, FP16×INT8 can't be computed |
| 2285 | + # in Tensor Core directly, requiring INT8 to FP16 ops. |
| 2286 | + qdata = Q.view(shape).contiguous().to(torch.int8) |
| 2287 | + |
| 2288 | + # Combine RTN scale with row Sinkhorn factor |
| 2289 | + scale_row = ( |
| 2290 | + (scale_s.view(-1) * scale_row_sinkhorn).view(shape[0], -1).to(compute_dtype) |
| 2291 | + ) |
| 2292 | + num_groups = shape[1] // group_size |
| 2293 | + scale_col = scale_col_sinkhorn.repeat(num_groups)[: shape[1]].to(compute_dtype) |
| 2294 | + |
| 2295 | + return qdata, scale_row, scale_col |
| 2296 | + |
| 2297 | + |
2222 | 2298 | def _choose_qparams_affine_floatx( |
2223 | 2299 | tensor: torch.Tensor, ebits: int, mbits: int |
2224 | 2300 | ) -> torch.Tensor: |
|
0 commit comments