fix(grouped_linear): handle all-zero-token forward and backward#3019
fix(grouped_linear): handle all-zero-token forward and backward#3019jubick1337 wants to merge 2 commits into
Conversation
When a rank receives zero tokens for every local expert (`sum(m_splits) == 0`), `_GroupedLinear.forward` previously fed an empty input into `general_grouped_gemm`, which can crash or hang in the underlying cuBLAS grouped GEMM. Skipping the call entirely would also remove the weights from the autograd graph on that rank, desynchronizing DDP / EP collectives across ranks that did have tokens. Short-circuit `_GroupedLinear.forward` when `sum(m_splits) == 0`: produce an empty output of the correct shape and dtype, register the weights and biases via `ctx.save_for_backward` so DDP sees matching shapes across ranks, and record an `empty_input` flag on ctx. Mirror this in `_GroupedLinear.backward`: when `ctx.empty_input` is set, synthesize correctly-shaped zero gradients (dgrad, wgrad, bgrad) without touching the quantization / GEMM machinery. The `fuse_wgrad_accumulation` + Mcore-DDP path is preserved by flipping `grad_added_to_main_grad` on weights that opt into it, matching the existing `handle_custom_ddp_from_mcore` semantics. Add an `"all"` value to the existing `empty_split` parametrization in `test_sanity_grouped_linear`, exercising the zero-token forward and backward path across the same dtype / FP8 recipe / single_param matrix already in use, and assert that every weight / bias gradient comes back with the correct full shape (zeros are allowed but `None` is not, since that would desynchronize DDP / EP allreduce). The `te.ops.GroupedLinear` (basic-ops API) was verified to already handle the all-zero-token case correctly without changes, so this PR is scoped to `_GroupedLinear` in `module/grouped_linear.py`. Reported in NVIDIA/Megatron-LM#4851; per @Victarry's review, the fix belongs in TE (analogous to TE PR NVIDIA#648 for the FP8 path). Signed-off-by: mnovikov <mnovikov@nvidia.com>
Greptile SummaryThis PR is now a test-only change. An initial commit added both a Python-level short-circuit in
Confidence Score: 5/5Safe to merge — the only changed file is a test, and the change adds regression coverage for an edge case already handled by the C++ layer. The production code in grouped_linear.py is unchanged. The test correctly exercises the all-zero-token forward/backward path, asserts output shape, and checks per-GEMM parameter gradients. The noted gaps do not affect production correctness. No files require special attention; tests/pytorch/test_sanity.py is the only changed file and the additions are straightforward. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[test_sanity_grouped_linear] --> B{empty_split?}
B -->|first/last/middle| C[num_tokens = bs * max_seqlen * num_gemms-1]
B -->|all| D[num_tokens = 0]
C --> E[m_splits: one zero entry]
D --> F[m_splits: all zeros]
E --> G[forward + backward]
F --> G
G --> H[assert out.shape == num_tokens x ffn_hidden_size]
H --> I{not single_param?}
I -->|Yes| J[assert each param.grad is not None AND has correct shape]
I -->|No| K[skip grad assertions - GroupedTensorStorage path not exercised]
J --> L[Pass]
K --> L
Reviews (2): Last reviewed commit: "Drop sum(m_splits)==0 short-circuit; kee..." | Re-trigger Greptile |
| if ctx.weights_requires_grad: | ||
| wgrad_list = [torch.zeros_like(w) for w in weights] | ||
| if ctx.fuse_wgrad_accumulation: |
There was a problem hiding this comment.
When
fuse_wgrad_accumulation=False and FP8 primary weights are used (primary_weights_in_fp8=True), weights here may be QuantizedTensorStorage objects with an FP8 element type (e.g., torch.float8_e4m3fn). torch.zeros_like(w) would inherit that FP8 dtype, but the normal backward path explicitly allocates wgrads with ctx.activation_dtype via tex.bulk_allocate. Returning FP8-typed zero wgrads from the empty path while the normal path returns activation_dtype wgrads is inconsistent and will break optimizers that expect higher-precision gradients for FP8 weight parameters.
| if ctx.weights_requires_grad: | |
| wgrad_list = [torch.zeros_like(w) for w in weights] | |
| if ctx.fuse_wgrad_accumulation: | |
| if ctx.weights_requires_grad: | |
| wgrad_list = [ | |
| torch.zeros(w.size(), dtype=ctx.activation_dtype, device=device) | |
| for w in weights | |
| ] | |
| if ctx.fuse_wgrad_accumulation: |
| if ctx.fuse_wgrad_accumulation: | ||
| # Mirror the mcore custom-DDP path: when weight gradient | ||
| # accumulation is fused into main_grad on the weight, mark | ||
| # main_grad as already updated (we accumulated zero) and | ||
| # let downstream skip its own add. | ||
| for w in weights: | ||
| if hasattr(w, "grad_added_to_main_grad"): | ||
| w.grad_added_to_main_grad = True | ||
| wgrad_list = [None] * N |
There was a problem hiding this comment.
zero_out_wgrad not mirrored in empty-input path
The normal handle_custom_ddp_from_mcore checks getattr(weight, "zero_out_wgrad", False) and, when set, returns a zero dummy_wgrad shaped to main_grad (so the DDP reducer zeros main_grad before accumulation). The empty backward skips this: it sets grad_added_to_main_grad = True and then returns [None] * N, which tells the reducer "I already accumulated" but never zeros main_grad. In a gradient-accumulation scenario where a rank hits the empty path on the first microbatch, stale main_grad values from the previous step would survive.
| if ctx.weights_requires_grad: | ||
| wgrad_list = [torch.zeros_like(w) for w in weights] | ||
| if ctx.fuse_wgrad_accumulation: | ||
| # Mirror the mcore custom-DDP path: when weight gradient | ||
| # accumulation is fused into main_grad on the weight, mark | ||
| # main_grad as already updated (we accumulated zero) and | ||
| # let downstream skip its own add. | ||
| for w in weights: | ||
| if hasattr(w, "grad_added_to_main_grad"): | ||
| w.grad_added_to_main_grad = True | ||
| wgrad_list = [None] * N |
There was a problem hiding this comment.
Unnecessary allocation when
fuse_wgrad_accumulation=True
wgrad_list = [torch.zeros_like(w) for w in weights] allocates N zero tensors on every empty backward, but they are immediately discarded two lines later when fuse_wgrad_accumulation=True overrides wgrad_list with [None] * N. Consider guarding the allocation behind an if not ctx.fuse_wgrad_accumulation: check to avoid the redundant GPU memory traffic.
ptrendx
left a comment
There was a problem hiding this comment.
There are a few issues with this, but the most important one is that we cannot really do that check. We are moving towards the device-based m_splits in order to have CUDA graphs compatibility and the comparison sum(m_splits) == 0 is not going to work in that context. We need to make sure that cuBLAS works in that case too.
@ptrendx pointed out that `sum(m_splits) == 0` won't survive the move to device-side m_splits / CUDA-graphs compatibility. Re-running the new `empty_split="all"` test against current TE main *without* the Python short-circuit shows 2568 pass / 0 fail across bf16, fp32, and every supported FP8 recipe my Blackwell box can exercise. The per-group skip at gemm.cpp:509-516 plus the existing handling of an empty `nvte_multi_tensor_gemm` call already cover the all-zero-tokens path end-to-end. Keep the new test alone as regression coverage for the path our Megatron-LM SFT depends on. Signed-off-by: mnovikov <mnovikov@nvidia.com>
Thanks @ptrendx! Some context: we hit this on a Megatron-LM training run (32n H100 MoE SFT) where some EP ranks got zero tokens for a microbatch and the whole job deadlocked at the next NCCL allreduce. We worked around it in Megatron first (#4851 there), then @Victarry pointed us at TE as the proper place to fix. But re-running my coverage on current TE main without the Python check, the all-zero case already works end-to-end on bf16, fp32, and every supported FP8 recipe my Blackwell can execute (2568 pass, 0 fail). The per-group early return at gemm.cpp:509 + the empty nvte_multi_tensor_gemm handling seem to cover it. So I'm reducing this PR to just the regression test: new empty_split="all" parametrize value + a check that weight grads come back with full shape (not None) so DDP/EP allreduce doesn't desync. Keeps this from regressing without adding any code surface. What do you think? |
|
Sounds good. The changes look good to me, the only point I'm not sure about is not testing the single parameter path. Adding @ksivaman to look at that. |
When a rank receives zero tokens for every local expert (
sum(m_splits) == 0),_GroupedLinear.forwardpreviously fed an empty input intogeneral_grouped_gemm, which can crash or hang in the underlying cuBLAS grouped GEMM. Skipping the call entirely would also remove the weights from the autograd graph on that rank, desynchronizing DDP / EP collectives across ranks that did have tokens.Short-circuit
_GroupedLinear.forwardwhensum(m_splits) == 0: produce an empty output of the correct shape and dtype, register the weights and biases viactx.save_for_backwardso DDP sees matching shapes across ranks, and record anempty_inputflag on ctx.Mirror this in
_GroupedLinear.backward: whenctx.empty_inputis set, synthesize correctly-shaped zero gradients (dgrad, wgrad, bgrad) without touching the quantization / GEMM machinery. Thefuse_wgrad_accumulationgrad_added_to_main_gradon weights that opt into it, matching the existinghandle_custom_ddp_from_mcoresemantics.Add an
"all"value to the existingempty_splitparametrization intest_sanity_grouped_linear, exercising the zero-token forward and backward path across the same dtype / FP8 recipe / single_param matrix already in use, and assert that every weight / bias gradient comes back with the correct full shape (zeros are allowed butNoneis not, since that would desynchronize DDP / EP allreduce).The
te.ops.GroupedLinear(basic-ops API) was verified to already handle the all-zero-token case correctly without changes, so this PR is scoped to_GroupedLinearinmodule/grouped_linear.py.Reported in NVIDIA/Megatron-LM#4851; per @Victarry's review, the fix belongs in TE (analogous to TE PR #648 for the FP8 path).
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: