diff --git a/benchmarks/benchmark_rht_cast_swizzle_fusion.py b/benchmarks/benchmark_rht_cast_swizzle_fusion.py new file mode 100644 index 0000000000..13f264862c --- /dev/null +++ b/benchmarks/benchmark_rht_cast_swizzle_fusion.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output. + +For each shape we measure two paths and two builds: + + * path = "quant_only": just NVFP4Quantizer(x) + * path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t) + (this is what te.Linear -> tex.generic_gemm does right before the + cuBLAS LT NVFP4 GEMM dispatch) + + * build = "baseline": optimize_for_gemm=False + -> quant kernel emits compact SF; + tex.swizzle_scales_for_gemm_ launches the standalone + swizzle_{row,col}_scaling_kernel pass before GEMM. + * build = "swizzle_fusion": optimize_for_gemm=True + -> quant kernel emits GEMM-swizzled SF directly (via the + kEnableSwizzleSFOutput compile-time switch in + row_cast_col_hadamard_transform_cast_fusion.cu); + tex.swizzle_scales_for_gemm_ early-returns and the standalone + swizzle pass disappears from the timeline. + +The wall-clock delta on the "quant_plus_swizzle" path is the production +saving of this PR. +""" + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order +import transformer_engine_torch as tex +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + +def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer: + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + q.optimize_for_gemm = optimize_for_gemm + return q + + +def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float: + """Returns median wall-clock per call in microseconds.""" + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=min_run_time) + return timing.median * 1e6 + + +def run_shape(shape, min_run_time: float): + M, K = shape + assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16" + + x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda") + q_base = make_quantizer(optimize_for_gemm=False) + q_swf = make_quantizer(optimize_for_gemm=True) + + # quant_only path + quant_only_base_us = _bench( + stmt="q(x)", + globals_dict={"q": q_base, "x": x}, + min_run_time=min_run_time, + ) + quant_only_swf_us = _bench( + stmt="q(x)", + globals_dict={"q": q_swf, "x": x}, + min_run_time=min_run_time, + ) + + # quant_plus_swizzle path (this is what te.Linear actually runs) + quant_plus_swizzle_base_us = _bench( + stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)", + globals_dict={"q": q_base, "x": x, "tex": tex}, + min_run_time=min_run_time, + ) + quant_plus_swizzle_swf_us = _bench( + stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)", + globals_dict={"q": q_swf, "x": x, "tex": tex}, + min_run_time=min_run_time, + ) + + saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us + speedup = ( + quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us + if quant_plus_swizzle_swf_us > 0 + else float("inf") + ) + + print( + f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, " + f"SUT={quant_only_swf_us:.2f}us | " + f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, " + f"SUT={quant_plus_swizzle_swf_us:.2f}us " + f"-> saved {saved_us:.2f}us ({speedup:.2f}x)" + ) + + return { + "shape": shape, + "M": M, + "K": K, + "quant_only_base_us": quant_only_base_us, + "quant_only_swf_us": quant_only_swf_us, + "quant_plus_swizzle_base_us": quant_plus_swizzle_base_us, + "quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us, + "saved_us": saved_us, + "speedup": speedup, + } + + +# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears): +# ncu -f -o swizzle_fusion --set=full \ +# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \ +# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--profile", + action="store_true", + help="Run only one shape for use with ncu/nsys; longer min_run_time", + ) + parser.add_argument( + "--min-run-time", + type=float, + default=2.0, + help="Minimum total measured time per cell in seconds (benchmark.Timer)", + ) + parser.add_argument( + "--csv", + type=str, + default="benchmark_rht_cast_swizzle_fusion.csv", + help="CSV output path", + ) + args = parser.parse_args() + + if args.profile: + print("Profiling mode enabled (single shape).") + shapes = [(8192, 4096)] + min_run_time = max(5.0, args.min_run_time) + else: + shapes = [ + # production-class shapes + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + min_run_time = args.min_run_time + + print( + "NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. " + f"min_run_time={min_run_time}s per cell, BF16 input, " + "rowwise+columnwise SF, RHT=True+post_rht_amax." + ) + rows = [] + for shape in shapes: + print(f"Running {shape} ...") + rows.append(run_shape(shape, min_run_time)) + + df = pd.DataFrame(rows) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 200) + print() + print(df.to_string(index=False)) + df.to_csv(args.csv, index=False) + print(f"\nWrote {args.csv}") diff --git a/benchmarks/profile_rht_cast_swizzle_fusion.py b/benchmarks/profile_rht_cast_swizzle_fusion.py new file mode 100644 index 0000000000..97d2a3e564 --- /dev/null +++ b/benchmarks/profile_rht_cast_swizzle_fusion.py @@ -0,0 +1,123 @@ +""" +Profile that the dedicated swizzle kernels (swizzle_{row,col}_scaling_kernel +in transformer_engine/common/swizzle/swizzle.cu) disappear from the timeline +when NVFP4 RHT cast-fusion emits SF in the GEMM-swizzled layout directly +(optimize_for_gemm=True). + +Test setup: + - NVFP4 + RHT + post-RHT amax (same as te.Linear sets up internally) + - rowwise=True AND columnwise=True (covers BOTH swizzle_row_scaling_kernel + and swizzle_col_scaling_kernel; this is what tex.Linear's input quantizer + needs during training because the rowwise tensor is used by the fwd GEMM + and the columnwise tensor is used by the dgrad GEMM) + - tex.swizzle_scales_for_gemm_(t) is what te.Linear -> tex.generic_gemm + calls just before the cuBLAS LT NVFP4 GEMM dispatch +""" + +import torch +import transformer_engine.pytorch as te # noqa: F401 must be first +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + + +def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer: + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + q.optimize_for_gemm = optimize_for_gemm + return q + + +import re + +# Match ONLY the standalone swizzle pass kernels in +# transformer_engine/common/swizzle/swizzle.cu — NOT RHT cast-fusion kernels +# whose mangled name happens to contain "Swizzle" because of the +# `template <..., bool kEnableSwizzleSFOutput, ...>` parameter substring. +STANDALONE_SWIZZLE_RE = re.compile( + r"(?:multi_tensor_(?:un)?swizzle|(?:un)?swizzle)_(?:row|col)_scaling_kernel" +) + + +def dump_kernel_counts(prof, label: str) -> dict: + print(f"\n=== {label} ===") + counts: dict[str, int] = {} + for ev in prof.events(): + if ev.device_type != torch.autograd.DeviceType.CUDA: + continue + counts[ev.name] = counts.get(ev.name, 0) + 1 + standalone_swizzle_total = 0 + for name, c in sorted(counts.items(), key=lambda kv: -kv[1]): + marker = "" + if STANDALONE_SWIZZLE_RE.search(name): + marker = " <-- STANDALONE SWIZZLE PASS" + standalone_swizzle_total += c + # Truncate long mangled CUTLASS names for readability + short = name if len(name) <= 110 else name[:107] + "..." + print(f" {c:4d} {short}{marker}") + print(f" -- standalone swizzle kernel total: {standalone_swizzle_total}") + return counts + + +def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20): + q = make_quantizer(optimize_for_gemm=optimize_for_gemm) + # warm-up + for _ in range(3): + t = q(x) + tex.swizzle_scales_for_gemm_(t) + torch.cuda.synchronize() + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + for _ in range(n_iters): + t = q(x) + tex.swizzle_scales_for_gemm_(t) + torch.cuda.synchronize() + return prof + + +def main(): + torch.manual_seed(0) + torch.cuda.manual_seed(0) + device = "cuda" + # Shape that hits the production RHT cast-fusion fast-path + # (rows % 64 == 0, cols % 128 == 0, BF16, SM100/110). + M, N = 8192, 4096 + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) + + print(f"Shape: M={M}, N={N}, dtype=bf16, RHT=True, post_rht_amax=True") + print(f"iters: 20 (after 3 warm-up)") + + prof_baseline = profile_path(optimize_for_gemm=False, x=x) + counts_baseline = dump_kernel_counts( + prof_baseline, "BASELINE: optimize_for_gemm=False (separate swizzle kernel)" + ) + + prof_swf = profile_path(optimize_for_gemm=True, x=x) + counts_swf = dump_kernel_counts( + prof_swf, "SUT: optimize_for_gemm=True (quant emits swizzled SF directly)" + ) + + print("\n=== VERDICT ===") + base_swizzle = sum(c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n)) + swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n)) + print(f" baseline standalone swizzle kernel launches: {base_swizzle}") + print(f" SUT standalone swizzle kernel launches: {swf_swizzle}") + if swf_swizzle == 0 and base_swizzle > 0: + print( + " PASS: standalone swizzle pass disappears from timeline under optimize_for_gemm=True" + ) + else: + print( + " FAIL: expected baseline > 0 and SUT == 0; check whether SUT actually " + "set with_gemm_swizzled_scales=True on the output tensor" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 20a91bf6fe..23c582be86 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -28,6 +28,7 @@ generate_split_sections, assert_same_shape_and_dtype, reference_group_quantize, + swizzle_nvfp4_scale, ) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -43,6 +44,7 @@ def check_group_quantization_nvfp4_versus_reference( with_rht: bool = True, with_post_rht_amax: bool = True, with_random_sign_mask: bool = True, + optimize_for_gemm: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -59,7 +61,7 @@ def check_group_quantization_nvfp4_versus_reference( x_splits = torch.split(x, split_sections) - # Quantize + # Reference quantizers (compact SF, default optimize_for_gemm=False). quantizers = [ NVFP4Quantizer( fp4_dtype=te_dtype, @@ -77,7 +79,13 @@ def check_group_quantization_nvfp4_versus_reference( reference_group_quantize(x, quantizers, split_sections, return_rowwise, return_transpose) ) - split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + # SUT quantizers: same as reference, but with optimize_for_gemm toggled to + # request direct swizzled SF emission from the RHT cast-fusion kernel. + sut_quantizers = [q.copy() for q in quantizers] + for q in sut_quantizers: + q.optimize_for_gemm = optimize_for_gemm + + split_quantize_outputs = tex.split_quantize(x, split_sections, sut_quantizers) if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -98,6 +106,12 @@ def check_group_quantization_nvfp4_versus_reference( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]] x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + if optimize_for_gemm: + # SUT emits SF in the GEMM-swizzled layout directly; swizzle + # the reference compact SF for byte-equal comparison. + x_sx_ref_valid = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_valid, columnwise=False + ) torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0) if return_transpose: @@ -121,6 +135,10 @@ def check_group_quantization_nvfp4_versus_reference( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]] x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + if optimize_for_gemm: + x_sx_t_ref_valid = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_valid, columnwise=True + ) torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0) @@ -157,6 +175,11 @@ def check_group_quantization_nvfp4_versus_reference( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] ) @pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", + [False, True], + ids=["compact_sf", "swizzled_sf"], +) def test_rht_with_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -165,9 +188,43 @@ def test_rht_with_quantization_block_tiling_versus_reference( quantize_mode: str, with_random_sign_mask: bool, with_rht: bool, + optimize_for_gemm: bool, ) -> None: - split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + # The "quantize writes swizzled SF" fast-path is gated in the C++ framework + # on ``optimize_for_gemm && with_rht`` (see NVFP4Quantizer::create_tensor and + # bulk_allocate_nvfp4_tensors in transformer_engine/pytorch/csrc). Without + # ``with_rht=True`` the flag is silently dropped, so the swizzled_sf row + # would just duplicate the compact_sf row — skip it instead of flooding the + # matrix with redundant cases. + if optimize_for_gemm and not with_rht: + pytest.skip("optimize_for_gemm requires with_rht=True (framework gate)") + + # The grouped RHT cast-fusion kernel that honors with_gemm_swizzled_scales + # (group_row_cast_col_hadamard_transform_cast_fusion.cu) is only dispatched + # when: + # - cols are a 128 multiple (RHT cast-fusion eligibility), AND + # - every split section is a 128 multiple (all_aligned_token_dim path in + # split_quantize_nvfp4_impl_with_rht_helper). + # For other shapes the C++ side falls back to the unfused row/col split, + # which does NOT (yet) emit swizzled SF; we'd hit either an NVTE_CHECK or + # silent SF-layout corruption. Restrict the swizzled coverage to the fused + # path; the unfused fallback is covered by the optimize_for_gemm=False + # baseline already exercised above. + if optimize_for_gemm and N % 128 != 0: + pytest.skip("RHT cast-fusion requires N % 128 == 0") + + # generate_split_sections hard-codes num_chunks=4 and requires every chunk + # to be a least_multiple multiple. When optimize_for_gemm forces + # least_multiple from 64 to 128, the test needs M >= 4*128 = 512 (and a + # multiple of 512 for the regular/zero_tokens patterns, which the existing + # M shapes already satisfy: 0, 1024, 8192, 16384). The small M=256 shape + # cannot satisfy this and is exercised by the optimize_for_gemm=False rows. + if optimize_for_gemm and 0 < M < 4 * 128: + pytest.skip("optimize_for_gemm requires M==0 or M>=512 for 4-chunk 128-aligned split") + + least_multiple = 128 if optimize_for_gemm else 64 + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=least_multiple) # currently disable pre-RHT amax with_post_rht_amax = with_rht @@ -194,4 +251,5 @@ def test_rht_with_quantization_block_tiling_versus_reference( with_rht=with_rht, with_post_rht_amax=with_post_rht_amax, with_random_sign_mask=with_random_sign_mask, + optimize_for_gemm=optimize_for_gemm, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py new file mode 100644 index 0000000000..3059adea8c --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py @@ -0,0 +1,254 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fidelity tests for NVFP4 single-tensor RHT cast-fusion swizzled SF output. + +Mirrors ``tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py``: when +the quantizer has ``optimize_for_gemm=True`` and the input is eligible for +the RHT cast-fusion kernel (bf16, rows%64==0, cols%128==0, SM 100/110), the +kernel should emit scale factors directly in the GEMM-swizzled layout +``cuBLAS LT`` consumes, eliminating the otherwise-required +``nvte_swizzle_scaling_factors`` pass between quantize and GEMM. + +The fidelity contract is: ``quantizer_swizzle_fusion(x)`` produces SF that +are byte-equal to ``swizzle_nvfp4_scale(quantizer(x).sx)``. The +``_rowwise_data`` / ``_columnwise_data`` quantized FP4 buffers and amaxes +must also be byte-equal (the swizzle optimization changes only the SF +layout, not the FP4 data itself). +""" + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex # noqa: F401 +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import ( + NVFP4TensorStorage, +) + +import pytest +import torch + +from typing import Tuple + +from nvfp4_utils import ( + swizzle_nvfp4_scale, + get_nvfp4_scale_shape_no_padding, +) + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def _unpack_quantized_tensor( + quantized_tensor: NVFP4TensorStorage, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Extract the six tensors we want to compare in byte-equal form. + + Returns ``(qx, sx, amax_row, qx_t, sx_t, amax_col)``, with any of these + set to ``None`` when the quantizer did not request that direction. + """ + qx, sx, amax_row = None, None, None + qx_t, sx_t, amax_col = None, None, None + if quantized_tensor._rowwise_data is not None: + qx = quantized_tensor._rowwise_data.view(dtype=torch.uint8) + if quantized_tensor._rowwise_scale_inv is not None: + sx = quantized_tensor._rowwise_scale_inv + if quantized_tensor._amax_rowwise is not None: + amax_row = quantized_tensor._amax_rowwise + if quantized_tensor._columnwise_data is not None: + qx_t = quantized_tensor._columnwise_data.view(dtype=torch.uint8) + if quantized_tensor._columnwise_scale_inv is not None: + sx_t = quantized_tensor._columnwise_scale_inv + if quantized_tensor._amax_columnwise is not None: + amax_col = quantized_tensor._amax_columnwise + return qx, sx, amax_row, qx_t, sx_t, amax_col + + +def _check_nvfp4_rht_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + return_rowwise: bool, + return_transpose: bool, + with_random_sign_mask: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Reference quantizer (compact SF, default behavior). + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_rowwise, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=with_random_sign_mask, + ) + + # SUT: same quantizer but with swizzled-SF emission enabled. + quantizer_swizzle_fusion = quantizer.copy() + quantizer_swizzle_fusion.optimize_for_gemm = True + + qx_swf, sx_swf, amax_row_swf, qx_t_swf, sx_t_swf, amax_col_swf = _unpack_quantized_tensor( + quantizer_swizzle_fusion(x) + ) + qx_ref, sx_ref, amax_row_ref, qx_t_ref, sx_t_ref, amax_col_ref = _unpack_quantized_tensor( + quantizer(x) + ) + + if return_rowwise: + # FP4 data buffer and amax must be byte-equal (swizzle only changes SF + # layout, not the quantized data). + torch.testing.assert_close(qx_swf, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(amax_row_swf, amax_row_ref, atol=0.0, rtol=0.0) + + # SF tensor must match the swizzle of the reference compact SF. + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x.shape, False) + assert valid_scale_shape == sx_swf.shape, ( + "rowwise SF shape mismatch; this test assumes the input shape needs no " + f"SF padding (got valid={valid_scale_shape}, got_swf={sx_swf.shape})." + ) + sx_ref_swizzled = swizzle_nvfp4_scale(M, N, sx_ref, columnwise=False) + torch.testing.assert_close(sx_swf, sx_ref_swizzled, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t_swf, qx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(amax_col_swf, amax_col_ref, atol=0.0, rtol=0.0) + + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x.shape, True) + assert valid_scale_shape == sx_t_swf.shape, ( + "columnwise SF shape mismatch; this test assumes the input shape needs no " + f"SF padding (got valid={valid_scale_shape}, got_swf={sx_t_swf.shape})." + ) + sx_t_ref_swizzled = swizzle_nvfp4_scale(M, N, sx_t_ref, columnwise=True) + torch.testing.assert_close(sx_t_swf, sx_t_ref_swizzled, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # Full tile cases (rows%64==0 and cols%128==0 to be eligible for the + # RHT cast-fusion kernel; also sized so no SF padding is needed). + (128, 128), + (256, 256), + (1024, 256), + # Production-like shapes. + (2048, 2048), + (8192, 1024), + (8192, 5120), + (8192, 10240), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_rht_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + quantize_mode: str, + with_random_sign_mask: bool, +) -> None: + if quantize_mode == "rowwise_only": + return_rowwise = True + return_transpose = False + elif quantize_mode == "both_directions": + return_rowwise = True + return_transpose = True + elif quantize_mode == "columnwise_only": + return_rowwise = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + _check_nvfp4_rht_quantize_swizzle_fusion( + x_dtype=x_dtype, + M=M, + N=N, + return_rowwise=return_rowwise, + return_transpose=return_transpose, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N, expected_swizzled", + [ + # Eligible: rows%64==0 AND cols%128==0 -> framework keeps swizzled=True + # and dispatch lands on the RHT cast-fusion kernel. + (64, 128, True), + (128, 256, True), + # Ineligible (cols%128 != 0) -> framework must clamp swizzled to False + # so the RHT-unfused fallback runs cleanly. This is the case hit in + # production at irregular shapes like (8192, 11328) where 11328%128==64. + (64, 144, False), + (128, 144, False), + # Ineligible (rows%64 != 0) -> same clamping requirement. + (48, 128, False), + ], +) +def test_nvfp4_rht_swizzle_fusion_shape_gate(M: int, N: int, expected_swizzled: bool) -> None: + """Framework gate must clamp ``with_gemm_swizzled_scales`` by shape eligibility. + + Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can emit + GEMM-swizzled SF directly. When the input shape is ineligible for that + kernel (rows%64!=0 or cols%128!=0), dispatch falls back to + ``quantize_with_rht_unfused_helper`` whose backing kernels + (``nvte_quantize_v2`` and ``nvte_hadamard_transform``) cannot emit + swizzled SF. The framework gates in ``NVFP4Quantizer::create_tensor`` and + ``convert_and_update_tensor`` must therefore clamp the flag to False on + ineligible shapes -- otherwise the defense-in-depth ``NVTE_CHECK`` in the + unfused helper hard-aborts at user-facing code paths (regression observed + at irregular production shapes such as ``(8192, 11328)``). + + The defense-in-depth ``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper`` + is kept as a safety net for direct low-level callers; this test + intentionally exercises the user-level entry point (``quantizer(x)``) and + asserts the framework gate makes the unfused-path crash unreachable from + user code. + """ + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) + + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + quantizer.optimize_for_gemm = True + + # Should not raise on ineligible shapes; the framework gate clamps the + # flag and the unfused fallback runs cleanly. + result = quantizer(x) + assert result._with_gemm_swizzled_scales is expected_swizzled, ( + f"Framework shape gate expected _with_gemm_swizzled_scales={expected_swizzled} " + f"for shape ({M}, {N}) with optimize_for_gemm=True + with_rht=True, " + f"got {result._with_gemm_swizzled_scales}" + ) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 1265f2711c..683250ee36 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1423,7 +1423,20 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vectorwith_gemm_swizzled_scales; + for (size_t i = 1; i < output_list.size(); ++i) { + NVTE_CHECK(output_list[i]->with_gemm_swizzled_scales == use_swizzle_sf_output, + "group_hadamard_transform_cast_fusion: all output tensors must share the same " + "with_gemm_swizzled_scales flag (mismatch at index ", + i, ")."); + } TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kEnableStochasticRounding, diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 99060ab627..c560bdd073 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -1315,8 +1315,11 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, int k_tile_size = 1024; - // TODO: add support for swizzle sf output - const bool use_swizzle_sf_output = false; + // Honor the output tensor's GEMM-swizzled-scales flag: when set, emit + // scale factors directly in the layout that the downstream cuBLAS LT NVFP4 + // GEMM consumes, eliminating the otherwise-required + // nvte_swizzle_scaling_factors pass between quantize and GEMM. + const bool use_swizzle_sf_output = output_.with_gemm_swizzled_scales; TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kEnableStochasticRounding, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..13e78d4a04 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -371,6 +371,12 @@ class NVFP4Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + /*! @brief Whether a tensor of the given shape is eligible for + * the NVFP4 RHT cast-fusion kernel (single-tensor or grouped). + */ + static bool is_eligible_for_rht_cast_fusion(const std::vector& shape, + bool for_grouped_kernel = false); + private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..5eb843db6a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -719,7 +719,6 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } - // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; @@ -730,7 +729,39 @@ std::tuple, std::vector, bool> bulk_alloc } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; + + // with_gemm_swizzled_scales is a single group-wide boolean baked + // into every output tensor. We can safely request it only when + // (a) every quantizer in the group has optimize_for_gemm and + // with_rht set, and (b) every tensor's shape qualifies for RHT + // cast-fusion. Disagreement among quantizers would silently give + // some outputs a layout that their own quantizer did not request; + // the NVTE_CHECK loop below turns that into a loud error. The + // final flag ANDs all three predicates (the shape-eligibility one + // is computed by the loop further below). + const bool group_optimize_for_gemm = quantizer_cpp_list[0]->optimize_for_gemm; + const bool group_with_rht = quantizer_cpp_list[0]->with_rht; + for (size_t i = 1; i < num_tensors; ++i) { + NVTE_CHECK(quantizer_cpp_list[i]->optimize_for_gemm == group_optimize_for_gemm, + "NVFP4 bulk allocation requires all quantizers in the group to share " + "the same optimize_for_gemm value (tensor 0=", + group_optimize_for_gemm, ", tensor ", i, "=", + quantizer_cpp_list[i]->optimize_for_gemm, ")."); + NVTE_CHECK(quantizer_cpp_list[i]->with_rht == group_with_rht, + "NVFP4 bulk allocation requires all quantizers in the group to share " + "the same with_rht value (tensor 0=", + group_with_rht, ", tensor ", i, "=", quantizer_cpp_list[i]->with_rht, ")."); + } + bool all_tensors_rht_cast_fusion_eligible = true; + for (size_t i = 0; i < num_tensors; ++i) { + if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape_list[i], + /*for_grouped_kernel=*/true)) { + all_tensors_rht_cast_fusion_eligible = false; + break; + } + } + const bool with_gemm_swizzled_scales = + group_optimize_for_gemm && group_with_rht && all_tensors_rht_cast_fusion_eligible; // Helper function to get size of byte buffer holding FP4 data (last dim divided by 2) auto fp4_byte_shape = [](const std::vector &shape) -> std::vector { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..a9d7e4b121 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1760,18 +1760,28 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } +bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(const std::vector& shape, + bool for_grouped_kernel) { + const auto [rows, cols] = get_2d_dims(shape); + const size_t row_align = for_grouped_kernel ? 128 : 64; + return rows % row_align == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 && + transformer_engine::cuda::sm_arch() <= 110; +} + std::pair NVFP4Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional device_opt, bool pin_memory) const { const auto device = resolve_device(device_opt); using namespace pybind11::literals; - // Scaling factor format - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm - // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + + // Swizzled SF is only valid when the RHT cast-fusion path runs; + // other quantize paths reject it. + const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2051,9 +2061,6 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); - // Scaling factor format - const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm - // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { auto attr_py = tensor.attr(name); @@ -2084,6 +2091,12 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + + // Swizzled SF is only valid when the RHT cast-fusion path runs; + // other quantize paths reject it. + const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -2204,7 +2217,14 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream) { - // only triggered for irregular shapes where RHT cast fusion kernel is not eligible + // The kernels invoked below reject swizzled-SF output, so trip a clear + // error here before reaching them. + NVTE_CHECK(!out.get_with_gemm_swizzled_scales(), + "NVFP4 RHT-unfused fallback path does not support " + "with_gemm_swizzled_scales=True. Either disable optimize_for_gemm on the " + "quantizer, or ensure the input shape is eligible for RHT cast-fusion " + "(bf16 dtype + rows%64==0 + cols%128==0 + SM 100/110)."); + if (rowwise_usage) { // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise TensorWrapper out_identity(out.scaling_mode()); @@ -2288,11 +2308,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input - size_t rows = 1; - for (size_t i = 0; i < input.ndim() - 1; ++i) { - rows *= input.size(i); - } - size_t cols = input.size(input.ndim() - 1); + const auto [rows, cols] = get_2d_dims(input.shape()); const bool row_scaled_nvfp4 = out.get_row_scaled_nvfp4(); if (row_scaled_nvfp4) { @@ -2307,9 +2323,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT - bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && - transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; + const bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(convertShape(input.shape())); // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT,