Skip to content

fix(grouped_linear): handle all-zero-token forward and backward#3019

Open
jubick1337 wants to merge 2 commits into
NVIDIA:mainfrom
jubick1337:mnovikov/fix-grouped-linear-zero-tokens
Open

fix(grouped_linear): handle all-zero-token forward and backward#3019
jubick1337 wants to merge 2 commits into
NVIDIA:mainfrom
jubick1337:mnovikov/fix-grouped-linear-zero-tokens

Conversation

@jubick1337
Copy link
Copy Markdown

When a rank receives zero tokens for every local expert (sum(m_splits) == 0), _GroupedLinear.forward previously fed an empty input into general_grouped_gemm, which can crash or hang in the underlying cuBLAS grouped GEMM. Skipping the call entirely would also remove the weights from the autograd graph on that rank, desynchronizing DDP / EP collectives across ranks that did have tokens.

Short-circuit _GroupedLinear.forward when sum(m_splits) == 0: produce an empty output of the correct shape and dtype, register the weights and biases via ctx.save_for_backward so DDP sees matching shapes across ranks, and record an empty_input flag on ctx.

Mirror this in _GroupedLinear.backward: when ctx.empty_input is set, synthesize correctly-shaped zero gradients (dgrad, wgrad, bgrad) without touching the quantization / GEMM machinery. The fuse_wgrad_accumulation

  • Mcore-DDP path is preserved by flipping grad_added_to_main_grad on weights that opt into it, matching the existing
    handle_custom_ddp_from_mcore semantics.

Add an "all" value to the existing empty_split parametrization in test_sanity_grouped_linear, exercising the zero-token forward and backward path across the same dtype / FP8 recipe / single_param matrix already in use, and assert that every weight / bias gradient comes back with the correct full shape (zeros are allowed but None is not, since that would desynchronize DDP / EP allreduce).

The te.ops.GroupedLinear (basic-ops API) was verified to already handle the all-zero-token case correctly without changes, so this PR is scoped to _GroupedLinear in module/grouped_linear.py.

Reported in NVIDIA/Megatron-LM#4851; per @Victarry's review, the fix belongs in TE (analogous to TE PR #648 for the FP8 path).

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

When a rank receives zero tokens for every local expert
(`sum(m_splits) == 0`), `_GroupedLinear.forward` previously fed an empty
input into `general_grouped_gemm`, which can crash or hang in the
underlying cuBLAS grouped GEMM. Skipping the call entirely would also
remove the weights from the autograd graph on that rank, desynchronizing
DDP / EP collectives across ranks that did have tokens.

Short-circuit `_GroupedLinear.forward` when `sum(m_splits) == 0`:
produce an empty output of the correct shape and dtype, register the
weights and biases via `ctx.save_for_backward` so DDP sees matching
shapes across ranks, and record an `empty_input` flag on ctx.

Mirror this in `_GroupedLinear.backward`: when `ctx.empty_input` is set,
synthesize correctly-shaped zero gradients (dgrad, wgrad, bgrad) without
touching the quantization / GEMM machinery. The `fuse_wgrad_accumulation`
+ Mcore-DDP path is preserved by flipping `grad_added_to_main_grad` on
weights that opt into it, matching the existing
`handle_custom_ddp_from_mcore` semantics.

Add an `"all"` value to the existing `empty_split` parametrization in
`test_sanity_grouped_linear`, exercising the zero-token forward and
backward path across the same dtype / FP8 recipe / single_param matrix
already in use, and assert that every weight / bias gradient comes back
with the correct full shape (zeros are allowed but `None` is not, since
that would desynchronize DDP / EP allreduce).

The `te.ops.GroupedLinear` (basic-ops API) was verified to already
handle the all-zero-token case correctly without changes, so this PR is
scoped to `_GroupedLinear` in `module/grouped_linear.py`.

Reported in NVIDIA/Megatron-LM#4851; per @Victarry's review, the fix
belongs in TE (analogous to TE PR NVIDIA#648 for the FP8 path).

Signed-off-by: mnovikov <mnovikov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR is now a test-only change. An initial commit added both a Python-level short-circuit in _GroupedLinear.forward/backward for the sum(m_splits) == 0 case and a matching regression test; the second commit reverted the grouped_linear.py changes after @ptrendx confirmed the existing per-group skip in the C++ GEMM layer (gemm.cpp:509-516) already handles all-zero-token inputs end-to-end.

  • Adds \"all\" to the empty_split parametrization in test_sanity_grouped_linear, setting num_tokens = 0 and m_splits = [0] * num_gemms to explicitly exercise the Megatron-LM MoE scenario where a rank receives no locally-routed tokens for a microbatch.
  • Adds gradient-shape assertions for all empty_split values when not single_param, verifying that each weight/bias grad is non-None and has the expected shape after backward (preventing DDP/EP allreduce desynchronization).

Confidence Score: 5/5

Safe to merge — the only changed file is a test, and the change adds regression coverage for an edge case already handled by the C++ layer.

The production code in grouped_linear.py is unchanged. The test correctly exercises the all-zero-token forward/backward path, asserts output shape, and checks per-GEMM parameter gradients. The noted gaps do not affect production correctness.

No files require special attention; tests/pytorch/test_sanity.py is the only changed file and the additions are straightforward.

Important Files Changed

Filename Overview
tests/pytorch/test_sanity.py Adds "all" to empty_split parametrization and gradient-shape assertions; the underlying grouped_linear.py fix was reverted as the C++ layer already handles all-zero-token inputs

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[test_sanity_grouped_linear] --> B{empty_split?}
    B -->|first/last/middle| C[num_tokens = bs * max_seqlen * num_gemms-1]
    B -->|all| D[num_tokens = 0]
    C --> E[m_splits: one zero entry]
    D --> F[m_splits: all zeros]
    E --> G[forward + backward]
    F --> G
    G --> H[assert out.shape == num_tokens x ffn_hidden_size]
    H --> I{not single_param?}
    I -->|Yes| J[assert each param.grad is not None AND has correct shape]
    I -->|No| K[skip grad assertions - GroupedTensorStorage path not exercised]
    J --> L[Pass]
    K --> L
Loading

Reviews (2): Last reviewed commit: "Drop sum(m_splits)==0 short-circuit; kee..." | Re-trigger Greptile

Comment on lines +399 to +401
if ctx.weights_requires_grad:
wgrad_list = [torch.zeros_like(w) for w in weights]
if ctx.fuse_wgrad_accumulation:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 When fuse_wgrad_accumulation=False and FP8 primary weights are used (primary_weights_in_fp8=True), weights here may be QuantizedTensorStorage objects with an FP8 element type (e.g., torch.float8_e4m3fn). torch.zeros_like(w) would inherit that FP8 dtype, but the normal backward path explicitly allocates wgrads with ctx.activation_dtype via tex.bulk_allocate. Returning FP8-typed zero wgrads from the empty path while the normal path returns activation_dtype wgrads is inconsistent and will break optimizers that expect higher-precision gradients for FP8 weight parameters.

Suggested change
if ctx.weights_requires_grad:
wgrad_list = [torch.zeros_like(w) for w in weights]
if ctx.fuse_wgrad_accumulation:
if ctx.weights_requires_grad:
wgrad_list = [
torch.zeros(w.size(), dtype=ctx.activation_dtype, device=device)
for w in weights
]
if ctx.fuse_wgrad_accumulation:

Comment on lines +401 to +409
if ctx.fuse_wgrad_accumulation:
# Mirror the mcore custom-DDP path: when weight gradient
# accumulation is fused into main_grad on the weight, mark
# main_grad as already updated (we accumulated zero) and
# let downstream skip its own add.
for w in weights:
if hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
wgrad_list = [None] * N
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 zero_out_wgrad not mirrored in empty-input path

The normal handle_custom_ddp_from_mcore checks getattr(weight, "zero_out_wgrad", False) and, when set, returns a zero dummy_wgrad shaped to main_grad (so the DDP reducer zeros main_grad before accumulation). The empty backward skips this: it sets grad_added_to_main_grad = True and then returns [None] * N, which tells the reducer "I already accumulated" but never zeros main_grad. In a gradient-accumulation scenario where a rank hits the empty path on the first microbatch, stale main_grad values from the previous step would survive.

Comment on lines +399 to +409
if ctx.weights_requires_grad:
wgrad_list = [torch.zeros_like(w) for w in weights]
if ctx.fuse_wgrad_accumulation:
# Mirror the mcore custom-DDP path: when weight gradient
# accumulation is fused into main_grad on the weight, mark
# main_grad as already updated (we accumulated zero) and
# let downstream skip its own add.
for w in weights:
if hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
wgrad_list = [None] * N
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unnecessary allocation when fuse_wgrad_accumulation=True

wgrad_list = [torch.zeros_like(w) for w in weights] allocates N zero tensors on every empty backward, but they are immediately discarded two lines later when fuse_wgrad_accumulation=True overrides wgrad_list with [None] * N. Consider guarding the allocation behind an if not ctx.fuse_wgrad_accumulation: check to avoid the redundant GPU memory traffic.

Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few issues with this, but the most important one is that we cannot really do that check. We are moving towards the device-based m_splits in order to have CUDA graphs compatibility and the comparison sum(m_splits) == 0 is not going to work in that context. We need to make sure that cuBLAS works in that case too.

@ptrendx pointed out that `sum(m_splits) == 0` won't survive the move
to device-side m_splits / CUDA-graphs compatibility. Re-running the new
`empty_split="all"` test against current TE main *without* the Python
short-circuit shows 2568 pass / 0 fail across bf16, fp32, and every
supported FP8 recipe my Blackwell box can exercise. The per-group skip
at gemm.cpp:509-516 plus the existing handling of an empty
`nvte_multi_tensor_gemm` call already cover the all-zero-tokens path
end-to-end.

Keep the new test alone as regression coverage for the path our
Megatron-LM SFT depends on.

Signed-off-by: mnovikov <mnovikov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
@jubick1337
Copy link
Copy Markdown
Author

There are a few issues with this, but the most important one is that we cannot really do that check. We are moving towards the device-based m_splits in order to have CUDA graphs compatibility and the comparison sum(m_splits) == 0 is not going to work in that context. We need to make sure that cuBLAS works in that case too.

Thanks @ptrendx!
Dropping sum(m_splits) == 0.

Some context: we hit this on a Megatron-LM training run (32n H100 MoE SFT) where some EP ranks got zero tokens for a microbatch and the whole job deadlocked at the next NCCL allreduce. We worked around it in Megatron first (#4851 there), then @Victarry pointed us at TE as the proper place to fix.

But re-running my coverage on current TE main without the Python check, the all-zero case already works end-to-end on bf16, fp32, and every supported FP8 recipe my Blackwell can execute (2568 pass, 0 fail). The per-group early return at gemm.cpp:509 + the empty nvte_multi_tensor_gemm handling seem to cover it.

So I'm reducing this PR to just the regression test: new empty_split="all" parametrize value + a check that weight grads come back with full shape (not None) so DDP/EP allreduce doesn't desync. Keeps this from regressing without adding any code surface.

What do you think?

@ptrendx ptrendx dismissed their stale review May 21, 2026 20:03

No longer applies.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 21, 2026

Sounds good. The changes look good to me, the only point I'm not sure about is not testing the single parameter path. Adding @ksivaman to look at that.

@ptrendx ptrendx added org-contribution and removed community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. labels May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants