From 3d3e9c9507484cbe0963a697f885ddf8089b0bcb Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 18 May 2026 20:22:10 -0700 Subject: [PATCH 1/5] [PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a `kEnableSwizzleSFOutput` switch for that, but the framework never set the matching `with_gemm_swizzled_scales` flag on NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through. Changes: * Single + grouped Hadamard cast-fusion kernels: drive `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`. * NVFP4Quantizer create_tensor / convert_and_update_tensor / bulk_allocate_nvfp4_tensors: set the flag when `optimize_for_gemm && with_rht && shape eligible`, with eligibility in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites. * Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper in case a future low-level caller bypasses the gate. The shape gate is part of this PR (not a follow-up) because LLaMA-class shapes like (8192, 11328) have K%128==64. Without the gate the framework would set the flag, dispatch would fall to the unfused path that can't emit swizzled SF, and the process would abort. With the gate, ineligible shapes silently fall back to the original code path. Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median, `quant + swizzle` path -- what te.Linear actually runs): (8192, 5120) 108.6 -> 81.9 us 1.33x eligible (8192, 11328) 236.3 -> 236.3 us 1.00x ineligible, gate clamped (11328, 8192) 114.4 -> 93.2 us 1.23x eligible (14336,16384) 232.1 -> 197.5 us 1.18x eligible 11/12 production-class shapes get 1.18x - 1.36x. The one ineligible shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged across all shapes -- the savings come entirely from eliminating the standalone swizzle pass, not from a faster quant kernel. Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py Tests: * new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py: byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases verifying the shape gate clamps correctly and that quantizer(x) on an ineligible shape does not raise. * tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added optimize_for_gemm parametrization for the legacy grouped path. * test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe variant already had the wiring). Signed-off-by: Cael Ling --- .../benchmark_rht_cast_swizzle_fusion.py | 189 +++++++++++++ benchmarks/profile_rht_cast_swizzle_fusion.py | 128 +++++++++ .../nvfp4/test_nvfp4_group_quantize.py | 64 ++++- .../test_nvfp4_rht_quantize_swizzle_fusion.py | 254 ++++++++++++++++++ ...cast_col_hadamard_transform_cast_fusion.cu | 15 +- ...cast_col_hadamard_transform_cast_fusion.cu | 7 +- transformer_engine/pytorch/csrc/common.h | 13 + .../pytorch/csrc/extensions/cast.cpp | 22 +- transformer_engine/pytorch/csrc/quantizer.cpp | 49 +++- 9 files changed, 728 insertions(+), 13 deletions(-) create mode 100644 benchmarks/benchmark_rht_cast_swizzle_fusion.py create mode 100644 benchmarks/profile_rht_cast_swizzle_fusion.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py diff --git a/benchmarks/benchmark_rht_cast_swizzle_fusion.py b/benchmarks/benchmark_rht_cast_swizzle_fusion.py new file mode 100644 index 0000000000..2f1cca13fb --- /dev/null +++ b/benchmarks/benchmark_rht_cast_swizzle_fusion.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output. + +For each shape we measure two paths and two builds: + + * path = "quant_only": just NVFP4Quantizer(x) + * path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t) + (this is what te.Linear -> tex.generic_gemm does right before the + cuBLAS LT NVFP4 GEMM dispatch) + + * build = "baseline": optimize_for_gemm=False + -> quant kernel emits compact SF; + tex.swizzle_scales_for_gemm_ launches the standalone + swizzle_{row,col}_scaling_kernel pass before GEMM. + * build = "swizzle_fusion": optimize_for_gemm=True + -> quant kernel emits GEMM-swizzled SF directly (via the + kEnableSwizzleSFOutput compile-time switch in + row_cast_col_hadamard_transform_cast_fusion.cu); + tex.swizzle_scales_for_gemm_ early-returns and the standalone + swizzle pass disappears from the timeline. + +The wall-clock delta on the "quant_plus_swizzle" path is the production +saving of this PR. +""" + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order +import transformer_engine_torch as tex +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + +def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer: + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + q.optimize_for_gemm = optimize_for_gemm + return q + + +def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float: + """Returns median wall-clock per call in microseconds.""" + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=min_run_time) + return timing.median * 1e6 + + +def run_shape(shape, min_run_time: float): + M, K = shape + assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16" + + x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda") + q_base = make_quantizer(optimize_for_gemm=False) + q_swf = make_quantizer(optimize_for_gemm=True) + + # quant_only path + quant_only_base_us = _bench( + stmt="q(x)", + globals_dict={"q": q_base, "x": x}, + min_run_time=min_run_time, + ) + quant_only_swf_us = _bench( + stmt="q(x)", + globals_dict={"q": q_swf, "x": x}, + min_run_time=min_run_time, + ) + + # quant_plus_swizzle path (this is what te.Linear actually runs) + quant_plus_swizzle_base_us = _bench( + stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)", + globals_dict={"q": q_base, "x": x, "tex": tex}, + min_run_time=min_run_time, + ) + quant_plus_swizzle_swf_us = _bench( + stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)", + globals_dict={"q": q_swf, "x": x, "tex": tex}, + min_run_time=min_run_time, + ) + + saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us + speedup = ( + quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us + if quant_plus_swizzle_swf_us > 0 + else float("inf") + ) + + print( + f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, " + f"SUT={quant_only_swf_us:.2f}us | " + f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, " + f"SUT={quant_plus_swizzle_swf_us:.2f}us " + f"-> saved {saved_us:.2f}us ({speedup:.2f}x)" + ) + + return { + "shape": shape, + "M": M, + "K": K, + "quant_only_base_us": quant_only_base_us, + "quant_only_swf_us": quant_only_swf_us, + "quant_plus_swizzle_base_us": quant_plus_swizzle_base_us, + "quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us, + "saved_us": saved_us, + "speedup": speedup, + } + + +# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears): +# ncu -f -o swizzle_fusion --set=full \ +# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \ +# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--profile", + action="store_true", + help="Run only one shape for use with ncu/nsys; longer min_run_time", + ) + parser.add_argument( + "--min-run-time", + type=float, + default=2.0, + help="Minimum total measured time per cell in seconds (benchmark.Timer)", + ) + parser.add_argument( + "--csv", + type=str, + default="benchmark_rht_cast_swizzle_fusion.csv", + help="CSV output path", + ) + args = parser.parse_args() + + if args.profile: + print("Profiling mode enabled (single shape).") + shapes = [(8192, 4096)] + min_run_time = max(5.0, args.min_run_time) + else: + shapes = [ + # production-class shapes + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + min_run_time = args.min_run_time + + print( + f"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. " + f"min_run_time={min_run_time}s per cell, BF16 input, " + f"rowwise+columnwise SF, RHT=True+post_rht_amax." + ) + rows = [] + for shape in shapes: + print(f"Running {shape} ...") + rows.append(run_shape(shape, min_run_time)) + + df = pd.DataFrame(rows) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", 200) + print() + print(df.to_string(index=False)) + df.to_csv(args.csv, index=False) + print(f"\nWrote {args.csv}") diff --git a/benchmarks/profile_rht_cast_swizzle_fusion.py b/benchmarks/profile_rht_cast_swizzle_fusion.py new file mode 100644 index 0000000000..aa0f588a4f --- /dev/null +++ b/benchmarks/profile_rht_cast_swizzle_fusion.py @@ -0,0 +1,128 @@ +""" +Profile that the dedicated swizzle kernels (swizzle_{row,col}_scaling_kernel +in transformer_engine/common/swizzle/swizzle.cu) disappear from the timeline +when NVFP4 RHT cast-fusion emits SF in the GEMM-swizzled layout directly +(optimize_for_gemm=True). + +Test setup: + - NVFP4 + RHT + post-RHT amax (same as te.Linear sets up internally) + - rowwise=True AND columnwise=True (covers BOTH swizzle_row_scaling_kernel + and swizzle_col_scaling_kernel; this is what tex.Linear's input quantizer + needs during training because the rowwise tensor is used by the fwd GEMM + and the columnwise tensor is used by the dgrad GEMM) + - tex.swizzle_scales_for_gemm_(t) is what te.Linear -> tex.generic_gemm + calls just before the cuBLAS LT NVFP4 GEMM dispatch +""" + +import torch +import transformer_engine.pytorch as te # noqa: F401 must be first +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + + +def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer: + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + q.optimize_for_gemm = optimize_for_gemm + return q + + +import re + +# Match ONLY the standalone swizzle pass kernels in +# transformer_engine/common/swizzle/swizzle.cu — NOT RHT cast-fusion kernels +# whose mangled name happens to contain "Swizzle" because of the +# `template <..., bool kEnableSwizzleSFOutput, ...>` parameter substring. +STANDALONE_SWIZZLE_RE = re.compile( + r"(?:multi_tensor_(?:un)?swizzle|(?:un)?swizzle)_(?:row|col)_scaling_kernel" +) + + +def dump_kernel_counts(prof, label: str) -> dict: + print(f"\n=== {label} ===") + counts: dict[str, int] = {} + for ev in prof.events(): + if ev.device_type != torch.autograd.DeviceType.CUDA: + continue + counts[ev.name] = counts.get(ev.name, 0) + 1 + standalone_swizzle_total = 0 + for name, c in sorted(counts.items(), key=lambda kv: -kv[1]): + marker = "" + if STANDALONE_SWIZZLE_RE.search(name): + marker = " <-- STANDALONE SWIZZLE PASS" + standalone_swizzle_total += c + # Truncate long mangled CUTLASS names for readability + short = name if len(name) <= 110 else name[:107] + "..." + print(f" {c:4d} {short}{marker}") + print(f" -- standalone swizzle kernel total: {standalone_swizzle_total}") + return counts + + +def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20): + q = make_quantizer(optimize_for_gemm=optimize_for_gemm) + # warm-up + for _ in range(3): + t = q(x) + tex.swizzle_scales_for_gemm_(t) + torch.cuda.synchronize() + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA] + ) as prof: + for _ in range(n_iters): + t = q(x) + tex.swizzle_scales_for_gemm_(t) + torch.cuda.synchronize() + return prof + + +def main(): + torch.manual_seed(0) + torch.cuda.manual_seed(0) + device = "cuda" + # Shape that hits the production RHT cast-fusion fast-path + # (rows % 64 == 0, cols % 128 == 0, BF16, SM100/110). + M, N = 8192, 4096 + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) + + print(f"Shape: M={M}, N={N}, dtype=bf16, RHT=True, post_rht_amax=True") + print(f"iters: 20 (after 3 warm-up)") + + prof_baseline = profile_path(optimize_for_gemm=False, x=x) + counts_baseline = dump_kernel_counts( + prof_baseline, "BASELINE: optimize_for_gemm=False (separate swizzle kernel)" + ) + + prof_swf = profile_path(optimize_for_gemm=True, x=x) + counts_swf = dump_kernel_counts( + prof_swf, "SUT: optimize_for_gemm=True (quant emits swizzled SF directly)" + ) + + print("\n=== VERDICT ===") + base_swizzle = sum( + c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n) + ) + swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n)) + print(f" baseline standalone swizzle kernel launches: {base_swizzle}") + print(f" SUT standalone swizzle kernel launches: {swf_swizzle}") + if swf_swizzle == 0 and base_swizzle > 0: + print( + " PASS: standalone swizzle pass disappears from timeline under " + "optimize_for_gemm=True" + ) + else: + print( + " FAIL: expected baseline > 0 and SUT == 0; check whether SUT actually " + "set with_gemm_swizzled_scales=True on the output tensor" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 20a91bf6fe..23c582be86 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -28,6 +28,7 @@ generate_split_sections, assert_same_shape_and_dtype, reference_group_quantize, + swizzle_nvfp4_scale, ) recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -43,6 +44,7 @@ def check_group_quantization_nvfp4_versus_reference( with_rht: bool = True, with_post_rht_amax: bool = True, with_random_sign_mask: bool = True, + optimize_for_gemm: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -59,7 +61,7 @@ def check_group_quantization_nvfp4_versus_reference( x_splits = torch.split(x, split_sections) - # Quantize + # Reference quantizers (compact SF, default optimize_for_gemm=False). quantizers = [ NVFP4Quantizer( fp4_dtype=te_dtype, @@ -77,7 +79,13 @@ def check_group_quantization_nvfp4_versus_reference( reference_group_quantize(x, quantizers, split_sections, return_rowwise, return_transpose) ) - split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + # SUT quantizers: same as reference, but with optimize_for_gemm toggled to + # request direct swizzled SF emission from the RHT cast-fusion kernel. + sut_quantizers = [q.copy() for q in quantizers] + for q in sut_quantizers: + q.optimize_for_gemm = optimize_for_gemm + + split_quantize_outputs = tex.split_quantize(x, split_sections, sut_quantizers) if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -98,6 +106,12 @@ def check_group_quantization_nvfp4_versus_reference( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]] x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + if optimize_for_gemm: + # SUT emits SF in the GEMM-swizzled layout directly; swizzle + # the reference compact SF for byte-equal comparison. + x_sx_ref_valid = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_valid, columnwise=False + ) torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0) if return_transpose: @@ -121,6 +135,10 @@ def check_group_quantization_nvfp4_versus_reference( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]] x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]] + if optimize_for_gemm: + x_sx_t_ref_valid = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_valid, columnwise=True + ) torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0) @@ -157,6 +175,11 @@ def check_group_quantization_nvfp4_versus_reference( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] ) @pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", + [False, True], + ids=["compact_sf", "swizzled_sf"], +) def test_rht_with_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -165,9 +188,43 @@ def test_rht_with_quantization_block_tiling_versus_reference( quantize_mode: str, with_random_sign_mask: bool, with_rht: bool, + optimize_for_gemm: bool, ) -> None: - split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + # The "quantize writes swizzled SF" fast-path is gated in the C++ framework + # on ``optimize_for_gemm && with_rht`` (see NVFP4Quantizer::create_tensor and + # bulk_allocate_nvfp4_tensors in transformer_engine/pytorch/csrc). Without + # ``with_rht=True`` the flag is silently dropped, so the swizzled_sf row + # would just duplicate the compact_sf row — skip it instead of flooding the + # matrix with redundant cases. + if optimize_for_gemm and not with_rht: + pytest.skip("optimize_for_gemm requires with_rht=True (framework gate)") + + # The grouped RHT cast-fusion kernel that honors with_gemm_swizzled_scales + # (group_row_cast_col_hadamard_transform_cast_fusion.cu) is only dispatched + # when: + # - cols are a 128 multiple (RHT cast-fusion eligibility), AND + # - every split section is a 128 multiple (all_aligned_token_dim path in + # split_quantize_nvfp4_impl_with_rht_helper). + # For other shapes the C++ side falls back to the unfused row/col split, + # which does NOT (yet) emit swizzled SF; we'd hit either an NVTE_CHECK or + # silent SF-layout corruption. Restrict the swizzled coverage to the fused + # path; the unfused fallback is covered by the optimize_for_gemm=False + # baseline already exercised above. + if optimize_for_gemm and N % 128 != 0: + pytest.skip("RHT cast-fusion requires N % 128 == 0") + + # generate_split_sections hard-codes num_chunks=4 and requires every chunk + # to be a least_multiple multiple. When optimize_for_gemm forces + # least_multiple from 64 to 128, the test needs M >= 4*128 = 512 (and a + # multiple of 512 for the regular/zero_tokens patterns, which the existing + # M shapes already satisfy: 0, 1024, 8192, 16384). The small M=256 shape + # cannot satisfy this and is exercised by the optimize_for_gemm=False rows. + if optimize_for_gemm and 0 < M < 4 * 128: + pytest.skip("optimize_for_gemm requires M==0 or M>=512 for 4-chunk 128-aligned split") + + least_multiple = 128 if optimize_for_gemm else 64 + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=least_multiple) # currently disable pre-RHT amax with_post_rht_amax = with_rht @@ -194,4 +251,5 @@ def test_rht_with_quantization_block_tiling_versus_reference( with_rht=with_rht, with_post_rht_amax=with_post_rht_amax, with_random_sign_mask=with_random_sign_mask, + optimize_for_gemm=optimize_for_gemm, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py new file mode 100644 index 0000000000..6cd331b437 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py @@ -0,0 +1,254 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fidelity tests for NVFP4 single-tensor RHT cast-fusion swizzled SF output. + +Mirrors ``tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py``: when +the quantizer has ``optimize_for_gemm=True`` and the input is eligible for +the RHT cast-fusion kernel (bf16, rows%64==0, cols%128==0, SM 100/110), the +kernel should emit scale factors directly in the GEMM-swizzled layout +``cuBLAS LT`` consumes, eliminating the otherwise-required +``nvte_swizzle_scaling_factors`` pass between quantize and GEMM. + +The fidelity contract is: ``quantizer_swizzle_fusion(x)`` produces SF that +are byte-equal to ``swizzle_nvfp4_scale(quantizer(x).sx)``. The +``_rowwise_data`` / ``_columnwise_data`` quantized FP4 buffers and amaxes +must also be byte-equal (the swizzle optimization changes only the SF +layout, not the FP4 data itself). +""" + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex # noqa: F401 +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import ( + NVFP4TensorStorage, +) + +import pytest +import torch + +from typing import Tuple + +from nvfp4_utils import ( + swizzle_nvfp4_scale, + get_nvfp4_scale_shape_no_padding, +) + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def _unpack_quantized_tensor( + quantized_tensor: NVFP4TensorStorage, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Extract the six tensors we want to compare in byte-equal form. + + Returns ``(qx, sx, amax_row, qx_t, sx_t, amax_col)``, with any of these + set to ``None`` when the quantizer did not request that direction. + """ + qx, sx, amax_row = None, None, None + qx_t, sx_t, amax_col = None, None, None + if quantized_tensor._rowwise_data is not None: + qx = quantized_tensor._rowwise_data.view(dtype=torch.uint8) + if quantized_tensor._rowwise_scale_inv is not None: + sx = quantized_tensor._rowwise_scale_inv + if quantized_tensor._amax_rowwise is not None: + amax_row = quantized_tensor._amax_rowwise + if quantized_tensor._columnwise_data is not None: + qx_t = quantized_tensor._columnwise_data.view(dtype=torch.uint8) + if quantized_tensor._columnwise_scale_inv is not None: + sx_t = quantized_tensor._columnwise_scale_inv + if quantized_tensor._amax_columnwise is not None: + amax_col = quantized_tensor._amax_columnwise + return qx, sx, amax_row, qx_t, sx_t, amax_col + + +def _check_nvfp4_rht_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + return_rowwise: bool, + return_transpose: bool, + with_random_sign_mask: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Reference quantizer (compact SF, default behavior). + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_rowwise, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=with_random_sign_mask, + ) + + # SUT: same quantizer but with swizzled-SF emission enabled. + quantizer_swizzle_fusion = quantizer.copy() + quantizer_swizzle_fusion.optimize_for_gemm = True + + qx_swf, sx_swf, amax_row_swf, qx_t_swf, sx_t_swf, amax_col_swf = ( + _unpack_quantized_tensor(quantizer_swizzle_fusion(x)) + ) + qx_ref, sx_ref, amax_row_ref, qx_t_ref, sx_t_ref, amax_col_ref = ( + _unpack_quantized_tensor(quantizer(x)) + ) + + if return_rowwise: + # FP4 data buffer and amax must be byte-equal (swizzle only changes SF + # layout, not the quantized data). + torch.testing.assert_close(qx_swf, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(amax_row_swf, amax_row_ref, atol=0.0, rtol=0.0) + + # SF tensor must match the swizzle of the reference compact SF. + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x.shape, False) + assert valid_scale_shape == sx_swf.shape, ( + "rowwise SF shape mismatch; this test assumes the input shape needs no " + f"SF padding (got valid={valid_scale_shape}, got_swf={sx_swf.shape})." + ) + sx_ref_swizzled = swizzle_nvfp4_scale(M, N, sx_ref, columnwise=False) + torch.testing.assert_close(sx_swf, sx_ref_swizzled, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t_swf, qx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(amax_col_swf, amax_col_ref, atol=0.0, rtol=0.0) + + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x.shape, True) + assert valid_scale_shape == sx_t_swf.shape, ( + "columnwise SF shape mismatch; this test assumes the input shape needs no " + f"SF padding (got valid={valid_scale_shape}, got_swf={sx_t_swf.shape})." + ) + sx_t_ref_swizzled = swizzle_nvfp4_scale(M, N, sx_t_ref, columnwise=True) + torch.testing.assert_close(sx_t_swf, sx_t_ref_swizzled, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # Full tile cases (rows%64==0 and cols%128==0 to be eligible for the + # RHT cast-fusion kernel; also sized so no SF padding is needed). + (128, 128), + (256, 256), + (1024, 256), + # Production-like shapes. + (2048, 2048), + (8192, 1024), + (8192, 5120), + (8192, 10240), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_rht_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + quantize_mode: str, + with_random_sign_mask: bool, +) -> None: + if quantize_mode == "rowwise_only": + return_rowwise = True + return_transpose = False + elif quantize_mode == "both_directions": + return_rowwise = True + return_transpose = True + elif quantize_mode == "columnwise_only": + return_rowwise = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + _check_nvfp4_rht_quantize_swizzle_fusion( + x_dtype=x_dtype, + M=M, + N=N, + return_rowwise=return_rowwise, + return_transpose=return_transpose, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N, expected_swizzled", + [ + # Eligible: rows%64==0 AND cols%128==0 -> framework keeps swizzled=True + # and dispatch lands on the RHT cast-fusion kernel. + (64, 128, True), + (128, 256, True), + # Ineligible (cols%128 != 0) -> framework must clamp swizzled to False + # so the RHT-unfused fallback runs cleanly. This is the case hit in + # production at irregular shapes like (8192, 11328) where 11328%128==64. + (64, 144, False), + (128, 144, False), + # Ineligible (rows%64 != 0) -> same clamping requirement. + (48, 128, False), + ], +) +def test_nvfp4_rht_swizzle_fusion_shape_gate(M: int, N: int, expected_swizzled: bool) -> None: + """Framework gate must clamp ``with_gemm_swizzled_scales`` by shape eligibility. + + Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can emit + GEMM-swizzled SF directly. When the input shape is ineligible for that + kernel (rows%64!=0 or cols%128!=0), dispatch falls back to + ``quantize_with_rht_unfused_helper`` whose backing kernels + (``nvte_quantize_v2`` and ``nvte_hadamard_transform``) cannot emit + swizzled SF. The framework gates in ``NVFP4Quantizer::create_tensor`` and + ``convert_and_update_tensor`` must therefore clamp the flag to False on + ineligible shapes -- otherwise the defense-in-depth ``NVTE_CHECK`` in the + unfused helper hard-aborts at user-facing code paths (regression observed + at irregular production shapes such as ``(8192, 11328)``). + + The defense-in-depth ``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper`` + is kept as a safety net for direct low-level callers; this test + intentionally exercises the user-level entry point (``quantizer(x)``) and + asserts the framework gate makes the unfused-path crash unreachable from + user code. + """ + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) + + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + quantizer.optimize_for_gemm = True + + # Should not raise on ineligible shapes; the framework gate clamps the + # flag and the unfused fallback runs cleanly. + result = quantizer(x) + assert result._with_gemm_swizzled_scales is expected_swizzled, ( + f"Framework shape gate expected _with_gemm_swizzled_scales={expected_swizzled} " + f"for shape ({M}, {N}) with optimize_for_gemm=True + with_rht=True, " + f"got {result._with_gemm_swizzled_scales}" + ) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 1265f2711c..683250ee36 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1423,7 +1423,20 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vectorwith_gemm_swizzled_scales; + for (size_t i = 1; i < output_list.size(); ++i) { + NVTE_CHECK(output_list[i]->with_gemm_swizzled_scales == use_swizzle_sf_output, + "group_hadamard_transform_cast_fusion: all output tensors must share the same " + "with_gemm_swizzled_scales flag (mismatch at index ", + i, ")."); + } TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kEnableStochasticRounding, diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 99060ab627..c560bdd073 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -1315,8 +1315,11 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, int k_tile_size = 1024; - // TODO: add support for swizzle sf output - const bool use_swizzle_sf_output = false; + // Honor the output tensor's GEMM-swizzled-scales flag: when set, emit + // scale factors directly in the layout that the downstream cuBLAS LT NVFP4 + // GEMM consumes, eliminating the otherwise-required + // nvte_swizzle_scaling_factors pass between quantize and GEMM. + const bool use_swizzle_sf_output = output_.with_gemm_swizzled_scales; TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kEnableStochasticRounding, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..4fc5ab28da 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -371,6 +371,19 @@ class NVFP4Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + /*! @brief Whether a 2D shape (rows, cols) is eligible for the + * NVFP4 RHT cast-fusion kernel. + * + * Matches the dispatch logic in NVFP4Quantizer::quantize_impl. + * The dtype check (BF16) is implicit -- with_rht=True requires + * BF16 input by construction, so callers gate on with_rht first. + * When false, the dispatch falls back to quantize_with_rht_unfused + * which cannot emit GEMM-swizzled SF; framework gates that opt + * into with_gemm_swizzled_scales must therefore also check this + * to avoid mismatched-flag aborts in the fallback path. + */ + static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols); + private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..cabe857b0b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -730,7 +730,27 @@ std::tuple, std::vector, bool> bulk_alloc } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; + // Only the RHT cast-fusion quant kernel supports direct swizzled SF + // emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 -> + // quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject a + // swizzled-flagged output, so we gate on with_rht to avoid silent data + // corruption / hard aborts on non-RHT paths. Additionally we require + // *all* tensors in the group to be shape-eligible for RHT cast-fusion, + // because the grouped kernel honours a single boolean and the unfused + // fallback rejects swizzled output (see NVTE_CHECK at + // group_row_cast_col_hadamard_transform_cast_fusion.cu and + // quantize_with_rht_unfused_helper). + bool all_tensors_rht_cast_fusion_eligible = true; + for (size_t i = 0; i < num_tensors; ++i) { + const auto [rows, cols] = get_2d_dims(shape_list[i]); + if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) { + all_tensors_rht_cast_fusion_eligible = false; + break; + } + } + const bool with_gemm_swizzled_scales = + quantizer_cpp_list[0]->optimize_for_gemm && quantizer_cpp_list[0]->with_rht && + all_tensors_rht_cast_fusion_eligible; // Helper function to get size of byte buffer holding FP4 data (last dim divided by 2) auto fp4_byte_shape = [](const std::vector &shape) -> std::vector { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..6c9d18927d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1760,18 +1760,39 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } +bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(size_t rows, size_t cols) { + // Must mirror the eligibility check in NVFP4Quantizer::quantize_impl + // (search for "eligible_for_rht_cast_fusion" in this file). The dtype + // check (BF16) is implicit: with_rht is only valid for BF16 input by + // construction. + return rows % 64 == 0 && cols % 128 == 0 && + transformer_engine::cuda::sm_arch() >= 100 && + transformer_engine::cuda::sm_arch() <= 110; +} + std::pair NVFP4Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional device_opt, bool pin_memory) const { const auto device = resolve_device(device_opt); using namespace pybind11::literals; - // Scaling factor format - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm - // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + + // Scaling factor format. + // Only the RHT cast-fusion quant kernel + // (row_cast_col_hadamard_transform_cast_fusion.cu) emits SF in the GEMM- + // swizzled layout. Non-RHT NVFP4 quant kernels (quantize_nvfp4.cuh and + // quantize_transpose_nvfp4.cuh) NVTE_CHECK reject a swizzled-flagged + // output, so we gate on with_rht. The RHT-unfused fallback (taken when + // the input shape is ineligible for the cast-fusion kernel) also + // rejects swizzled output, so we additionally gate on shape eligibility + // -- otherwise irregular shapes like (8192, 11328) silently set the + // flag here and then hard-abort deep inside quantize_with_rht_unfused. + const bool with_gemm_swizzled_scales = + this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(flat_first_dim, flat_last_dim); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2051,9 +2072,6 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); - // Scaling factor format - const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm - // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { auto attr_py = tensor.attr(name); @@ -2084,6 +2102,13 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); + + // Scaling factor format. See NVFP4Quantizer::create_tensor for the + // rationale on gating on (with_rht && shape eligibility). + const bool with_gemm_swizzled_scales = + this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(flat_first_dim, flat_last_dim); + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -2205,6 +2230,18 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream) { // only triggered for irregular shapes where RHT cast fusion kernel is not eligible + // The unfused fallback dispatches to nvte_quantize_v2 / nvte_hadamard_transform, + // neither of which supports emitting SF in the GEMM-swizzled layout (their + // backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean + // error here instead of letting it abort deep inside the kernel with an + // opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that + // contract specifically when optimize_for_gemm=True. + NVTE_CHECK(!out.get_with_gemm_swizzled_scales(), + "NVFP4 RHT-unfused fallback path does not support " + "with_gemm_swizzled_scales=True. Either disable optimize_for_gemm on the " + "quantizer, or ensure the input shape is eligible for RHT cast-fusion " + "(bf16 dtype + rows%64==0 + cols%128==0 + SM 100/110)."); + if (rowwise_usage) { // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise TensorWrapper out_identity(out.scaling_mode()); From 6728675329230bc33b885b56b993d4840fae72c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 03:49:57 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/benchmark_rht_cast_swizzle_fusion.py | 4 ++-- benchmarks/profile_rht_cast_swizzle_fusion.py | 11 +++-------- .../nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py | 8 ++++---- transformer_engine/pytorch/csrc/extensions/cast.cpp | 6 +++--- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +-- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmark_rht_cast_swizzle_fusion.py b/benchmarks/benchmark_rht_cast_swizzle_fusion.py index 2f1cca13fb..13f264862c 100644 --- a/benchmarks/benchmark_rht_cast_swizzle_fusion.py +++ b/benchmarks/benchmark_rht_cast_swizzle_fusion.py @@ -171,9 +171,9 @@ def run_shape(shape, min_run_time: float): min_run_time = args.min_run_time print( - f"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. " + "NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. " f"min_run_time={min_run_time}s per cell, BF16 input, " - f"rowwise+columnwise SF, RHT=True+post_rht_amax." + "rowwise+columnwise SF, RHT=True+post_rht_amax." ) rows = [] for shape in shapes: diff --git a/benchmarks/profile_rht_cast_swizzle_fusion.py b/benchmarks/profile_rht_cast_swizzle_fusion.py index aa0f588a4f..97d2a3e564 100644 --- a/benchmarks/profile_rht_cast_swizzle_fusion.py +++ b/benchmarks/profile_rht_cast_swizzle_fusion.py @@ -73,9 +73,7 @@ def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20): t = q(x) tex.swizzle_scales_for_gemm_(t) torch.cuda.synchronize() - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA] - ) as prof: + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: for _ in range(n_iters): t = q(x) tex.swizzle_scales_for_gemm_(t) @@ -106,16 +104,13 @@ def main(): ) print("\n=== VERDICT ===") - base_swizzle = sum( - c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n) - ) + base_swizzle = sum(c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n)) swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n)) print(f" baseline standalone swizzle kernel launches: {base_swizzle}") print(f" SUT standalone swizzle kernel launches: {swf_swizzle}") if swf_swizzle == 0 and base_swizzle > 0: print( - " PASS: standalone swizzle pass disappears from timeline under " - "optimize_for_gemm=True" + " PASS: standalone swizzle pass disappears from timeline under optimize_for_gemm=True" ) else: print( diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py index 6cd331b437..3059adea8c 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py @@ -103,11 +103,11 @@ def _check_nvfp4_rht_quantize_swizzle_fusion( quantizer_swizzle_fusion = quantizer.copy() quantizer_swizzle_fusion.optimize_for_gemm = True - qx_swf, sx_swf, amax_row_swf, qx_t_swf, sx_t_swf, amax_col_swf = ( - _unpack_quantized_tensor(quantizer_swizzle_fusion(x)) + qx_swf, sx_swf, amax_row_swf, qx_t_swf, sx_t_swf, amax_col_swf = _unpack_quantized_tensor( + quantizer_swizzle_fusion(x) ) - qx_ref, sx_ref, amax_row_ref, qx_t_ref, sx_t_ref, amax_col_ref = ( - _unpack_quantized_tensor(quantizer(x)) + qx_ref, sx_ref, amax_row_ref, qx_t_ref, sx_t_ref, amax_col_ref = _unpack_quantized_tensor( + quantizer(x) ) if return_rowwise: diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index cabe857b0b..288b3374bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -748,9 +748,9 @@ std::tuple, std::vector, bool> bulk_alloc break; } } - const bool with_gemm_swizzled_scales = - quantizer_cpp_list[0]->optimize_for_gemm && quantizer_cpp_list[0]->with_rht && - all_tensors_rht_cast_fusion_eligible; + const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm && + quantizer_cpp_list[0]->with_rht && + all_tensors_rht_cast_fusion_eligible; // Helper function to get size of byte buffer holding FP4 data (last dim divided by 2) auto fp4_byte_shape = [](const std::vector &shape) -> std::vector { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6c9d18927d..c914af1fb1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1765,8 +1765,7 @@ bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(size_t rows, size_t cols) { // (search for "eligible_for_rht_cast_fusion" in this file). The dtype // check (BF16) is implicit: with_rht is only valid for BF16 input by // construction. - return rows % 64 == 0 && cols % 128 == 0 && - transformer_engine::cuda::sm_arch() >= 100 && + return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; } From 0f7b867e906bbe92b43192db880954d9ce3f63d9 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 18 May 2026 23:30:05 -0700 Subject: [PATCH 3/5] [PyTorch] NVFP4 RHT cast-fusion: enforce group-wide quantizer config Reviewer feedback: with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout. Following the precedent of the split-quantize path at line 1276 (// Assume all quantizers have identical config), this commit: * adds an explicit comment block calling out the group-wide identical-config assumption and which fields this PR enforces vs. which are pre-existing; * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm and with_rht across the group (the two fields the with_gemm_swizzled_scales gate depends on), with error messages that print the offending tensor index and the disagreeing values; * extracts the [0] reads into group_optimize_for_gemm and group_with_rht locals so the same value feeds both the check and the gate. Other from-[0] reads (rowwise_usage, row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype) are pre-existing assumptions and remain out of scope for this PR. Signed-off-by: Cael Ling --- .../pytorch/csrc/extensions/cast.cpp | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 288b3374bf..6cb81080e3 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -719,7 +719,17 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } - // Quantization parameters + // Quantization parameters. Like the NVFP4 split-quantize path + // (see split_quantize_nvfp4_impl in this file), we assume all + // quantizers in the group share an identical config and read + // group-wide flags from quantizer_cpp_list[0]. The grouped RHT + // cast-fusion kernel honours a single with_gemm_swizzled_scales + // boolean across the whole group, so optimize_for_gemm and with_rht + // must in particular agree across all quantizers; the NVTE_CHECK + // loop below enforces that for the fields the swizzled-SF gate + // depends on. (The other group-wide reads from [0] -- rowwise_usage, + // row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are + // pre-existing assumptions and out of scope for this PR.) const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; @@ -730,16 +740,30 @@ std::tuple, std::vector, bool> bulk_alloc } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + // Only the RHT cast-fusion quant kernel supports direct swizzled SF // emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 -> - // quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject a - // swizzled-flagged output, so we gate on with_rht to avoid silent data - // corruption / hard aborts on non-RHT paths. Additionally we require - // *all* tensors in the group to be shape-eligible for RHT cast-fusion, - // because the grouped kernel honours a single boolean and the unfused - // fallback rejects swizzled output (see NVTE_CHECK at - // group_row_cast_col_hadamard_transform_cast_fusion.cu and + // quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject + // a swizzled-flagged output, so we gate on with_rht to avoid silent + // data corruption / hard aborts on non-RHT paths. Additionally we + // require *all* tensors in the group to be shape-eligible for RHT + // cast-fusion, because the grouped kernel honours a single boolean + // and the unfused fallback rejects swizzled output (see NVTE_CHECK + // at group_row_cast_col_hadamard_transform_cast_fusion.cu and // quantize_with_rht_unfused_helper). + const bool group_optimize_for_gemm = quantizer_cpp_list[0]->optimize_for_gemm; + const bool group_with_rht = quantizer_cpp_list[0]->with_rht; + for (size_t i = 1; i < num_tensors; ++i) { + NVTE_CHECK(quantizer_cpp_list[i]->optimize_for_gemm == group_optimize_for_gemm, + "NVFP4 bulk allocation requires all quantizers in the group to share " + "the same optimize_for_gemm value (tensor 0=", + group_optimize_for_gemm, ", tensor ", i, "=", + quantizer_cpp_list[i]->optimize_for_gemm, ")."); + NVTE_CHECK(quantizer_cpp_list[i]->with_rht == group_with_rht, + "NVFP4 bulk allocation requires all quantizers in the group to share " + "the same with_rht value (tensor 0=", + group_with_rht, ", tensor ", i, "=", quantizer_cpp_list[i]->with_rht, ")."); + } bool all_tensors_rht_cast_fusion_eligible = true; for (size_t i = 0; i < num_tensors; ++i) { const auto [rows, cols] = get_2d_dims(shape_list[i]); @@ -748,9 +772,8 @@ std::tuple, std::vector, bool> bulk_alloc break; } } - const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm && - quantizer_cpp_list[0]->with_rht && - all_tensors_rht_cast_fusion_eligible; + const bool with_gemm_swizzled_scales = + group_optimize_for_gemm && group_with_rht && all_tensors_rht_cast_fusion_eligible; // Helper function to get size of byte buffer holding FP4 data (last dim divided by 2) auto fp4_byte_shape = [](const std::vector &shape) -> std::vector { From 521b388ea92ce1e346bf612ef8eff2a60d28eab5 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Tue, 19 May 2026 20:12:05 -0700 Subject: [PATCH 4/5] [PyTorch] NVFP4 RHT cast-fusion: address review feedback Functional fix: - `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT eligibility check (`rows % 64 == 0`), but the grouped kernel asserts `first_logical_dim % 128 == 0` at entry. Shapes with rows in {64, 192, 320, ...} would pass eligibility, set `with_gemm_swizzled_scales=True`, and then hard-abort inside the grouped kernel with an opaque NVTE_CHECK message. Adding a `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion` selects the correct row alignment: 64 for the single-tensor kernel, 128 for the grouped variant. Only the bulk-allocation caller passes `true`; the three single-tensor callers keep the default `false`. Refactors: - `is_eligible_for_rht_cast_fusion` now takes the full tensor shape (`std::vector`) and flattens internally with `get_2d_dims`, so the four call sites no longer pre-flatten and duplicate the flatten rule. - `quantize_impl` delegates the shape/arch eligibility to `is_eligible_for_rht_cast_fusion` instead of inlining the same predicate, and its hand-rolled `rows = product(shape[:-1])` loop is replaced with `get_2d_dims(input.shape())`. The shape/arch eligibility now has a single source of truth. Comment cleanups: - Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`, `create_tensor`, `convert_and_update_tensor`, and `quantize_with_rht_unfused_helper`. Removed cross-references to other functions/files, code narration of subsequent lines, the JAX reference in PyTorch source, and the "see X for rationale" pattern. - Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single brief sentence. Signed-off-by: Cael Ling --- transformer_engine/pytorch/csrc/common.h | 15 ++---- .../pytorch/csrc/extensions/cast.cpp | 34 ++++-------- transformer_engine/pytorch/csrc/quantizer.cpp | 53 +++++++------------ 3 files changed, 33 insertions(+), 69 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4fc5ab28da..13e78d4a04 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -371,18 +371,11 @@ class NVFP4Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; - /*! @brief Whether a 2D shape (rows, cols) is eligible for the - * NVFP4 RHT cast-fusion kernel. - * - * Matches the dispatch logic in NVFP4Quantizer::quantize_impl. - * The dtype check (BF16) is implicit -- with_rht=True requires - * BF16 input by construction, so callers gate on with_rht first. - * When false, the dispatch falls back to quantize_with_rht_unfused - * which cannot emit GEMM-swizzled SF; framework gates that opt - * into with_gemm_swizzled_scales must therefore also check this - * to avoid mismatched-flag aborts in the fallback path. + /*! @brief Whether a tensor of the given shape is eligible for + * the NVFP4 RHT cast-fusion kernel (single-tensor or grouped). */ - static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols); + static bool is_eligible_for_rht_cast_fusion(const std::vector& shape, + bool for_grouped_kernel = false); private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 6cb81080e3..5eb843db6a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -719,17 +719,6 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } - // Quantization parameters. Like the NVFP4 split-quantize path - // (see split_quantize_nvfp4_impl in this file), we assume all - // quantizers in the group share an identical config and read - // group-wide flags from quantizer_cpp_list[0]. The grouped RHT - // cast-fusion kernel honours a single with_gemm_swizzled_scales - // boolean across the whole group, so optimize_for_gemm and with_rht - // must in particular agree across all quantizers; the NVTE_CHECK - // loop below enforces that for the fields the swizzled-SF gate - // depends on. (The other group-wide reads from [0] -- rowwise_usage, - // row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are - // pre-existing assumptions and out of scope for this PR.) const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; @@ -741,16 +730,15 @@ std::tuple, std::vector, bool> bulk_alloc const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; - // Only the RHT cast-fusion quant kernel supports direct swizzled SF - // emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 -> - // quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject - // a swizzled-flagged output, so we gate on with_rht to avoid silent - // data corruption / hard aborts on non-RHT paths. Additionally we - // require *all* tensors in the group to be shape-eligible for RHT - // cast-fusion, because the grouped kernel honours a single boolean - // and the unfused fallback rejects swizzled output (see NVTE_CHECK - // at group_row_cast_col_hadamard_transform_cast_fusion.cu and - // quantize_with_rht_unfused_helper). + // with_gemm_swizzled_scales is a single group-wide boolean baked + // into every output tensor. We can safely request it only when + // (a) every quantizer in the group has optimize_for_gemm and + // with_rht set, and (b) every tensor's shape qualifies for RHT + // cast-fusion. Disagreement among quantizers would silently give + // some outputs a layout that their own quantizer did not request; + // the NVTE_CHECK loop below turns that into a loud error. The + // final flag ANDs all three predicates (the shape-eligibility one + // is computed by the loop further below). const bool group_optimize_for_gemm = quantizer_cpp_list[0]->optimize_for_gemm; const bool group_with_rht = quantizer_cpp_list[0]->with_rht; for (size_t i = 1; i < num_tensors; ++i) { @@ -766,8 +754,8 @@ std::tuple, std::vector, bool> bulk_alloc } bool all_tensors_rht_cast_fusion_eligible = true; for (size_t i = 0; i < num_tensors; ++i) { - const auto [rows, cols] = get_2d_dims(shape_list[i]); - if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) { + if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape_list[i], + /*for_grouped_kernel=*/true)) { all_tensors_rht_cast_fusion_eligible = false; break; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c914af1fb1..a54c3d946b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1760,12 +1760,12 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(size_t rows, size_t cols) { - // Must mirror the eligibility check in NVFP4Quantizer::quantize_impl - // (search for "eligible_for_rht_cast_fusion" in this file). The dtype - // check (BF16) is implicit: with_rht is only valid for BF16 input by - // construction. - return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 && +bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(const std::vector& shape, + bool for_grouped_kernel) { + const auto [rows, cols] = get_2d_dims(shape); + const size_t row_align = for_grouped_kernel ? 128 : 64; + return rows % row_align == 0 && cols % 128 == 0 && + transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; } @@ -1779,19 +1779,11 @@ std::pair NVFP4Quantizer::create_tensor( const std::vector shape_int64(shape.begin(), shape.end()); const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); - // Scaling factor format. - // Only the RHT cast-fusion quant kernel - // (row_cast_col_hadamard_transform_cast_fusion.cu) emits SF in the GEMM- - // swizzled layout. Non-RHT NVFP4 quant kernels (quantize_nvfp4.cuh and - // quantize_transpose_nvfp4.cuh) NVTE_CHECK reject a swizzled-flagged - // output, so we gate on with_rht. The RHT-unfused fallback (taken when - // the input shape is ineligible for the cast-fusion kernel) also - // rejects swizzled output, so we additionally gate on shape eligibility - // -- otherwise irregular shapes like (8192, 11328) silently set the - // flag here and then hard-abort deep inside quantize_with_rht_unfused. + // Swizzled SF is only valid when the RHT cast-fusion path runs; + // other quantize paths reject it. const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && - NVFP4Quantizer::is_eligible_for_rht_cast_fusion(flat_first_dim, flat_last_dim); + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2102,11 +2094,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); - // Scaling factor format. See NVFP4Quantizer::create_tensor for the - // rationale on gating on (with_rht && shape eligibility). + // Swizzled SF is only valid when the RHT cast-fusion path runs; + // other quantize paths reject it. const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && - NVFP4Quantizer::is_eligible_for_rht_cast_fusion(flat_first_dim, flat_last_dim); + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) { @@ -2228,13 +2220,8 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream) { - // only triggered for irregular shapes where RHT cast fusion kernel is not eligible - // The unfused fallback dispatches to nvte_quantize_v2 / nvte_hadamard_transform, - // neither of which supports emitting SF in the GEMM-swizzled layout (their - // backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean - // error here instead of letting it abort deep inside the kernel with an - // opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that - // contract specifically when optimize_for_gemm=True. + // The kernels invoked below reject swizzled-SF output, so trip a clear + // error here before reaching them. NVTE_CHECK(!out.get_with_gemm_swizzled_scales(), "NVFP4 RHT-unfused fallback path does not support " "with_gemm_swizzled_scales=True. Either disable optimize_for_gemm on the " @@ -2324,11 +2311,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input - size_t rows = 1; - for (size_t i = 0; i < input.ndim() - 1; ++i) { - rows *= input.size(i); - } - size_t cols = input.size(input.ndim() - 1); + const auto [rows, cols] = get_2d_dims(input.shape()); const bool row_scaled_nvfp4 = out.get_row_scaled_nvfp4(); if (row_scaled_nvfp4) { @@ -2343,9 +2326,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT - bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && - transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; + const bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(convertShape(input.shape())); // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT, From 93e75925e6b56c769c8b44e947728f3e1154774d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 May 2026 03:15:58 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a54c3d946b..a9d7e4b121 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1764,8 +1764,7 @@ bool NVFP4Quantizer::is_eligible_for_rht_cast_fusion(const std::vector& bool for_grouped_kernel) { const auto [rows, cols] = get_2d_dims(shape); const size_t row_align = for_grouped_kernel ? 128 : 64; - return rows % row_align == 0 && cols % 128 == 0 && - transformer_engine::cuda::sm_arch() >= 100 && + return rows % row_align == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; } @@ -1781,9 +1780,8 @@ std::pair NVFP4Quantizer::create_tensor( // Swizzled SF is only valid when the RHT cast-fusion path runs; // other quantize paths reject it. - const bool with_gemm_swizzled_scales = - this->optimize_for_gemm && this->with_rht && - NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); + const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2096,9 +2094,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Swizzled SF is only valid when the RHT cast-fusion path runs; // other quantize paths reject it. - const bool with_gemm_swizzled_scales = - this->optimize_for_gemm && this->with_rht && - NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); + const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht && + NVFP4Quantizer::is_eligible_for_rht_cast_fusion(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) {