From b78659753fd2c66c54806cc16955cfc59ad5203d Mon Sep 17 00:00:00 2001 From: mnovikov Date: Wed, 20 May 2026 20:52:08 -0700 Subject: [PATCH 1/2] fix(grouped_linear): handle all-zero-token forward and backward When a rank receives zero tokens for every local expert (`sum(m_splits) == 0`), `_GroupedLinear.forward` previously fed an empty input into `general_grouped_gemm`, which can crash or hang in the underlying cuBLAS grouped GEMM. Skipping the call entirely would also remove the weights from the autograd graph on that rank, desynchronizing DDP / EP collectives across ranks that did have tokens. Short-circuit `_GroupedLinear.forward` when `sum(m_splits) == 0`: produce an empty output of the correct shape and dtype, register the weights and biases via `ctx.save_for_backward` so DDP sees matching shapes across ranks, and record an `empty_input` flag on ctx. Mirror this in `_GroupedLinear.backward`: when `ctx.empty_input` is set, synthesize correctly-shaped zero gradients (dgrad, wgrad, bgrad) without touching the quantization / GEMM machinery. The `fuse_wgrad_accumulation` + Mcore-DDP path is preserved by flipping `grad_added_to_main_grad` on weights that opt into it, matching the existing `handle_custom_ddp_from_mcore` semantics. Add an `"all"` value to the existing `empty_split` parametrization in `test_sanity_grouped_linear`, exercising the zero-token forward and backward path across the same dtype / FP8 recipe / single_param matrix already in use, and assert that every weight / bias gradient comes back with the correct full shape (zeros are allowed but `None` is not, since that would desynchronize DDP / EP allreduce). The `te.ops.GroupedLinear` (basic-ops API) was verified to already handle the all-zero-token case correctly without changes, so this PR is scoped to `_GroupedLinear` in `module/grouped_linear.py`. Reported in NVIDIA/Megatron-LM#4851; per @Victarry's review, the fix belongs in TE (analogous to TE PR #648 for the FP8 path). Signed-off-by: mnovikov --- tests/pytorch/test_sanity.py | 35 ++++++++++- .../pytorch/module/grouped_linear.py | 60 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..e9f478605b 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -547,7 +547,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("single_param", all_boolean) -@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) +@pytest.mark.parametrize("empty_split", ["first", "last", "middle", "all"]) @pytest.mark.parametrize("num_gemms", [4]) def test_sanity_grouped_linear( dtype, @@ -565,7 +565,15 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + # The "all" empty_split exercises the path where every local expert + # receives zero tokens, e.g. a Megatron-LM MoE rank that didn't receive + # any locally-routed tokens for a microbatch. In that case num_tokens + # is 0; for the other empty_split values exactly one of num_gemms splits + # is empty so num_tokens covers the remaining (num_gemms - 1) groups. + if empty_split == "all": + num_tokens = 0 + else: + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override) if fp8_recipe is not None: @@ -612,12 +620,35 @@ def test_sanity_grouped_linear( m_splits[-1] = 0 elif empty_split == "middle": m_splits[num_gemms // 2] = 0 + elif empty_split == "all": + m_splits = [0] * num_gemms with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_grouped_linear(inp_hidden_states, m_splits) loss = out.sum() loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) + # Weights and biases must receive correctly-shaped gradients (zeros + # allowed) regardless of how many tokens were routed; if any of them + # comes back as None or with a shape mismatch, a DDP / EP allreduce + # would hang or crash because other ranks may have produced + # full-shape grads for the same parameter on the same microbatch. + # `single_param` mode wraps weights/biases as a ``GroupedTensorStorage`` + # whose ``.grad`` accumulation goes through a separate path not + # exercised by these tests, so we only assert on the per-GEMM + # parameter layout here. + if not single_param: + for name, param in te_grouped_linear.named_parameters(): + if not param.requires_grad: + continue + assert param.grad is not None, ( + f"{name}.grad is None after backward with empty_split={empty_split}; " + "this would desynchronize DDP / EP allreduce across ranks" + ) + assert param.grad.shape == param.shape, ( + f"{name}.grad has shape {tuple(param.grad.shape)} " + f"but expected {tuple(param.shape)} (empty_split={empty_split})" + ) @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 627144345c..dd6053ec37 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -114,6 +114,32 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + # Short-circuit when this rank has zero tokens for every local expert. + # Calling the underlying grouped GEMM with sum(m_splits) == 0 can + # crash or hang in cuBLAS, and skipping it entirely would also remove + # the weights from the autograd graph. That asymmetry desynchronizes + # the autograd graphs across ranks (some of which DO have tokens) and + # produces hangs at the next DDP / EP collective. Here we produce an + # empty output of the correct shape and dtype, register weights and + # biases with ctx.save_for_backward so DDP sees matching shapes, and + # let _GroupedLinear.backward return correctly-shaped zero gradients. + if sum(m_splits) == 0: + out_features = weights[0].size(0) + out = torch.empty([0, out_features], dtype=activation_dtype, device=device) + if is_grad_enabled: + ctx.empty_input = True + ctx.num_gemms = num_gemms + ctx.use_bias = use_bias + ctx.activation_dtype = activation_dtype + ctx.inp_shape = inp.shape + ctx.requires_dgrad = inp.requires_grad + ctx.weights_requires_grad = weight_requires_grad + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.save_for_backward(*weights, *biases) + return out, [None] * num_gemms + if is_grad_enabled: + ctx.empty_input = False + # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): if FP8GlobalStateManager.get_fp8_recipe().custom(): @@ -354,6 +380,40 @@ def backward( ctx, grad_output: torch.Tensor, _grad_workspaces ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring + if getattr(ctx, "empty_input", False): + # Counterpart to the empty-input short-circuit in forward. Skip all + # quantization / GEMM work and synthesize correctly-shaped zero + # gradients so that DDP / EP collectives over weight and bias + # gradients see matching shapes across ranks (some of which had + # non-empty input on this microbatch). + saved_tensors = ctx.saved_tensors + N = ctx.num_gemms + weights = saved_tensors[:N] + biases = saved_tensors[N : 2 * N] if ctx.use_bias else [None] * N + device = weights[0].device + dgrad = ( + torch.zeros(ctx.inp_shape, dtype=ctx.activation_dtype, device=device) + if ctx.requires_dgrad + else None + ) + if ctx.weights_requires_grad: + wgrad_list = [torch.zeros_like(w) for w in weights] + if ctx.fuse_wgrad_accumulation: + # Mirror the mcore custom-DDP path: when weight gradient + # accumulation is fused into main_grad on the weight, mark + # main_grad as already updated (we accumulated zero) and + # let downstream skip its own add. + for w in weights: + if hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + wgrad_list = [None] * N + else: + wgrad_list = [None] * N + if ctx.use_bias: + grad_biases = [torch.zeros_like(b) for b in biases] + else: + grad_biases = [None] * N + return (dgrad, None, *wgrad_list, *grad_biases) with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms From 8a7933bcedb175486072697b27342fc95bad7c64 Mon Sep 17 00:00:00 2001 From: mnovikov Date: Thu, 21 May 2026 12:25:49 -0700 Subject: [PATCH 2/2] Drop sum(m_splits)==0 short-circuit; keep only regression test @ptrendx pointed out that `sum(m_splits) == 0` won't survive the move to device-side m_splits / CUDA-graphs compatibility. Re-running the new `empty_split="all"` test against current TE main *without* the Python short-circuit shows 2568 pass / 0 fail across bf16, fp32, and every supported FP8 recipe my Blackwell box can exercise. The per-group skip at gemm.cpp:509-516 plus the existing handling of an empty `nvte_multi_tensor_gemm` call already cover the all-zero-tokens path end-to-end. Keep the new test alone as regression coverage for the path our Megatron-LM SFT depends on. Signed-off-by: mnovikov --- tests/pytorch/test_sanity.py | 4 +- .../pytorch/module/grouped_linear.py | 60 ------------------- 2 files changed, 2 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9f478605b..601593afe5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -566,8 +566,8 @@ def test_sanity_grouped_linear( # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 # The "all" empty_split exercises the path where every local expert - # receives zero tokens, e.g. a Megatron-LM MoE rank that didn't receive - # any locally-routed tokens for a microbatch. In that case num_tokens + # receives zero tokens (e.g. a Megatron-LM MoE rank that didn't receive + # any locally-routed tokens for a microbatch). In that case num_tokens # is 0; for the other empty_split values exactly one of num_gemms splits # is empty so num_tokens covers the remaining (num_gemms - 1) groups. if empty_split == "all": diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index dd6053ec37..627144345c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -114,32 +114,6 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad - # Short-circuit when this rank has zero tokens for every local expert. - # Calling the underlying grouped GEMM with sum(m_splits) == 0 can - # crash or hang in cuBLAS, and skipping it entirely would also remove - # the weights from the autograd graph. That asymmetry desynchronizes - # the autograd graphs across ranks (some of which DO have tokens) and - # produces hangs at the next DDP / EP collective. Here we produce an - # empty output of the correct shape and dtype, register weights and - # biases with ctx.save_for_backward so DDP sees matching shapes, and - # let _GroupedLinear.backward return correctly-shaped zero gradients. - if sum(m_splits) == 0: - out_features = weights[0].size(0) - out = torch.empty([0, out_features], dtype=activation_dtype, device=device) - if is_grad_enabled: - ctx.empty_input = True - ctx.num_gemms = num_gemms - ctx.use_bias = use_bias - ctx.activation_dtype = activation_dtype - ctx.inp_shape = inp.shape - ctx.requires_dgrad = inp.requires_grad - ctx.weights_requires_grad = weight_requires_grad - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.save_for_backward(*weights, *biases) - return out, [None] * num_gemms - if is_grad_enabled: - ctx.empty_input = False - # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): if FP8GlobalStateManager.get_fp8_recipe().custom(): @@ -380,40 +354,6 @@ def backward( ctx, grad_output: torch.Tensor, _grad_workspaces ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if getattr(ctx, "empty_input", False): - # Counterpart to the empty-input short-circuit in forward. Skip all - # quantization / GEMM work and synthesize correctly-shaped zero - # gradients so that DDP / EP collectives over weight and bias - # gradients see matching shapes across ranks (some of which had - # non-empty input on this microbatch). - saved_tensors = ctx.saved_tensors - N = ctx.num_gemms - weights = saved_tensors[:N] - biases = saved_tensors[N : 2 * N] if ctx.use_bias else [None] * N - device = weights[0].device - dgrad = ( - torch.zeros(ctx.inp_shape, dtype=ctx.activation_dtype, device=device) - if ctx.requires_dgrad - else None - ) - if ctx.weights_requires_grad: - wgrad_list = [torch.zeros_like(w) for w in weights] - if ctx.fuse_wgrad_accumulation: - # Mirror the mcore custom-DDP path: when weight gradient - # accumulation is fused into main_grad on the weight, mark - # main_grad as already updated (we accumulated zero) and - # let downstream skip its own add. - for w in weights: - if hasattr(w, "grad_added_to_main_grad"): - w.grad_added_to_main_grad = True - wgrad_list = [None] * N - else: - wgrad_list = [None] * N - if ctx.use_bias: - grad_biases = [torch.zeros_like(b) for b in biases] - else: - grad_biases = [None] * N - return (dgrad, None, *wgrad_list, *grad_biases) with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms