Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2797,6 +2797,60 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
def test_grouped_gemm_cutlass_empty_groups(layout):
dtype = torch.bfloat16
z, k, n = 1, 2048, 1536
m_splits = [0] * z

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad
grad = True
single_output = True
else:
A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
single_output = False

old_cutlass_env = os.environ.get("NVTE_USE_CUTLASS_GROUPED_GEMM")
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
try:
general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
layout=layout,
single_output=single_output,
)
torch.cuda.synchronize()
finally:
if old_cutlass_env is None:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
else:
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = old_cutlass_env

for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
Comment on lines +2850 to 2855
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Zero-assertion is trivially true for TN and NN layouts

For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.

if data is None:
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);

if (num_gemms <= 0) {
return;
}

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
}

if (te_A_wrappers.empty()) {
return bias;
}

// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;

Expand Down