[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
Conversation
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Greptile SummaryThis PR fixes a native segfault in the CUTLASS grouped GEMM path when MoE routing produces a microbatch where every expert receives zero tokens. The PyTorch wrapper already filtered zero-token GEMMs individually, but did not handle the case where all groups were filtered, allowing
Confidence Score: 4/5The fix is narrowly scoped, targets a real crash path, and does not change behaviour for non-empty inputs. Safe to merge. Both changes are small, isolated guards with no effect on the non-empty GEMM path. The early return in te_general_grouped_gemm correctly mirrors the existing return value semantics. The only notable weakness is that the TN/NN test assertions are trivially true (empty-vs-empty compare), so those sub-cases are purely a no-crash check rather than a semantic postcondition. No files require special attention beyond the test assertion coverage noted in the inline comment on test_numerics.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["te_general_grouped_gemm(A, B, D, ...)"] --> B["Loop over all GEMM groups"]
B --> C{"te_A.numel()==0 or\nte_B.numel()==0?"}
C -- Yes --> D["zero_() output if non-empty\nzero_() bias/gelu if grad\ncontinue"]
D --> B
C -- No --> E["Build te_A/B/D wrappers\nappend to vectors"]
E --> B
B -- loop done --> F{"te_A_wrappers.empty()?\n(ALL groups filtered)"}
F -- Yes --> G["return bias ← NEW early return\n(prevents null deref in CUTLASS path)"]
F -- No --> H["swizzle scales\nbuild NVTETensor vectors"]
H --> I["nvte_multi_tensor_gemm(...)"]
I --> J{"num_gemms <= 0?\n← NEW guard"}
J -- Yes --> K["return (no-op)"]
J -- No --> L{"is_hopper &&\nuse_cutlass?"}
L -- No --> M["multi_stream_cublas_gemm"]
L -- Yes --> N["CUTLASS grouped GEMM\n(accesses A[0]/B[0]/D[0])"]
Reviews (1): Last reviewed commit: "[fix] fix grouped GEMM zero-work bug." | Re-trigger Greptile |
| for tensor in out: | ||
| torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) | ||
|
|
||
|
|
||
| def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: | ||
| data = grouped_tensor.rowwise_data |
There was a problem hiding this comment.
Zero-assertion is trivially true for TN and NN layouts
For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.
Description
Handle grouped GEMM calls where all groups are empty.
MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.
Return early after filtering all GEMMs in
te_general_grouped_gemm, and add adefensive
num_gemms <= 0guard in nvte_multi_tensor_gemm.Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.
Type of change
Changes
Please list the changes introduced in this PR:
te_general_grouped_gemmwhen all GEMMs were filtered.num_gemms <= 0guard innvte_multi_tensor_gemm.NVTE_USE_CUTLASS_GROUPED_GEMM=1.TN,NN, andNTlayouts.Testing
pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -sChecklist: