Skip to content

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037

Open
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm
Open

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm

Conversation

@Baibaifan
Copy link
Copy Markdown

Description

Handle grouped GEMM calls where all groups are empty.

MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.

Return early after filtering all GEMMs in te_general_grouped_gemm, and add a
defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.

Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

Please list the changes introduced in this PR:

  • Return early from te_general_grouped_gemm when all GEMMs were filtered.
  • Add a defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.
  • Add a Hopper-only regression test for all-empty grouped GEMM inputs under
    NVTE_USE_CUTLASS_GROUPED_GEMM=1.
  • Cover TN, NN, and NT layouts.

Testing

pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -s

Checklist:

  • I have read and followed the [contributing guidelines]
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Baibaifan Baibaifan requested a review from ksivaman as a code owner May 22, 2026 04:06
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR fixes a native segfault in the CUTLASS grouped GEMM path when MoE routing produces a microbatch where every expert receives zero tokens. The PyTorch wrapper already filtered zero-token GEMMs individually, but did not handle the case where all groups were filtered, allowing nvte_multi_tensor_gemm to be invoked with an empty tensor array and dereference A[0]/B[0]/D[0].

  • gemm.cpp: After the per-group filtering loop, an early return is added when te_A_wrappers is empty, returning bias to match the function's existing normal return value.
  • cublaslt_gemm.cu: A num_gemms <= 0 guard is added at the top of nvte_multi_tensor_gemm as a secondary defence for direct C-API callers.
  • test_numerics.py: A Hopper-only regression test exercises TN, NN, and NT layouts with all-zero m_splits; the NT case meaningfully asserts that the wgrad buffer is zeroed in-place, while TN/NN serve primarily as no-crash guards.

Confidence Score: 4/5

The fix is narrowly scoped, targets a real crash path, and does not change behaviour for non-empty inputs. Safe to merge.

Both changes are small, isolated guards with no effect on the non-empty GEMM path. The early return in te_general_grouped_gemm correctly mirrors the existing return value semantics. The only notable weakness is that the TN/NN test assertions are trivially true (empty-vs-empty compare), so those sub-cases are purely a no-crash check rather than a semantic postcondition.

No files require special attention beyond the test assertion coverage noted in the inline comment on test_numerics.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds an early return in te_general_grouped_gemm when all GEMM groups are filtered (te_A_wrappers is empty); returns the bias vector, consistent with the existing return at the bottom of the function. The fix is logically correct: empty groups are already handled by the in-loop continue + zero_(), and the early exit prevents a subsequent call to nvte_multi_tensor_gemm with a zero-length vector.
transformer_engine/common/gemm/cublaslt_gemm.cu Adds a defensive num_gemms <= 0 guard at the entry of nvte_multi_tensor_gemm. This is a secondary safety net for callers that bypass te_general_grouped_gemm and invoke the C API directly with an empty array.
tests/pytorch/test_numerics.py Adds a Hopper-only regression test for all-empty grouped GEMM. TN/NN out-tensors are 0-element so the zero-check is trivial (useful only as a no-crash guard); NT out-tensor is a real (n,k) buffer whose in-place zero_() is meaningfully exercised. Test env-var setup/teardown follows the existing pattern.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["te_general_grouped_gemm(A, B, D, ...)"] --> B["Loop over all GEMM groups"]
    B --> C{"te_A.numel()==0 or\nte_B.numel()==0?"}
    C -- Yes --> D["zero_() output if non-empty\nzero_() bias/gelu if grad\ncontinue"]
    D --> B
    C -- No --> E["Build te_A/B/D wrappers\nappend to vectors"]
    E --> B
    B -- loop done --> F{"te_A_wrappers.empty()?\n(ALL groups filtered)"}
    F -- Yes --> G["return bias  ← NEW early return\n(prevents null deref in CUTLASS path)"]
    F -- No --> H["swizzle scales\nbuild NVTETensor vectors"]
    H --> I["nvte_multi_tensor_gemm(...)"]
    I --> J{"num_gemms <= 0?\n← NEW guard"}
    J -- Yes --> K["return (no-op)"]
    J -- No --> L{"is_hopper &&\nuse_cutlass?"}
    L -- No --> M["multi_stream_cublas_gemm"]
    L -- Yes --> N["CUTLASS grouped GEMM\n(accesses A[0]/B[0]/D[0])"]
Loading

Reviews (1): Last reviewed commit: "[fix] fix grouped GEMM zero-work bug." | Re-trigger Greptile

Comment on lines +2850 to 2855
for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Zero-assertion is trivially true for TN and NN layouts

For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant