diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..a682fb2f19 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2797,6 +2797,60 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +def test_grouped_gemm_cutlass_empty_groups(layout): + dtype = torch.bfloat16 + z, k, n = 1, 2048, 1536 + m_splits = [0] * z + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad + grad = True + single_output = True + else: + A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + single_output = False + + old_cutlass_env = os.environ.get("NVTE_USE_CUTLASS_GROUPED_GEMM") + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + try: + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + torch.cuda.synchronize() + finally: + if old_cutlass_env is None: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + else: + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = old_cutlass_env + + for tensor in out: + torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) + + def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: data = grouped_tensor.rowwise_data if data is None: diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8589d7045d..68efd9793b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1058,6 +1058,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_gemm); + if (num_gemms <= 0) { + return; + } + const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 9cb1fb7f54..dd1ad905fd 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -535,6 +535,10 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + if (te_A_wrappers.empty()) { + return bias; + } + // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list;