diff --git a/quartet2/csrc/binding.cpp b/quartet2/csrc/binding.cpp index 77eb9bf..04c1f6a 100644 --- a/quartet2/csrc/binding.cpp +++ b/quartet2/csrc/binding.cpp @@ -49,6 +49,7 @@ void rht128_requant( void eden_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long seed, long rows, long cols); void rtn_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols); void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols); +void gridflip_four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, long rows, long cols); void dequant_tp_had_quant( __nv_fp4x2_storage_t* y, __nv_fp8_e4m3* scales_fp8, float* global_scale_ptr, nv_bfloat16* scratch_scales, unsigned* max_scale, const nv_bfloat16* h, @@ -272,6 +273,38 @@ void four_six_fp4_binding( scale_override, inp.shape(0), inp.shape(1)); } +void gridflip_four_six_fp4_binding( + const CudaArray<>& out, + const CudaArray<>& scales, + const CudaArray>& global_scale, + const CudaArray& inp, + const CudaArray& amax_ptr, + float scale_override, + float grid_shift + ) +{ + nb::dlpack::dtype bf16_dt{static_cast(nb::dlpack::dtype_code::Bfloat), 16, 1}; + + CHECK_EQ(inp.ndim(), 2ul); + CHECK_EQ(out.ndim(), 2ul); + + CHECK_EQ(out.shape(0), inp.shape(0)); + CHECK_EQ(out.size(), inp.size() / 2); + CHECK_EQ(out.dtype().bits, static_cast(8)); + CHECK_EQ(scales.size(), inp.size() / 16); + CHECK_EQ(scales.dtype().bits, static_cast(8)); + CHECK_EQ(inp.dtype(), bf16_dt); + CHECK(global_scale.data() != amax_ptr.data()); + + gridflip_four_six_fp4( + reinterpret_cast<__nv_fp4x4_e2m1*>(out.data()), + reinterpret_cast<__nv_fp8_e4m3*>(scales.data()), + global_scale.data(), + reinterpret_cast(inp.data()), + amax_ptr.data(), + scale_override, grid_shift, inp.shape(0), inp.shape(1)); +} + void rtn_fp4_binding( const CudaArray<>& out, const CudaArray<>& scales, @@ -422,5 +455,6 @@ NB_MODULE(_quartet2, m) { m.def("eden_fp4", &eden_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"), nb::arg("seed")); m.def("four_six_fp4", &four_six_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override")); + m.def("gridflip_four_six_fp4", &gridflip_four_six_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"), nb::arg("grid_shift")); m.def("rtn_fp4", &rtn_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override")); } diff --git a/quartet2/csrc/round_four_six.cu b/quartet2/csrc/round_four_six.cu index ea024f2..ea3baec 100644 --- a/quartet2/csrc/round_four_six.cu +++ b/quartet2/csrc/round_four_six.cu @@ -41,6 +41,27 @@ __device__ __forceinline__ QuantResult quantize(float abs_max, float inv_val_max return QuantResult{result, s_round_fp8, s_as_fp8}; } +__device__ __forceinline__ QuantResult quantize_gridflip(float abs_max, float inv_val_max, float scale, float grid_shift, bf16x8& x) { + float s_group = abs_max * inv_val_max; + float inv_scale = reciprocal_approximate_ftz(scale); + __nv_fp8_e4m3 s_as_fp8 = static_cast<__nv_fp8_e4m3>(s_group * inv_scale); + float s_round_fp8 = static_cast(s_as_fp8); + if (s_round_fp8 == 0) s_round_fp8 = 1.f; + + float factor = reciprocal_approximate_ftz(s_round_fp8 * scale); + float2 factor2 = {-factor, -factor}; + fp4x8 result; + for (int k = 0; k < bf16x8::size; k += 2) { + float2 src = make_float2(static_cast(x[k+0]), static_cast(x[k+1])); + float2 prod = __fmul2_rn(src, factor2); + float2 scaled = {prod.x - grid_shift, prod.y - grid_shift}; + unsigned char bits = __nv_cvt_float2_to_fp4x2(scaled, __nv_fp4_interpretation_t::__NV_E2M1, cudaRoundMode::cudaRoundNearest); + result[k/2] = bits; + } + + return QuantResult{result, s_round_fp8, s_as_fp8}; +} + __forceinline__ __device__ float quant_error(bf16x8 x, const QuantResult& q, float scale) { const float descale = static_cast(q.fp8s) * scale; float2 sum = {0.f, 0.f}; @@ -56,6 +77,24 @@ __forceinline__ __device__ float quant_error(bf16x8 x, const QuantResult& q, flo return local_error; } +__forceinline__ __device__ float gridflip_quant_error(bf16x8 x, const QuantResult& q, float scale, float grid_shift) { + const float descale = static_cast(q.fp8s) * scale; + float2 sum = {0.f, 0.f}; + for (int i = 0; i < 4; ++i) { + float2 dq = __nv_cvt_fp4x2_to_float2(q.bits[i]); + float2 xv = {static_cast(x[2*i+0]), static_cast(x[2*i+1])}; + float2 recon = { + -(dq.x + grid_shift) * descale, + -(dq.y + grid_shift) * descale, + }; + float2 d = {recon.x - xv.x, recon.y - xv.y}; + sum = __ffma2_rn(d, d, sum); + } + float local_error = sum.x + sum.y; + local_error += __shfl_xor_sync(0xffffffff, local_error, 1); + return local_error; +} + template struct get_candidate_helper; @@ -76,6 +115,24 @@ struct get_candidate_helper<> { } }; +template +struct get_candidate_value_helper; + +template +struct get_candidate_value_helper { + static constexpr __forceinline__ __device__ float get(int i) { + if (i == 0) return Value; + return get_candidate_value_helper::get(i - 1); + } +}; + +template<> +struct get_candidate_value_helper<> { + static constexpr __forceinline__ __device__ float get(int i) { + __builtin_unreachable(); + } +}; + template __global__ void four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, int nvecs, int cols) { constexpr int NumCandidates = sizeof...(Candidates); @@ -122,6 +179,63 @@ __global__ void four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale } } +template +__global__ void gridflip_four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, int nvecs, int cols) { + constexpr int NumCandidates = sizeof...(Candidates); + float global_abs_max = *amax_ptr; + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if(idx >= nvecs) return; + + bf16x8 x = bf16x8::load(x_ptr + 8 * idx); + + constexpr float inv_scales_max = NumCandidates > 1 ? 1.f / 256.f : 1.f / 448.f; + constexpr float one_over_six = 1.f / 6.f; + float inv_val_max = scale_override * one_over_six; + float scale = global_abs_max == 0 ? 1.f : global_abs_max * inv_scales_max * inv_val_max; + if (idx == 0) { + global_scale_ptr[0] = scale; + } + + nv_bfloat16 local_abs_max = vecReduceAbsMax(x); + nv_bfloat16 other_abs_max = __shfl_xor_sync(0xffffffff, local_abs_max, 1); + float full_abs_max = static_cast(__hmax(local_abs_max, other_abs_max)); + + QuantResult best_standard_res; + float best_standard = INFINITY; + QuantResult best_gridflip_res; + float best_gridflip = INFINITY; + for (int i = 0; i < NumCandidates; ++i) { + float inv_val = get_candidate_helper::get_inv(i); + QuantResult standard_res = quantize(full_abs_max, inv_val * scale_override, scale, x); + float standard_score = quant_error(x, standard_res, scale); + if (standard_score < best_standard) { + best_standard = standard_score; + best_standard_res = standard_res; + } + + float value = get_candidate_value_helper::get(i); + QuantResult gridflip_res = quantize_gridflip(full_abs_max, scale_override / (value + grid_shift), scale, grid_shift, x); + float gridflip_score = gridflip_quant_error(x, gridflip_res, scale, grid_shift); + if (gridflip_score < best_gridflip) { + best_gridflip = gridflip_score; + best_gridflip_res = gridflip_res; + } + } + + bool use_gridflip = best_gridflip < best_standard; + QuantResult res = use_gridflip ? best_gridflip_res : best_standard_res; + res.bits.store(reinterpret_cast(y_ptr) + 4 * idx); + if (idx % 2 == 0) { + int col = (idx / 2) % cols; + int row = (idx / 2) / cols; + unsigned char scale_bits = *reinterpret_cast(&res.fp8s); + if (use_gridflip) { + scale_bits |= 0x80; + } + reinterpret_cast(scale_ptr)[row * cols + col] = scale_bits; + } +} + void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols) { if (cols % 128 != 0) throw std::invalid_argument("four_six_fp4: cols must be divisible by 128"); CHECK_POINTER(y_ptr); @@ -137,6 +251,22 @@ void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* globa CUDA_CHECK(cudaGetLastError()); } +void gridflip_four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, long rows, long cols) { + if (cols % 128 != 0) throw std::invalid_argument("gridflip_four_six_fp4: cols must be divisible by 128"); + if (grid_shift < 0.f) throw std::invalid_argument("gridflip_four_six_fp4: grid_shift must be non-negative"); + CHECK_POINTER(y_ptr); + CHECK_POINTER(scale_ptr); + CHECK_POINTER(x_ptr); + CHECK_POINTER_NOT_NULL(global_scale_ptr); + CHECK_POINTER_NOT_NULL(amax_ptr); + + int n_vecs = rows * cols / 8; + int block_size = 256; + int n_blocks = (n_vecs + block_size - 1) / block_size; + gridflip_four_six_fp4_kernel<6.f, 4.f><<>>(y_ptr, scale_ptr, global_scale_ptr, x_ptr, amax_ptr, scale_override, grid_shift, n_vecs, cols / 16); + CUDA_CHECK(cudaGetLastError()); +} + void rtn_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols) { if (cols % 128 != 0) throw std::invalid_argument("rtn_fp4: cols must be divisible by 128"); CHECK_POINTER(y_ptr); diff --git a/quartet2/python/quartet2/linear.py b/quartet2/python/quartet2/linear.py index dc989bf..577c4ed 100644 --- a/quartet2/python/quartet2/linear.py +++ b/quartet2/python/quartet2/linear.py @@ -2,9 +2,30 @@ from flashinfer import mm_fp4 from scipy.linalg import hadamard -from .quant import quant_fp4, rht128_quant_eden, rht128_requant, NVFP4QuantMode +from .quant import NVFP4Quant, quant_fp4, quant_gridflip_fp4, rht128_quant_eden, rht128_requant, NVFP4QuantMode import nvtx import contextlib +from typing import Literal + +FP4MatmulBackend = Literal["flashinfer", "qutlass", "dequantized"] +FP4_MATMUL_BACKENDS = ("flashinfer", "qutlass", "dequantized") +_fp4_mm_backend: FP4MatmulBackend = "flashinfer" +FP4WeightQuantizer = Literal["four_six", "gridflip"] +FP4_WEIGHT_QUANTIZERS = ("four_six", "gridflip") +_fp4_weight_quantizer: FP4WeightQuantizer = "four_six" +_gridflip_shift: float = 0.25 + + +def _import_qutlass(): + try: + import qutlass + except ImportError as exc: + raise ImportError( + "The qutlass FP4 matmul backend requires the qutlass package. " + "Install the local GridFlip/QuTLASS package before selecting " + "quartet2.linear.set_fp4_mm_backend('qutlass')." + ) from exc + return qutlass def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): return torch.tensor( @@ -31,8 +52,174 @@ def rerotate_hadamard(hadamard_matrix): return hadamard_matrix * signs[None, :] # NOTE: rerotate along last dim, inner dim for TN GEMM -@torch.library.custom_op("quartet2::fp4_mm", mutates_args=()) -def _fp4_mm(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: +def set_fp4_mm_backend(backend: FP4MatmulBackend) -> FP4MatmulBackend: + if backend not in FP4_MATMUL_BACKENDS: + raise ValueError(f"backend must be one of {FP4_MATMUL_BACKENDS}, got {backend!r}") + if backend == "qutlass": + _import_qutlass() + + global _fp4_mm_backend + old_backend = _fp4_mm_backend + _fp4_mm_backend = backend + return old_backend + + +def get_fp4_mm_backend() -> FP4MatmulBackend: + return _fp4_mm_backend + + +def set_fp4_weight_quantizer( + quantizer: FP4WeightQuantizer, + *, + gridflip_shift: float = 0.25, +) -> FP4WeightQuantizer: + if quantizer not in FP4_WEIGHT_QUANTIZERS: + raise ValueError(f"quantizer must be one of {FP4_WEIGHT_QUANTIZERS}, got {quantizer!r}") + if gridflip_shift < 0: + raise ValueError("gridflip_shift must be non-negative") + + global _fp4_weight_quantizer, _gridflip_shift + old_quantizer = _fp4_weight_quantizer + _fp4_weight_quantizer = quantizer + _gridflip_shift = float(gridflip_shift) + return old_quantizer + + +def get_fp4_weight_quantizer() -> FP4WeightQuantizer: + return _fp4_weight_quantizer + + +def get_gridflip_shift() -> float: + return _gridflip_shift + + +@contextlib.contextmanager +def fp4_mm_backend(backend: FP4MatmulBackend): + old_backend = set_fp4_mm_backend(backend) + try: + yield + finally: + set_fp4_mm_backend(old_backend) + + +def _resolve_fp4_mm_backend(backend: FP4MatmulBackend | None) -> FP4MatmulBackend: + return _fp4_mm_backend if backend is None else backend + + +@torch.compiler.disable() +def apply_block_transform(x: torch.Tensor, transform: torch.Tensor) -> torch.Tensor: + if transform is None: + return x + + rows, cols = x.shape + group_size = transform.shape[-1] + groups = cols // group_size + x_grouped = x.reshape(rows, groups, group_size).permute(1, 0, 2) + y_grouped = torch.bmm(x_grouped, transform.to(dtype=x.dtype)) + return y_grouped.permute(1, 0, 2).reshape(rows, cols).contiguous() + + +@torch.compiler.disable() +def update_wush_moments( + sigma_x: torch.Tensor, + sigma_w: torch.Tensor | None, + ema_count: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + group_size: int, + ema_decay: float, +) -> None: + if x.shape[-1] % group_size != 0: + raise ValueError(f"WUSH requires input dimension divisible by {group_size}, got {x.shape[-1]}") + + rows = x.shape[0] + out = weight.shape[0] + groups = sigma_x.shape[0] + beta = 0.0 if ema_count.item() == 0 else ema_decay + + x_blocks = x.to(torch.float32).reshape(rows, groups, group_size).permute(1, 2, 0) + sig_x = torch.bmm(x_blocks, x_blocks.transpose(-1, -2)) / max(rows, 1) + sigma_x.mul_(beta).add_(sig_x, alpha=1.0 - beta) + + if sigma_w is not None: + w_blocks = weight.detach().to(torch.float32).reshape(out, groups, group_size).permute(1, 2, 0) + sig_w = torch.bmm(w_blocks, w_blocks.transpose(-1, -2)) / max(out, 1) + sigma_w.mul_(beta).add_(sig_w, alpha=1.0 - beta) + + ema_count.add_(1) + + +@torch.compiler.disable() +def _psd_factor_from_moments( + sigma: torch.Tensor, + damp: float, + eye_batch: torch.Tensor, + eig_floor: float = 1e-8, +) -> torch.Tensor: + sigma = sigma.float() + sigma = 0.5 * (sigma + sigma.transpose(-1, -2)) + sigma = torch.nan_to_num(sigma, nan=0.0, posinf=1e4, neginf=-1e4) + damped = sigma + damp * eye_batch + + factor, info = torch.linalg.cholesky_ex(damped) + bad = info != 0 + if not bad.any(): + return factor + + try: + eigvals, eigvecs = torch.linalg.eigh(damped[bad]) + eigvals_sqrt = eigvals.clamp_min(eig_floor).sqrt() + factor[bad] = eigvecs @ torch.diag_embed(eigvals_sqrt) @ eigvecs.transpose(-1, -2) + except RuntimeError: + factor[bad] = eye_batch[bad] + return factor + + +@torch.compiler.disable() +def get_wush_transforms_from_moments( + sigma_x: torch.Tensor, + sigma_w: torch.Tensor | None, + group_size: int, + damp: float, + s_min: float, + max_cond: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + groups = sigma_x.shape[0] + device = sigma_x.device + had = get_hadamard_matrix(group_size, torch.float32, device) + eye = torch.eye(group_size, device=device, dtype=torch.float32) + eye_batch = eye.unsqueeze(0).expand(groups, -1, -1) + + if sigma_w is None: + w_prime = eye_batch.clone() + else: + w_prime = _psd_factor_from_moments(sigma_w, damp, eye_batch) + + x_prime = _psd_factor_from_moments(sigma_x, damp, eye_batch) + _, singular_values, vt = torch.linalg.svd(w_prime.transpose(-1, -2) @ x_prime) + singular_values = singular_values.clamp_min(s_min) + s_inv_sqrt = torch.diag_embed(singular_values.rsqrt()) + + t_w = had.unsqueeze(0) @ s_inv_sqrt @ vt @ x_prime.transpose(-1, -2) + t_w_inv, inv_info = torch.linalg.inv_ex(t_w) + + try: + conds = torch.linalg.cond(t_w) + except RuntimeError: + conds = torch.full((groups,), float("inf"), device=device, dtype=torch.float32) + bad = (inv_info != 0) | (conds > max_cond) | ~torch.isfinite(conds) + if bad.any(): + t_w[bad] = had + t_w_inv[bad] = had.T + + # Quartet-II forward applies x @ input_transform and weight @ weight_transform. + input_transform = t_w_inv + weight_transform = t_w.transpose(-1, -2) + return input_transform.to(torch.bfloat16), weight_transform.to(torch.bfloat16), conds + + +@torch.library.custom_op("quartet2::fp4_mm_flashinfer", mutates_args=()) +def _fp4_mm_flashinfer(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: m, packed_k = x_fp4.shape k = packed_k * 2 n = w_fp4.shape[0] @@ -56,8 +243,48 @@ def _fp4_mm(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: return out -@_fp4_mm.register_fake -def _fp4_mm_fake(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: +@_fp4_mm_flashinfer.register_fake +def _fp4_mm_flashinfer_fake(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: + return torch.empty((x_fp4.shape[0], w_fp4.shape[0]), device=x_fp4.device, dtype=torch.bfloat16) + + +@torch.library.custom_op("quartet2::fp4_mm_qutlass", mutates_args=()) +def _fp4_mm_qutlass(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: + qutlass = _import_qutlass() + + # quartet2 quantization kernels already write micro-scales in Blackwell's + # blocked layout, even though tensors keep the logical [rows, K / 16] shape. + return qutlass.matmul_nvf4_bf16_tn( + x_fp4.contiguous(), + w_fp4.contiguous(), + x_mx.contiguous(), + w_mx.contiguous(), + alpha.reshape(1).contiguous(), + ) + + +def _fp4_mm_gridflip_qutlass( + x_fp4: torch.Tensor, + w_fp4: torch.Tensor, + x_mx: torch.Tensor, + w_mx_rowmajor: torch.Tensor, + alpha: torch.Tensor, + grid_shift: float, +) -> torch.Tensor: + qutlass = _import_qutlass() + x_mx_rowmajor = unblock(x_mx, x_fp4.shape[0], x_fp4.shape[1] * 2).contiguous() + return qutlass.matmul_nvf4_gridflip_bf16_tn( + x_fp4.contiguous(), + w_fp4.contiguous(), + x_mx_rowmajor, + w_mx_rowmajor.contiguous(), + alpha.reshape(1).contiguous(), + grid_shift, + ) + + +@_fp4_mm_qutlass.register_fake +def _fp4_mm_qutlass_fake(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: return torch.empty((x_fp4.shape[0], w_fp4.shape[0]), device=x_fp4.device, dtype=torch.bfloat16) @@ -143,6 +370,285 @@ def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float): ) * alpha return x_dq.to(torch.bfloat16) + +def _round_fp4_values_and_codes(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_abs = x.abs() + magnitude = torch.where( + x_abs >= 5.0, + torch.full_like(x, 6.0), + torch.where( + x_abs >= 3.5, + torch.full_like(x, 4.0), + torch.where( + x_abs >= 2.5, + torch.full_like(x, 3.0), + torch.where( + x_abs >= 1.75, + torch.full_like(x, 2.0), + torch.where( + x_abs >= 1.25, + torch.full_like(x, 1.5), + torch.where( + x_abs >= 0.75, + torch.full_like(x, 1.0), + torch.where( + x_abs >= 0.25, + torch.full_like(x, 0.5), + torch.zeros_like(x), + ), + ), + ), + ), + ), + ), + ) + magnitude_code = torch.where( + x_abs >= 5.0, + torch.full(x.shape, 7, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 3.5, + torch.full(x.shape, 6, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 2.5, + torch.full(x.shape, 5, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 1.75, + torch.full(x.shape, 4, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 1.25, + torch.full(x.shape, 3, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 0.75, + torch.full(x.shape, 2, device=x.device, dtype=torch.uint8), + torch.where( + x_abs >= 0.25, + torch.ones(x.shape, device=x.device, dtype=torch.uint8), + torch.zeros(x.shape, device=x.device, dtype=torch.uint8), + ), + ), + ), + ), + ), + ), + ) + sign = x < 0 + values = torch.where(sign, -magnitude, magnitude) + codes = magnitude_code | sign.to(torch.uint8) * 0x8 + return values, codes + + +def _pack_fp4_codes(codes: torch.Tensor) -> torch.Tensor: + codes = codes.to(torch.uint8) + return ((codes[..., 1::2] & 0xF) << 4 | (codes[..., ::2] & 0xF)).flatten(start_dim=-2) + + +@torch.compiler.disable() +def _quant_gridflip_weight_fp4( + weight: torch.Tensor, + *, + amax: torch.Tensor, + scale_override: float, + grid_shift: float, + mode: NVFP4QuantMode, +) -> NVFP4Quant: + if mode != NVFP4QuantMode.FOUR_SIX: + raise ValueError("GridFlip weight quantization currently supports only FOUR_SIX mode") + if weight.dtype != torch.bfloat16: + raise TypeError("GridFlip weight quantization requires bfloat16 weights") + if weight.dim() != 2 or weight.shape[0] % 128 != 0 or weight.shape[1] % 128 != 0: + raise ValueError("GridFlip weight quantization requires a 2D [rows, cols] tensor with both dimensions divisible by 128") + return quant_gridflip_fp4( + weight, + amax=amax, + scale_override=scale_override, + grid_shift=grid_shift, + ) + + +@torch.compiler.disable() +def _quant_gridflip_weight_fp4_reference( + weight: torch.Tensor, + *, + amax: torch.Tensor, + scale_override: float, + grid_shift: float, + mode: NVFP4QuantMode, +) -> NVFP4Quant: + if mode != NVFP4QuantMode.FOUR_SIX: + raise ValueError("GridFlip weight quantization currently supports only FOUR_SIX mode") + if weight.dtype != torch.bfloat16: + raise TypeError("GridFlip weight quantization requires bfloat16 weights") + if weight.dim() != 2 or weight.shape[0] % 128 != 0 or weight.shape[1] % 128 != 0: + raise ValueError("GridFlip weight quantization requires a 2D [rows, cols] tensor with both dimensions divisible by 128") + standard = quant_fp4( + weight, + amax=amax, + scale_override=scale_override, + mode=mode, + ) + rows, cols = weight.shape + k_blocks = cols // 16 + blocks = weight.float().reshape(rows, k_blocks, 16) + global_scale = standard.tensor_scale.float().reshape(()) + + standard_scale_rowmajor = unblock( + standard.micro_scales, + rows, + cols, + ).contiguous() + standard_codes = standard.fp4.view(torch.uint8).to(torch.int32) + standard_codes = torch.stack( + [standard_codes & 0xF, (standard_codes >> 4) & 0xF], + dim=-1, + ).flatten(start_dim=-2) + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=weight.device, + ) + standard_dequant = ( + grid_dq[standard_codes] + .reshape(rows, k_blocks, 16) + .mul(standard_scale_rowmajor.float().unsqueeze(-1)) + .mul(global_scale) + ) + standard_error = ( + standard_dequant.reshape(rows, k_blocks, 16) - blocks + ).square().sum(dim=-1) + + best_error = torch.full((rows, k_blocks), float("inf"), device=weight.device) + best_codes = torch.empty((rows, k_blocks, 16), device=weight.device, dtype=torch.uint8) + best_scale = torch.empty((rows, k_blocks), device=weight.device, dtype=torch.float8_e4m3fn) + block_abs_max = blocks.abs().amax(dim=-1, keepdim=True) + + for candidate in (6.0, 4.0): + shifted_val_max = candidate + grid_shift + raw_scale = block_abs_max * scale_override / (shifted_val_max * global_scale) + scale_fp8 = raw_scale.to(torch.float8_e4m3fn) + scale_dq_f32 = scale_fp8.float() + scale_quant_f32 = torch.where(scale_dq_f32 == 0.0, torch.ones_like(scale_dq_f32), scale_dq_f32) + scaled = -blocks / (scale_quant_f32 * global_scale) - grid_shift + q_values, q_codes = _round_fp4_values_and_codes(scaled) + dequant = -(q_values + grid_shift) * scale_dq_f32 * global_scale + error = (dequant - blocks).square().sum(dim=-1) + improve = error < best_error + best_error = torch.where(improve, error, best_error) + best_codes = torch.where(improve[..., None], q_codes, best_codes) + best_scale = torch.where(improve, scale_fp8.reshape(rows, k_blocks), best_scale) + + flags = best_error < standard_error + shifted_fp4 = _pack_fp4_codes(best_codes) + fp4 = torch.where(flags.repeat_interleave(8, dim=1), shifted_fp4, standard.fp4) + + scale_rowmajor = torch.where(flags, best_scale, standard_scale_rowmajor) + scale_u8 = scale_rowmajor.view(torch.uint8) + scale_u8 |= flags.to(torch.uint8) * 0x80 + return NVFP4Quant(fp4.contiguous(), scale_rowmajor.contiguous(), standard.tensor_scale) + + +def _dq_gridflip_weight_fp4( + w_fp4: torch.Tensor, + w_mx_rowmajor: torch.Tensor, + grid_shift: float, +) -> torch.Tensor: + scale_u8 = w_mx_rowmajor.view(torch.uint8) + flags = (scale_u8 & 0x80).bool() + clean_scale_u8 = scale_u8 & 0x7F + clean_scale_rowmajor = clean_scale_u8.view(torch.float8_e4m3fn) + clean_scale_blocked = to_blocked(clean_scale_rowmajor).view_as(w_mx_rowmajor) + standard = _dq_fp4(w_fp4, clean_scale_blocked, 1.0).float() + correction = clean_scale_rowmajor.float().repeat_interleave(16, dim=1) + flags = flags.repeat_interleave(16, dim=1) + return torch.where(flags, -standard - grid_shift * correction, standard) + + +def _fp4_mm_dequantized(x_fp4: torch.Tensor, w_fp4: torch.Tensor, x_mx: torch.Tensor, w_mx: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: + x = _dq_fp4(x_fp4, x_mx, 1.0).to(torch.float32) + w = _dq_fp4(w_fp4, w_mx, 1.0).to(torch.float32) + return (x @ w.T * alpha.reshape(())).to(torch.bfloat16) + + +def _fp4_mm_gridflip_dequantized( + x_fp4: torch.Tensor, + w_fp4: torch.Tensor, + x_mx: torch.Tensor, + w_mx_rowmajor: torch.Tensor, + alpha: torch.Tensor, + grid_shift: float, +) -> torch.Tensor: + x = _dq_fp4(x_fp4, x_mx, 1.0).float() + w = _dq_gridflip_weight_fp4(w_fp4, w_mx_rowmajor, grid_shift).float() + return (x @ w.T * alpha.reshape(())).to(torch.bfloat16) + + +def _fp4_mm_gridflip( + x_fp4: torch.Tensor, + w_fp4: torch.Tensor, + x_mx: torch.Tensor, + w_mx_rowmajor: torch.Tensor, + alpha: torch.Tensor, + *, + grid_shift: float, + backend: FP4MatmulBackend | None = None, +) -> torch.Tensor: + backend = _resolve_fp4_mm_backend(backend) + if backend == "qutlass": + return _fp4_mm_gridflip_qutlass( + x_fp4, + w_fp4, + x_mx, + w_mx_rowmajor, + alpha, + grid_shift, + ) + if backend == "dequantized": + return _fp4_mm_gridflip_dequantized( + x_fp4, + w_fp4, + x_mx, + w_mx_rowmajor, + alpha, + grid_shift, + ) + raise ValueError("GridFlip weight quantization requires the qutlass or dequantized FP4 matmul backend") + + +def _fp4_mm( + x_fp4: torch.Tensor, + w_fp4: torch.Tensor, + x_mx: torch.Tensor, + w_mx: torch.Tensor, + alpha: torch.Tensor, + *, + backend: FP4MatmulBackend | None = None, +) -> torch.Tensor: + backend = _resolve_fp4_mm_backend(backend) + if backend == "flashinfer": + return _fp4_mm_flashinfer(x_fp4, w_fp4, x_mx, w_mx, alpha) + if backend == "qutlass": + return _fp4_mm_qutlass(x_fp4, w_fp4, x_mx, w_mx, alpha) + if backend == "dequantized": + return _fp4_mm_dequantized(x_fp4, w_fp4, x_mx, w_mx, alpha) + raise ValueError(f"backend must be one of {FP4_MATMUL_BACKENDS}, got {backend!r}") + + @torch.compile(dynamic=False) def abs_max(x): return x.abs().max().to(torch.float32) @@ -154,7 +660,8 @@ class Quartet_II_fn(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quant: bool = False, weight_amax: torch.Tensor = None, input_amax: torch.Tensor = None, - scratch_amax: torch.Tensor = None): + scratch_amax: torch.Tensor = None, wush_input_transform: torch.Tensor = None, + wush_weight_transform: torch.Tensor = None): ctx.batch = input.shape[0] ctx.seq = input.shape[1] ctx.in_dim = weight.shape[1] @@ -162,6 +669,9 @@ def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quan ctx.disable_backward_quant = disable_backward_quant ctx.scratch_amax = scratch_amax ctx.had = had + ctx.wush_input_transform = wush_input_transform + ctx.wush_weight_transform = wush_weight_transform + ctx.fp4_mm_backend = get_fp4_mm_backend() autocast_enabled = torch.is_autocast_enabled("cuda") if autocast_enabled: @@ -175,6 +685,9 @@ def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quan forward_scale_override = 1.0 flat_input = input.reshape(-1, input.shape[-1]) + if wush_input_transform is not None: + flat_input = apply_block_transform(flat_input.contiguous(), wush_input_transform) + weight = apply_block_transform(weight.contiguous(), wush_weight_transform) with nvtx_annotate("Abs-max", color="red"): if input_amax is None: @@ -184,14 +697,37 @@ def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quan with nvtx_annotate("Quant", color="yellow"): input_fp4 = quant_fp4(flat_input, amax=input_amax, scale_override=forward_scale_override, mode=mode) - weight_fp4 = quant_fp4(weight, amax=weight_amax, scale_override=forward_scale_override, mode=mode) + weight_backward_fp4 = quant_fp4(weight, amax=weight_amax, scale_override=forward_scale_override, mode=mode) + use_gridflip_weight = get_fp4_weight_quantizer() == "gridflip" + if use_gridflip_weight: + weight_fp4 = _quant_gridflip_weight_fp4( + weight, + amax=weight_amax, + scale_override=forward_scale_override, + grid_shift=get_gridflip_shift(), + mode=mode, + ) + else: + weight_fp4 = weight_backward_fp4 ctx.save_for_backward(input_fp4.fp4, input_fp4.micro_scales, input_fp4.tensor_scale, - weight_fp4.fp4, weight_fp4.micro_scales, weight_fp4.tensor_scale) + weight_backward_fp4.fp4, weight_backward_fp4.micro_scales, weight_backward_fp4.tensor_scale) with nvtx_annotate("Matmul", color="blue"): - res = _fp4_mm( - input_fp4.fp4, weight_fp4.fp4, - input_fp4.micro_scales, weight_fp4.micro_scales, - alpha=input_fp4.tensor_scale * weight_fp4.tensor_scale) + if use_gridflip_weight: + res = _fp4_mm_gridflip( + input_fp4.fp4, + weight_fp4.fp4, + input_fp4.micro_scales, + weight_fp4.micro_scales, + alpha=input_fp4.tensor_scale * weight_fp4.tensor_scale, + grid_shift=get_gridflip_shift(), + backend=ctx.fp4_mm_backend, + ) + else: + res = _fp4_mm( + input_fp4.fp4, weight_fp4.fp4, + input_fp4.micro_scales, weight_fp4.micro_scales, + alpha=input_fp4.tensor_scale * weight_fp4.tensor_scale, + backend=ctx.fp4_mm_backend) return res.reshape(ctx.batch, ctx.seq, ctx.out_dim) @@ -215,22 +751,42 @@ def backward(ctx, grad_output): wr = _dq_fp4(wfp4, ws, wm) grad_input = flat_grad_output @ wr grad_weight = flat_grad_output.T @ xr - return grad_input.reshape(ctx.batch, ctx.seq, ctx.in_dim), grad_weight, None, None, None, None, None, None + if ctx.wush_input_transform is not None: + grad_input = apply_block_transform(grad_input.contiguous(), ctx.wush_input_transform.transpose(-1, -2)) + grad_weight = apply_block_transform(grad_weight.contiguous(), ctx.wush_weight_transform.transpose(-1, -2)) + return grad_input.reshape(ctx.batch, ctx.seq, ctx.in_dim), grad_weight, None, None, None, None, None, None, None, None # EW with nvtx_annotate("Quant", color="yellow"): e_ht_fp4, e_ht_ms, e_ht_ts = rht128_quant_eden(x=flat_grad_output, h=had[:16, :], scale_override=backward_scale_override, scratch_amax=ctx.scratch_amax) wt_ht_fp4, wt_ht_ms, wt_ht_ts = rht128_requant(x=wfp4, x_group_scales=ws, x_tensor_scale=wm, h=had[:16, :], scale_override=backward_scale_override, scratch_amax=ctx.scratch_amax) with nvtx_annotate("Matmul", color="blue"): - grad_input = _fp4_mm(e_ht_fp4, wt_ht_fp4, e_ht_ms, wt_ht_ms, alpha=e_ht_ts*wt_ht_ts) + grad_input = _fp4_mm( + e_ht_fp4, + wt_ht_fp4, + e_ht_ms, + wt_ht_ms, + alpha=e_ht_ts*wt_ht_ts, + backend=ctx.fp4_mm_backend, + ) # EtX with nvtx_annotate("Quant", color="yellow"): et_ht_fp4, et_ht_ms, et_ht_ts = rht128_quant_eden(x=flat_grad_output, h=had[:16, :], scale_override=backward_scale_override, transpose=True, scratch_amax=ctx.scratch_amax) xt_ht_fp4, xt_ht_ms, xt_ht_ts = rht128_requant(x=xfp4, x_group_scales=xs, x_tensor_scale=xm, h=had[:16, :], scale_override=backward_scale_override, scratch_amax=ctx.scratch_amax) with nvtx_annotate("Matmul", color="blue"): - grad_weight = _fp4_mm(et_ht_fp4, xt_ht_fp4, et_ht_ms, xt_ht_ms, alpha=et_ht_ts*xt_ht_ts) - return grad_input.reshape(ctx.batch, ctx.seq, ctx.in_dim), grad_weight, None, None, None, None, None, None + grad_weight = _fp4_mm( + et_ht_fp4, + xt_ht_fp4, + et_ht_ms, + xt_ht_ms, + alpha=et_ht_ts*xt_ht_ts, + backend=ctx.fp4_mm_backend, + ) + if ctx.wush_input_transform is not None: + grad_input = apply_block_transform(grad_input.contiguous(), ctx.wush_input_transform.transpose(-1, -2)) + grad_weight = apply_block_transform(grad_weight.contiguous(), ctx.wush_weight_transform.transpose(-1, -2)) + return grad_input.reshape(ctx.batch, ctx.seq, ctx.in_dim), grad_weight, None, None, None, None, None, None, None, None class Quartet_II_linear(torch.nn.Linear): @@ -238,10 +794,26 @@ def __init__(self, *args, four_over_six=True, **kwargs): super().__init__(*args, **kwargs) self.mode = NVFP4QuantMode.FOUR_SIX if four_over_six else NVFP4QuantMode.RNE self.weight_abs_max = None + self.wush_enabled = False + self.wush_update_freq = 200 + self.wush_damp = 1e-3 + self.wush_s_min = 1e-2 + self.wush_max_cond = 1e4 + self.wush_ema_decay = 0.99 + self.wush_group_size = 128 + self.wush_g_identity = True + self.wush_step = -1 + self.wush_last_update_step = -1 # initialize hadamard matrix. # *if* we are on meta device, initialization will be deferred until we move to real device (handled in _apply) had = get_hadamard_matrix(128, torch.bfloat16, self.weight.device) if self.weight.device.type != 'meta' else None self.register_buffer("had", had, persistent=False) + self.register_buffer("wush_input_transform", None, persistent=False) + self.register_buffer("wush_weight_transform", None, persistent=False) + self.register_buffer("wush_sigma_x", None, persistent=False) + self.register_buffer("wush_sigma_w", None, persistent=False) + self.register_buffer("wush_ema_count", torch.zeros((), dtype=torch.long, device=self.weight.device), persistent=False) + self.register_buffer("wush_last_conds", None, persistent=False) self.register_buffer("scratch_amax", torch.empty((), dtype=torch.uint32, device=self.weight.device), persistent=False) def _apply(self, fn): @@ -252,8 +824,85 @@ def _apply(self, fn): self.had = get_hadamard_matrix(128, torch.bfloat16, self.weight.device) return self + @torch.no_grad() + def configure_wush(self, enabled: bool, update_freq: int = 200, damp: float = 1e-3, + s_min: float = 1e-2, max_cond: float = 1e4, + ema_decay: float = 0.99, group_size: int = 128, + g_identity: bool = True): + self.wush_enabled = enabled + self.wush_update_freq = update_freq + self.wush_damp = damp + self.wush_s_min = s_min + self.wush_max_cond = max_cond + self.wush_ema_decay = ema_decay + self.wush_group_size = group_size + self.wush_g_identity = g_identity + self.wush_last_update_step = -1 + if not enabled: + self.wush_input_transform = None + self.wush_weight_transform = None + self.wush_sigma_x = None + self.wush_sigma_w = None + self.wush_ema_count.zero_() + self.wush_last_conds = None + return + + if self.weight.shape[1] % group_size != 0: + raise ValueError(f"WUSH requires in_features divisible by {group_size}, got {self.weight.shape[1]}") + + groups = self.weight.shape[1] // group_size + device = self.weight.device + had = get_hadamard_matrix(group_size, torch.bfloat16, device) + self.wush_input_transform = had.T.unsqueeze(0).expand(groups, -1, -1).clone() + self.wush_weight_transform = had.T.unsqueeze(0).expand(groups, -1, -1).clone() + self.wush_sigma_x = torch.zeros((groups, group_size, group_size), device=device, dtype=torch.float32) + self.wush_sigma_w = None if g_identity else torch.zeros_like(self.wush_sigma_x) + self.wush_ema_count.zero_() + self.wush_last_conds = torch.ones((groups,), device=device, dtype=torch.float32) + + def set_wush_step(self, step: int): + self.wush_step = step + + @torch.compiler.disable() + @torch.no_grad() + def update_wush_moments(self, x: torch.Tensor): + if not self.wush_enabled or not self.training: + return + + x = x.detach().reshape(-1, x.shape[-1]).contiguous() + update_wush_moments( + self.wush_sigma_x, + self.wush_sigma_w, + self.wush_ema_count, + x, + self.weight.detach(), + self.wush_group_size, + self.wush_ema_decay, + ) + + @torch.compiler.disable() + @torch.no_grad() + def recompute_wush_transform(self): + if not self.wush_enabled or self.wush_ema_count.item() == 0: + return False + + input_transform, weight_transform, conds = get_wush_transforms_from_moments( + self.wush_sigma_x, + self.wush_sigma_w, + self.wush_group_size, + self.wush_damp, + self.wush_s_min, + self.wush_max_cond, + ) + self.wush_input_transform = input_transform + self.wush_weight_transform = weight_transform + self.wush_last_conds = conds.to(dtype=torch.float32) + return True + def forward(self, x, disable_backward_quant=False, input_abs_max=None): - return Quartet_II_fn.apply(x, self.weight[...], self.had, self.mode, disable_backward_quant, self.weight_abs_max, input_abs_max, self.scratch_amax) + if self.wush_enabled: + self.update_wush_moments(x) + return Quartet_II_fn.apply(x, self.weight[...], self.had, self.mode, disable_backward_quant, self.weight_abs_max, input_abs_max, self.scratch_amax, self.wush_input_transform, self.wush_weight_transform) def register_optimizer_hook(model: torch.nn.Module, optimizer: torch.optim.Optimizer): @@ -268,3 +917,44 @@ def hook(opt, args, kwargs): m.weight_abs_max = abs_max(m.weight) return optimizer.register_step_post_hook(hook) + + +def configure_wush(model: torch.nn.Module, enabled: bool, update_freq: int = 200, damp: float = 1e-3, + s_min: float = 1e-2, max_cond: float = 1e4, ema_decay: float = 0.99, + group_size: int = 128, g_identity: bool = True): + for module in model.modules(): + if isinstance(module, Quartet_II_linear): + module.configure_wush(enabled, update_freq, damp, s_min, max_cond, ema_decay, group_size, g_identity) + + +def set_wush_step(model: torch.nn.Module, step: int): + for module in model.modules(): + if isinstance(module, Quartet_II_linear): + module.set_wush_step(step) + + +@torch.no_grad() +def update_wush_transforms(model: torch.nn.Module, sync_distributed: bool = False): + modules = [module for module in model.modules() if isinstance(module, Quartet_II_linear) and module.wush_enabled] + if not modules: + return 0 + + if sync_distributed and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + for module in modules: + torch.distributed.all_reduce(module.wush_sigma_x, op=torch.distributed.ReduceOp.SUM) + module.wush_sigma_x.div_(world_size) + + updated = 0 + for module in modules: + updated += int(module.recompute_wush_transform()) + + if sync_distributed and torch.distributed.is_initialized(): + for module in modules: + for attr in ("wush_input_transform", "wush_weight_transform"): + tensor = getattr(module, attr) + flat = tensor.data.clone().contiguous() + torch.distributed.broadcast(flat, src=0) + tensor.data.copy_(flat) + + return updated diff --git a/quartet2/python/quartet2/quant.py b/quartet2/python/quartet2/quant.py index 2701597..068bd27 100644 --- a/quartet2/python/quartet2/quant.py +++ b/quartet2/python/quartet2/quant.py @@ -11,6 +11,11 @@ def _four_six_fp4_op(o: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x: torch _quartet2.four_six_fp4(o, s.view(torch.uint8), t, x.detach(), amax, scale_override) +@torch.library.custom_op("quartet2::gridflip_four_six", mutates_args=("o", "s", "t")) +def _gridflip_four_six_fp4_op(o: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x: torch.Tensor, amax: torch.Tensor, scale_override: float, grid_shift: float) -> None: + _quartet2.gridflip_four_six_fp4(o, s.view(torch.uint8), t, x.detach(), amax, scale_override, grid_shift) + + @torch.library.custom_op("quartet2::rtn", mutates_args=("o", "s", "t")) def _rtn_fp4_op(o: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x: torch.Tensor, amax: torch.Tensor, scale_override: float) -> None: _quartet2.rtn_fp4(o, s.view(torch.uint8), t, x.detach(), amax, scale_override) @@ -138,6 +143,24 @@ def quant_fp4(x, *, scale_override: float, amax: torch.Tensor = None, mode=NVFP4 return NVFP4Quant(q, s, global_scale) +def quant_gridflip_fp4(x, *, scale_override: float, grid_shift: float, amax: torch.Tensor = None) -> NVFP4Quant: + q = torch.empty((x.shape[0], x.shape[1] // 2), device=x.device, dtype=torch.uint8) + s = torch.empty((x.shape[0], x.shape[1] // 16), device=x.device, dtype=torch.float8_e4m3fn) + assert x.dtype == torch.bfloat16 + assert x.is_cuda + assert x.is_contiguous() + assert x.shape[0] % 128 == 0 + assert x.shape[1] % 128 == 0 + + if amax is None: + amax = torch.max(torch.abs(x)).to(torch.float32) + else: + assert amax.dtype == torch.float32 + global_scale = torch.empty((), device=x.device, dtype=torch.float32) + _gridflip_four_six_fp4_op(q, s, global_scale, x, amax, scale_override, grid_shift) + return NVFP4Quant(q, s, global_scale) + + def quant_had_eden( *, out: torch.Tensor = None, @@ -489,4 +512,4 @@ def rht128_requant( group_scales = group_scales.reshape(rows, cols // 16) - return NVFP4Quant(out, group_scales, tensor_scale) \ No newline at end of file + return NVFP4Quant(out, group_scales, tensor_scale) diff --git a/quartet2/test/test_qutlass_backend.py b/quartet2/test/test_qutlass_backend.py new file mode 100644 index 0000000..0a3c52f --- /dev/null +++ b/quartet2/test/test_qutlass_backend.py @@ -0,0 +1,255 @@ +import pytest +import torch + +pytest.importorskip("qutlass") +from quartet2.linear import ( + Quartet_II_linear, + _dq_gridflip_weight_fp4, + _fp4_mm, + _fp4_mm_gridflip, + _quant_gridflip_weight_fp4, + _quant_gridflip_weight_fp4_reference, + fp4_mm_backend, + get_fp4_mm_backend, + get_gridflip_shift, + get_fp4_weight_quantizer, + set_fp4_weight_quantizer, +) +from quartet2.quant import NVFP4QuantMode, quant_fp4 + + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + + +def _quantize_pair(m: int, n: int, k: int): + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + xq = quant_fp4( + x, + amax=x.abs().max().float(), + scale_override=1.0, + mode=NVFP4QuantMode.FOUR_SIX, + ) + wq = quant_fp4( + w, + amax=w.abs().max().float(), + scale_override=1.0, + mode=NVFP4QuantMode.FOUR_SIX, + ) + return xq, wq + + +@pytest.mark.parametrize("shape", [(128, 256, 128), (512, 256, 128), (128, 384, 256)]) +@torch.inference_mode() +def test_qutlass_backend_matches_flashinfer_and_dequantized(shape): + torch.manual_seed(0) + m, n, k = shape + xq, wq = _quantize_pair(m, n, k) + alpha = xq.tensor_scale * wq.tensor_scale + + old_backend = get_fp4_mm_backend() + with fp4_mm_backend("flashinfer"): + flashinfer_out = _fp4_mm( + xq.fp4, + wq.fp4, + xq.micro_scales, + wq.micro_scales, + alpha, + ) + with fp4_mm_backend("qutlass"): + qutlass_out = _fp4_mm( + xq.fp4, + wq.fp4, + xq.micro_scales, + wq.micro_scales, + alpha, + ) + with fp4_mm_backend("dequantized"): + dequantized_out = _fp4_mm( + xq.fp4, + wq.fp4, + xq.micro_scales, + wq.micro_scales, + alpha, + ) + + assert get_fp4_mm_backend() == old_backend + assert qutlass_out.equal(flashinfer_out) + torch.testing.assert_close(qutlass_out, dequantized_out, rtol=2e-2, atol=2.0) + + +@torch.inference_mode() +def test_qutlass_backend_matches_flashinfer_linear_forward(): + torch.manual_seed(1) + linear = Quartet_II_linear(128, 256, device="cuda", dtype=torch.bfloat16) + x = torch.randn((1, 512, 128), device="cuda", dtype=torch.bfloat16) + + with fp4_mm_backend("flashinfer"): + flashinfer_out = linear(x) + with fp4_mm_backend("qutlass"): + qutlass_out = linear(x) + with fp4_mm_backend("dequantized"): + dequantized_out = linear(x) + + assert qutlass_out.equal(flashinfer_out) + torch.testing.assert_close(qutlass_out, dequantized_out, rtol=2e-2, atol=2.0) + + +@torch.inference_mode() +def test_qutlass_backend_compiles_forward(): + torch.manual_seed(2) + linear = Quartet_II_linear(128, 256, device="cuda", dtype=torch.bfloat16) + x = torch.randn((1, 128, 128), device="cuda", dtype=torch.bfloat16) + + def fwd(inp): + return linear(inp) + + with fp4_mm_backend("qutlass"): + out = torch.compile(fwd, fullgraph=True)(x) + + assert out.shape == (1, 128, 256) + + +def test_qutlass_backend_matches_flashinfer_backward(): + torch.manual_seed(3) + weight = torch.randn((256, 128), device="cuda", dtype=torch.bfloat16) + x_ref = torch.randn((1, 128, 128), device="cuda", dtype=torch.bfloat16) + grad = torch.randn((1, 128, 256), device="cuda", dtype=torch.bfloat16) + + def run(backend: str, backward_backend: str | None = None): + linear = Quartet_II_linear(128, 256, bias=False, device="cuda", dtype=torch.bfloat16) + with torch.no_grad(): + linear.weight.copy_(weight) + x = x_ref.clone().requires_grad_(True) + torch.manual_seed(100) + with fp4_mm_backend(backend): + y = linear(x) + torch.manual_seed(200) + if backward_backend is None: + y.backward(grad) + else: + with fp4_mm_backend(backward_backend): + y.backward(grad) + return y.detach(), x.grad.detach(), linear.weight.grad.detach() + + flashinfer_out, flashinfer_x_grad, flashinfer_w_grad = run("flashinfer") + qutlass_out, qutlass_x_grad, qutlass_w_grad = run("qutlass") + pinned_out, pinned_x_grad, pinned_w_grad = run("qutlass", backward_backend="dequantized") + + assert qutlass_out.equal(flashinfer_out) + assert qutlass_x_grad.equal(flashinfer_x_grad) + assert qutlass_w_grad.equal(flashinfer_w_grad) + assert pinned_out.equal(qutlass_out) + assert pinned_x_grad.equal(qutlass_x_grad) + assert pinned_w_grad.equal(qutlass_w_grad) + + +@pytest.mark.parametrize("scale_override", [1.0, 0.875]) +@torch.inference_mode() +def test_gridflip_quantizer_does_not_exceed_python_reference_error(scale_override): + torch.manual_seed(4) + n, k = 256, 128 + w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + wq = _quant_gridflip_weight_fp4( + w, + amax=w.abs().max().float(), + scale_override=scale_override, + grid_shift=0.25, + mode=NVFP4QuantMode.FOUR_SIX, + ) + ref_wq = _quant_gridflip_weight_fp4_reference( + w, + amax=w.abs().max().float(), + scale_override=scale_override, + grid_shift=0.25, + mode=NVFP4QuantMode.FOUR_SIX, + ) + flags = (wq.micro_scales.view(torch.uint8) & 0x80) != 0 + assert flags.any() + assert wq.tensor_scale.equal(ref_wq.tensor_scale) + cuda_dq = _dq_gridflip_weight_fp4(wq.fp4, wq.micro_scales, 0.25).float() * wq.tensor_scale.float() + ref_dq = _dq_gridflip_weight_fp4(ref_wq.fp4, ref_wq.micro_scales, 0.25).float() * ref_wq.tensor_scale.float() + cuda_error = (cuda_dq - w.float()).square().reshape(n, k // 16, 16).sum(dim=-1) + ref_error = (ref_dq - w.float()).square().reshape(n, k // 16, 16).sum(dim=-1) + assert torch.all(cuda_error <= ref_error + 1e-5) + + +@torch.inference_mode() +def test_gridflip_matmul_matches_dequantized_reference(): + torch.manual_seed(4) + m, n, k = 128, 256, 128 + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + xq = quant_fp4( + x, + amax=x.abs().max().float(), + scale_override=1.0, + mode=NVFP4QuantMode.FOUR_SIX, + ) + wq = _quant_gridflip_weight_fp4( + w, + amax=w.abs().max().float(), + scale_override=1.0, + grid_shift=0.25, + mode=NVFP4QuantMode.FOUR_SIX, + ) + alpha = xq.tensor_scale * wq.tensor_scale + + with fp4_mm_backend("qutlass"): + qutlass_out = _fp4_mm_gridflip( + xq.fp4, + wq.fp4, + xq.micro_scales, + wq.micro_scales, + alpha, + grid_shift=0.25, + ) + with fp4_mm_backend("dequantized"): + dequantized_out = _fp4_mm_gridflip( + xq.fp4, + wq.fp4, + xq.micro_scales, + wq.micro_scales, + alpha, + grid_shift=0.25, + ) + + torch.testing.assert_close(qutlass_out, dequantized_out, rtol=2e-2, atol=2.0) + + +def test_gridflip_linear_forward_matches_dequantized_and_pins_backward_backend(): + torch.manual_seed(5) + weight = torch.randn((256, 128), device="cuda", dtype=torch.bfloat16) + x_ref = torch.randn((1, 128, 128), device="cuda", dtype=torch.bfloat16) + grad = torch.randn((1, 128, 256), device="cuda", dtype=torch.bfloat16) + + def run(backend: str, backward_backend: str | None = None): + old_quantizer = get_fp4_weight_quantizer() + old_shift = get_gridflip_shift() + set_fp4_weight_quantizer("gridflip", gridflip_shift=0.25) + linear = Quartet_II_linear(128, 256, bias=False, device="cuda", dtype=torch.bfloat16) + with torch.no_grad(): + linear.weight.copy_(weight) + x = x_ref.clone().requires_grad_(True) + try: + with fp4_mm_backend(backend): + y = linear(x) + torch.manual_seed(200) + if backward_backend is None: + y.backward(grad) + else: + with fp4_mm_backend(backward_backend): + y.backward(grad) + finally: + set_fp4_weight_quantizer(old_quantizer, gridflip_shift=old_shift) + return y.detach(), x.grad.detach(), linear.weight.grad.detach() + + qutlass_out, qutlass_x_grad, qutlass_w_grad = run("qutlass") + pinned_out, pinned_x_grad, pinned_w_grad = run("qutlass", backward_backend="dequantized") + dequantized_out, _, _ = run("dequantized") + + torch.testing.assert_close(qutlass_out, dequantized_out, rtol=2e-2, atol=2.0) + assert pinned_out.equal(qutlass_out) + assert pinned_x_grad.equal(qutlass_x_grad) + assert pinned_w_grad.equal(qutlass_w_grad) diff --git a/src/train.py b/src/train.py index c270245..0b05a24 100644 --- a/src/train.py +++ b/src/train.py @@ -89,18 +89,50 @@ parser.add_argument("--quartet", help="quartet2.linear.Quartet_II_linear instead of torch.nn.Linear", type=utils.str_to_bool, default=True) parser.add_argument("--fake_quartet", help="Fake (simulated) NVFP4 quantization", type=utils.str_to_bool, default=False) +parser.add_argument("--quartet_matmul_backend", help="Matmul backend for real Quartet-II linears", choices=["flashinfer", "qutlass", "dequantized"], default="flashinfer") +parser.add_argument("--quartet_weight_quantizer", help="Weight quantizer for real Quartet-II linears", choices=["four_six", "gridflip"], default="four_six") +parser.add_argument("--gridflip_shift", help="GridFlip shifted-grid correction value", type=float, default=0.25) +parser.add_argument("--wush", help="Apply blockwise WUSH transforms in Quartet-II forward quantization", type=utils.str_to_bool, default=False) +parser.add_argument("--wush_update_freq", help="Every how many optimizer steps to refresh WUSH transforms", type=int, default=200) +parser.add_argument("--wush_damp", help="Tikhonov damping for KFAC WUSH second-moment estimates", type=float, default=1e-3) +parser.add_argument("--wush_s_min", help="Singular value floor for KFAC WUSH transform updates", type=float, default=1e-2) +parser.add_argument("--wush_max_cond", help="Per-block condition-number fallback threshold for KFAC WUSH", type=float, default=1e4) +parser.add_argument("--wush_ema_decay", help="EMA decay for KFAC WUSH activation second moments", type=float, default=0.99) +parser.add_argument("--wush_group_size", help="Block size for KFAC WUSH transforms", type=int, default=128) +parser.add_argument("--wush_g_identity", help="Use identity output KFAC factor G when recomputing WUSH transforms", type=utils.str_to_bool, default=True) parser.add_argument("--num_blocks", help="Number of Transformer blocks", type=int, default=4) parser.add_argument("--heads", help="Number of Q heads in the MHSA", type=int, default=6) parser.add_argument("--ratio", help="Ratio between Q heads and KV heads", type=int, default=3) parser.add_argument("--tied_embeddings", help="Tie input and output embeddings", type=utils.str_to_bool, default=True) parser.add_argument("--dataset_path", help="If passed, overrides where the dataset is loaded from", type=os.path.abspath, default=None) parser.add_argument("--dataset_seed", help="Seed to use for dataset sampling.", type=int, default=-1) +parser.add_argument("--seed", help="Seed for model initialization and torch RNGs. Negative means do not set it.", type=int, default=-1) parser.add_argument("--wandb_kwargs", help="Keyword arguments for wandb.init()", type=json.loads, default=None) parser.add_argument("--val_fixed", help="As an extra, evaluate the loss on fixed validation batches", type=utils.str_to_bool, default=True) args=parser.parse_args() if args.quartet and args.fake_quartet: parser.error("--quartet and --fake_quartet are mutually exclusive") +if not args.quartet and args.quartet_matmul_backend != "flashinfer": + parser.error("--quartet_matmul_backend applies only when --quartet true") +if not args.quartet and args.quartet_weight_quantizer != "four_six": + parser.error("--quartet_weight_quantizer applies only when --quartet true") +if args.quartet_weight_quantizer == "gridflip" and args.quartet_matmul_backend == "flashinfer": + parser.error("--quartet_weight_quantizer gridflip requires --quartet_matmul_backend qutlass or dequantized") +if args.gridflip_shift < 0: + parser.error("--gridflip_shift must be non-negative") +if args.quartet_weight_quantizer == "gridflip" and args.comp: + if "MASTER_ADDR" not in os.environ or int(os.getenv("RANK", 0)) == 0: + print("📌 GridFlip's fused matmul adapter is not torch.compile-ready in this integration, so disabling --comp for this run.") + args.comp = False +if args.wush and not args.quartet: + parser.error("--wush currently applies to real Quartet-II linears; use --quartet true") +if args.wush and (args.wush_group_size & (args.wush_group_size - 1)): + parser.error("--wush_group_size must be a power of two for the Hadamard transform") +if args.wush and args.comp: + if "MASTER_ADDR" not in os.environ or int(os.getenv("RANK", 0)) == 0: + print("📌 WUSH uses dynamic eigendecompositions, so disabling --comp for this run.") + args.comp = False if torch.distributed.is_torchelastic_launched(): # Get environment variables set by torchrun @@ -129,6 +161,14 @@ accumulation = args.batch_size//args.micro_batch_size torch.cuda.set_device(model_device) +gpu_reservation_gib = max(0, int(os.getenv("CLOVERLM_GPU_RESERVATION_GIB", "0"))) +if gpu_reservation_gib: + # Optional keepalive for external GPU reservation monitors during CPU-side + # dataset loading. The allocation is released before model initialization. + _gpu_reservation_touch = torch.empty(gpu_reservation_gib * 1024 ** 3, device=model_device, dtype=torch.uint8) +else: + _gpu_reservation_touch = None + subpath_dir = os.path.dirname(os.path.abspath(args.NAME)) if master: os.makedirs(subpath_dir, exist_ok=True) checkpoint_path = args.NAME+"_checkpoint" @@ -202,10 +242,26 @@ elif args.dataset_device_type == "cuda": dataset_device = model_device +if args.seed >= 0: + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + if master: print("💾 Loading dataset") train_iterator = data.utils_data.get_iterator(args.dataset, "train", dataset_device, args.micro_batch_size, args.context, RANK, args.dataset_path, args.dataset_seed) val_iterator = data.utils_data.get_iterator(args.dataset, "val", dataset_device, args.micro_batch_size, args.context, RANK, args.dataset_path, args.dataset_seed) fixed_val_batch = next(val_iterator) +if _gpu_reservation_touch is not None: + del _gpu_reservation_touch + torch.cuda.empty_cache() + +if args.quartet: + import quartet2.linear + quartet2.linear.set_fp4_mm_backend(args.quartet_matmul_backend) + quartet2.linear.set_fp4_weight_quantizer( + args.quartet_weight_quantizer, + gridflip_shift=args.gridflip_shift, + ) if master: print("🧠 Initializing model") model_or_ddp, opts = models.utils_models.get_model_opts( @@ -217,6 +273,12 @@ num_blocks=args.num_blocks, heads=args.heads, ratio=args.ratio, tied_embeddings=args.tied_embeddings) model = model_or_ddp.module if torch.distributed.is_initialized() else model_or_ddp +if args.wush: + import quartet2.linear + quartet2.linear.configure_wush(model, True, args.wush_update_freq, args.wush_damp, args.wush_s_min, + args.wush_max_cond, args.wush_ema_decay, args.wush_group_size, + args.wush_g_identity) + checkpoint_dict["checkpoint"].model = model checkpoint_dict["checkpoint"].opts = opts @@ -326,7 +388,8 @@ with torch.autocast(device_type=model_device_type, dtype=args.dtype): micro_train_loss = get_loss(args.dataset, model_or_ddp, batch_train_X, batch_train_Y, args.label_smoothing)[1] * loss_scale_acc train_loss += micro_train_loss.detach() - scaler.scale(micro_train_loss).backward() + retain_accumulation_graph = args.quartet and micro_train_batch < accumulation - 1 + scaler.scale(micro_train_loss).backward(retain_graph=retain_accumulation_graph) step_timer.end() @@ -450,6 +513,13 @@ scaler.update() opt_timer.end() + next_train_batch = checkpoint_dict["checkpoint"].train_batch + 1 + if args.wush and next_train_batch > 0 and next_train_batch % args.wush_update_freq == 0: + import quartet2.linear + n_updated = quartet2.linear.update_wush_transforms(model, sync_distributed=torch.distributed.is_initialized()) + if master: + print(f"📌 WUSH recomputed transforms for {n_updated} Quartet-II layers at step {next_train_batch}") + lr = schedulers[0].get_last_lr()[0] for scheduler in schedulers: scheduler.step() @@ -464,7 +534,7 @@ checkpoint_dict["checkpoint"].train_batch += 1 current_batch = checkpoint_dict["checkpoint"].train_batch - if current_batch > 1 and ((current_batch % args.checkpoint_freq == 0) or (current_batch == last_batch)): + if args.checkpoint_freq != utils.INF and current_batch > 1 and ((current_batch % args.checkpoint_freq == 0) or (current_batch == last_batch)): checkpoint_id = f"{checkpoint_path}/{current_batch}" torch.distributed.checkpoint.save(checkpoint_dict, checkpoint_id=checkpoint_id) if master: