diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..601593afe5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -547,7 +547,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("single_param", all_boolean) -@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) +@pytest.mark.parametrize("empty_split", ["first", "last", "middle", "all"]) @pytest.mark.parametrize("num_gemms", [4]) def test_sanity_grouped_linear( dtype, @@ -565,7 +565,15 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + # The "all" empty_split exercises the path where every local expert + # receives zero tokens (e.g. a Megatron-LM MoE rank that didn't receive + # any locally-routed tokens for a microbatch). In that case num_tokens + # is 0; for the other empty_split values exactly one of num_gemms splits + # is empty so num_tokens covers the remaining (num_gemms - 1) groups. + if empty_split == "all": + num_tokens = 0 + else: + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override) if fp8_recipe is not None: @@ -612,12 +620,35 @@ def test_sanity_grouped_linear( m_splits[-1] = 0 elif empty_split == "middle": m_splits[num_gemms // 2] = 0 + elif empty_split == "all": + m_splits = [0] * num_gemms with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_grouped_linear(inp_hidden_states, m_splits) loss = out.sum() loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) + # Weights and biases must receive correctly-shaped gradients (zeros + # allowed) regardless of how many tokens were routed; if any of them + # comes back as None or with a shape mismatch, a DDP / EP allreduce + # would hang or crash because other ranks may have produced + # full-shape grads for the same parameter on the same microbatch. + # `single_param` mode wraps weights/biases as a ``GroupedTensorStorage`` + # whose ``.grad`` accumulation goes through a separate path not + # exercised by these tests, so we only assert on the per-GEMM + # parameter layout here. + if not single_param: + for name, param in te_grouped_linear.named_parameters(): + if not param.requires_grad: + continue + assert param.grad is not None, ( + f"{name}.grad is None after backward with empty_split={empty_split}; " + "this would desynchronize DDP / EP allreduce across ranks" + ) + assert param.grad.shape == param.shape, ( + f"{name}.grad has shape {tuple(param.grad.shape)} " + f"but expected {tuple(param.shape)} (empty_split={empty_split})" + ) @pytest.mark.parametrize("dtype", param_types)