From d41657207dfc8d70d75720be9ef6b9de6ee39805 Mon Sep 17 00:00:00 2001 From: lizamd <161388580+lizamd@users.noreply.github.com> Date: Tue, 5 May 2026 00:02:50 +0000 Subject: [PATCH 1/2] [ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm. The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached. Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M). This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions. Follow-ups (out of scope for this PR): - Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64) and shape-aware tile selection by aggregate M per call. Currently throughput is unchanged on this workload because the existing hipblaslt fallback is well-tuned and the 3 hardcoded CK tile configs (TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding) don't fit MoE shapes (highly variable per-expert M) optimally. Real CK-grouped-GEMM perf wins will materialize once tile selection adapts to M. - Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument rejection (likely small per-expert M values that fail tile-size constraints in the current TileCfg_256x* instantiations). --- tests/pytorch/test_numerics.py | 62 +++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 5 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..33847f5b6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3078,6 +3078,68 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, + reason="Only enable CUTLASS / CK grouped gemm on Hopper or ROCm", +) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("layout", ["TN", "NT"]) +def test_grouped_gemm_fp32_output(input_dtype, layout): + """Verify grouped GEMM with fp16/bf16 inputs and fp32 output goes through + the CUTLASS / CK grouped GEMM path (not the per-expert fallback). Exercises + the dispatcher is_supported_dtype check for the common bf16/bf16/fp32 case + used during training with fp32 gradient accumulation.""" + if input_dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bf16 requires sm_80+") + torch.manual_seed(0) + z, m, k, n = 8, 1027, 128, 512 + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = (torch.tensor(dist + [m]) - torch.tensor([0] + dist)).tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=input_dtype, device="cuda") for _ in range(z)] + B = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(m, n, dtype=torch.float32, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + single_output = True + grad = False + else: # "NT" wgrad: weight gradient in fp32 + A = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + B = list(torch.split(torch.randn(m, n, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(n, k, dtype=torch.float32, device="cuda") for _ in range(z)] + out_ref = [o.clone() for o in out] + single_output = False + grad = True + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + try: + for i in range(z): + general_gemm( + A[i], B[i], + out_dtype=torch.float32, + grad=grad, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, B, out, [None] * z, + out_dtype=torch.float32, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + finally: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..0a19ef1e3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1163,12 +1163,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; + // CK fp16 dispatcher accepts D in {fp32, fp16, bf16} when A==B is fp16/bf16. return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) || ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) + (A_dt == B_dt) && is_fp16_dtype(A_dt) && + (is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32) ); #else auto A_type = get_cuda_dtype(inputA->data.dtype); From 4c031eb43afdf3c36f5e5122f6152cbc09443df7 Mon Sep 17 00:00:00 2001 From: lizamd <161388580+lizamd@users.noreply.github.com> Date: Tue, 12 May 2026 21:16:20 +0000 Subject: [PATCH 2/2] [ROCm] Address review feedback on bf16/bf16/fp32 dispatcher - Drop the inline comment in cublaslt_gemm.cu (rationale moved to PR body). - Fold test_grouped_gemm_fp32_output into test_grouped_gemm via a new fp32_output parametrize, removing the standalone test function. - Use pytest's monkeypatch fixture for NVTE_USE_CUTLASS_GROUPED_GEMM instead of mutating os.environ directly, so the test no longer assumes the user had the env var unset. --- tests/pytorch/test_numerics.py | 89 ++++--------------- .../common/gemm/cublaslt_gemm.cu | 1 - 2 files changed, 17 insertions(+), 73 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 33847f5b6..46da76448 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2999,7 +2999,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) -def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): +@pytest.mark.parametrize("fp32_output", [False, True], ids=["out=input", "out=fp32"]) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, fp32_output, monkeypatch): + # Mixed-precision output (bf16/bf16/fp32, fp16/fp16/fp32) only goes + # through the CUTLASS / CK grouped GEMM path; the multi-stream cublasLt + # fallback requires A_dt == B_dt == D_dt, and accumulate is incompatible + # with the mixed-precision output path. + if fp32_output and (not use_cutlass or accumulate): + pytest.skip("fp32 output requires use_cutlass=True and accumulate=False") + torch.manual_seed(0) z, m, k, n = shape @@ -3008,10 +3016,12 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): assert m_splits.sum() == m and len(m_splits) == z m_splits = m_splits.tolist() + out_dtype = torch.float32 if fp32_output else dtype + if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out = [torch.randn(m, n, dtype=out_dtype, device="cuda")] # output out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = False single_output = True @@ -3020,7 +3030,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B = list( torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) ) # grad_output - out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out = [torch.randn(m, k, dtype=out_dtype, device="cuda")] # dgrad out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = True single_output = True @@ -3029,19 +3039,19 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B = list( torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) ) # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out = [torch.randn(n, k, dtype=out_dtype, device="cuda") for _ in range(z)] # wgrad out_ref = [o.clone() for o in out] grad = True single_output = False if use_cutlass: - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") for i in range(z): general_gemm( A[i], B[i], - dtype, + out_dtype, grad=grad, accumulate=accumulate, layout=layout, @@ -3055,7 +3065,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): B, out, [None] * z, - dtype, + out_dtype, m_splits=m_splits, grad=grad, accumulate=accumulate, @@ -3074,71 +3084,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): else: torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) - if use_cutlass: - os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - - -@pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, - reason="Only enable CUTLASS / CK grouped gemm on Hopper or ROCm", -) -@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("layout", ["TN", "NT"]) -def test_grouped_gemm_fp32_output(input_dtype, layout): - """Verify grouped GEMM with fp16/bf16 inputs and fp32 output goes through - the CUTLASS / CK grouped GEMM path (not the per-expert fallback). Exercises - the dispatcher is_supported_dtype check for the common bf16/bf16/fp32 case - used during training with fp32 gradient accumulation.""" - if input_dtype == torch.bfloat16 and not is_bf16_available(): - pytest.skip("bf16 requires sm_80+") - torch.manual_seed(0) - z, m, k, n = 8, 1027, 128, 512 - - dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() - m_splits = (torch.tensor(dist + [m]) - torch.tensor([0] + dist)).tolist() - - if layout == "TN": - A = [torch.randn(n, k, dtype=input_dtype, device="cuda") for _ in range(z)] - B = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) - out = [torch.empty(m, n, dtype=torch.float32, device="cuda")] - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - single_output = True - grad = False - else: # "NT" wgrad: weight gradient in fp32 - A = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) - B = list(torch.split(torch.randn(m, n, dtype=input_dtype, device="cuda"), m_splits)) - out = [torch.empty(n, k, dtype=torch.float32, device="cuda") for _ in range(z)] - out_ref = [o.clone() for o in out] - single_output = False - grad = True - - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - try: - for i in range(z): - general_gemm( - A[i], B[i], - out_dtype=torch.float32, - grad=grad, - layout=layout, - out=out_ref[i], - ) - if single_output: - out_ref = [torch.cat(out_ref)] - - general_grouped_gemm( - A, B, out, [None] * z, - out_dtype=torch.float32, - m_splits=m_splits, - grad=grad, - layout=layout, - single_output=single_output, - ) - - for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) - finally: - os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 0a19ef1e3..44735d9bb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1163,7 +1163,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - // CK fp16 dispatcher accepts D in {fp32, fp16, bf16} when A==B is fp16/bf16. return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) ||