Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions benchmarks/benchmark_rht_cast_swizzle_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output.

For each shape we measure two paths and two builds:

* path = "quant_only": just NVFP4Quantizer(x)
* path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t)
(this is what te.Linear -> tex.generic_gemm does right before the
cuBLAS LT NVFP4 GEMM dispatch)

* build = "baseline": optimize_for_gemm=False
-> quant kernel emits compact SF;
tex.swizzle_scales_for_gemm_ launches the standalone
swizzle_{row,col}_scaling_kernel pass before GEMM.
* build = "swizzle_fusion": optimize_for_gemm=True
-> quant kernel emits GEMM-swizzled SF directly (via the
kEnableSwizzleSFOutput compile-time switch in
row_cast_col_hadamard_transform_cast_fusion.cu);
tex.swizzle_scales_for_gemm_ early-returns and the standalone
swizzle pass disappears from the timeline.

The wall-clock delta on the "quant_plus_swizzle" path is the production
saving of this PR.
"""

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer


def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
q = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
)
q.optimize_for_gemm = optimize_for_gemm
return q


def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float:
"""Returns median wall-clock per call in microseconds."""
timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=min_run_time)
return timing.median * 1e6


def run_shape(shape, min_run_time: float):
M, K = shape
assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16"

x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda")
q_base = make_quantizer(optimize_for_gemm=False)
q_swf = make_quantizer(optimize_for_gemm=True)

# quant_only path
quant_only_base_us = _bench(
stmt="q(x)",
globals_dict={"q": q_base, "x": x},
min_run_time=min_run_time,
)
quant_only_swf_us = _bench(
stmt="q(x)",
globals_dict={"q": q_swf, "x": x},
min_run_time=min_run_time,
)

# quant_plus_swizzle path (this is what te.Linear actually runs)
quant_plus_swizzle_base_us = _bench(
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
globals_dict={"q": q_base, "x": x, "tex": tex},
min_run_time=min_run_time,
)
quant_plus_swizzle_swf_us = _bench(
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
globals_dict={"q": q_swf, "x": x, "tex": tex},
min_run_time=min_run_time,
)

saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us
speedup = (
quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us
if quant_plus_swizzle_swf_us > 0
else float("inf")
)

print(
f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, "
f"SUT={quant_only_swf_us:.2f}us | "
f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, "
f"SUT={quant_plus_swizzle_swf_us:.2f}us "
f"-> saved {saved_us:.2f}us ({speedup:.2f}x)"
)

return {
"shape": shape,
"M": M,
"K": K,
"quant_only_base_us": quant_only_base_us,
"quant_only_swf_us": quant_only_swf_us,
"quant_plus_swizzle_base_us": quant_plus_swizzle_base_us,
"quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us,
"saved_us": saved_us,
"speedup": speedup,
}


# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears):
# ncu -f -o swizzle_fusion --set=full \
# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \
# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--profile",
action="store_true",
help="Run only one shape for use with ncu/nsys; longer min_run_time",
)
parser.add_argument(
"--min-run-time",
type=float,
default=2.0,
help="Minimum total measured time per cell in seconds (benchmark.Timer)",
)
parser.add_argument(
"--csv",
type=str,
default="benchmark_rht_cast_swizzle_fusion.csv",
help="CSV output path",
)
args = parser.parse_args()

if args.profile:
print("Profiling mode enabled (single shape).")
shapes = [(8192, 4096)]
min_run_time = max(5.0, args.min_run_time)
else:
shapes = [
# production-class shapes
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]
min_run_time = args.min_run_time

print(
"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. "
f"min_run_time={min_run_time}s per cell, BF16 input, "
"rowwise+columnwise SF, RHT=True+post_rht_amax."
)
rows = []
for shape in shapes:
print(f"Running {shape} ...")
rows.append(run_shape(shape, min_run_time))

df = pd.DataFrame(rows)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 200)
print()
print(df.to_string(index=False))
df.to_csv(args.csv, index=False)
print(f"\nWrote {args.csv}")
123 changes: 123 additions & 0 deletions benchmarks/profile_rht_cast_swizzle_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Profile that the dedicated swizzle kernels (swizzle_{row,col}_scaling_kernel
in transformer_engine/common/swizzle/swizzle.cu) disappear from the timeline
when NVFP4 RHT cast-fusion emits SF in the GEMM-swizzled layout directly
(optimize_for_gemm=True).

Test setup:
- NVFP4 + RHT + post-RHT amax (same as te.Linear sets up internally)
- rowwise=True AND columnwise=True (covers BOTH swizzle_row_scaling_kernel
and swizzle_col_scaling_kernel; this is what tex.Linear's input quantizer
needs during training because the rowwise tensor is used by the fwd GEMM
and the columnwise tensor is used by the dgrad GEMM)
- tex.swizzle_scales_for_gemm_(t) is what te.Linear -> tex.generic_gemm
calls just before the cuBLAS LT NVFP4 GEMM dispatch
"""

import torch
import transformer_engine.pytorch as te # noqa: F401 must be first
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer


def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
q = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
)
q.optimize_for_gemm = optimize_for_gemm
return q


import re

# Match ONLY the standalone swizzle pass kernels in
# transformer_engine/common/swizzle/swizzle.cu — NOT RHT cast-fusion kernels
# whose mangled name happens to contain "Swizzle" because of the
# `template <..., bool kEnableSwizzleSFOutput, ...>` parameter substring.
STANDALONE_SWIZZLE_RE = re.compile(
r"(?:multi_tensor_(?:un)?swizzle|(?:un)?swizzle)_(?:row|col)_scaling_kernel"
)


def dump_kernel_counts(prof, label: str) -> dict:
print(f"\n=== {label} ===")
counts: dict[str, int] = {}
for ev in prof.events():
if ev.device_type != torch.autograd.DeviceType.CUDA:
continue
counts[ev.name] = counts.get(ev.name, 0) + 1
standalone_swizzle_total = 0
for name, c in sorted(counts.items(), key=lambda kv: -kv[1]):
marker = ""
if STANDALONE_SWIZZLE_RE.search(name):
marker = " <-- STANDALONE SWIZZLE PASS"
standalone_swizzle_total += c
# Truncate long mangled CUTLASS names for readability
short = name if len(name) <= 110 else name[:107] + "..."
print(f" {c:4d} {short}{marker}")
print(f" -- standalone swizzle kernel total: {standalone_swizzle_total}")
return counts


def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20):
q = make_quantizer(optimize_for_gemm=optimize_for_gemm)
# warm-up
for _ in range(3):
t = q(x)
tex.swizzle_scales_for_gemm_(t)
torch.cuda.synchronize()
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
for _ in range(n_iters):
t = q(x)
tex.swizzle_scales_for_gemm_(t)
torch.cuda.synchronize()
return prof


def main():
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = "cuda"
# Shape that hits the production RHT cast-fusion fast-path
# (rows % 64 == 0, cols % 128 == 0, BF16, SM100/110).
M, N = 8192, 4096
x = torch.randn((M, N), dtype=torch.bfloat16, device=device)

print(f"Shape: M={M}, N={N}, dtype=bf16, RHT=True, post_rht_amax=True")
print(f"iters: 20 (after 3 warm-up)")

prof_baseline = profile_path(optimize_for_gemm=False, x=x)
counts_baseline = dump_kernel_counts(
prof_baseline, "BASELINE: optimize_for_gemm=False (separate swizzle kernel)"
)

prof_swf = profile_path(optimize_for_gemm=True, x=x)
counts_swf = dump_kernel_counts(
prof_swf, "SUT: optimize_for_gemm=True (quant emits swizzled SF directly)"
)

print("\n=== VERDICT ===")
base_swizzle = sum(c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n))
swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n))
print(f" baseline standalone swizzle kernel launches: {base_swizzle}")
print(f" SUT standalone swizzle kernel launches: {swf_swizzle}")
if swf_swizzle == 0 and base_swizzle > 0:
print(
" PASS: standalone swizzle pass disappears from timeline under optimize_for_gemm=True"
)
else:
print(
" FAIL: expected baseline > 0 and SUT == 0; check whether SUT actually "
"set with_gemm_swizzled_scales=True on the output tensor"
)


if __name__ == "__main__":
main()
Loading