Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def test_sanity_linear_with_zero_tokens(
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("empty_split", ["first", "last", "middle", "all"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype,
Expand All @@ -565,7 +565,15 @@ def test_sanity_grouped_linear(
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
# The "all" empty_split exercises the path where every local expert
# receives zero tokens (e.g. a Megatron-LM MoE rank that didn't receive
# any locally-routed tokens for a microbatch). In that case num_tokens
# is 0; for the other empty_split values exactly one of num_gemms splits
# is empty so num_tokens covers the remaining (num_gemms - 1) groups.
if empty_split == "all":
num_tokens = 0
else:
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)

skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override)
if fp8_recipe is not None:
Expand Down Expand Up @@ -612,12 +620,35 @@ def test_sanity_grouped_linear(
m_splits[-1] = 0
elif empty_split == "middle":
m_splits[num_gemms // 2] = 0
elif empty_split == "all":
m_splits = [0] * num_gemms

with autocast(enabled=use_fp8, recipe=fp8_recipe):
out = te_grouped_linear(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
# Weights and biases must receive correctly-shaped gradients (zeros
# allowed) regardless of how many tokens were routed; if any of them
# comes back as None or with a shape mismatch, a DDP / EP allreduce
# would hang or crash because other ranks may have produced
# full-shape grads for the same parameter on the same microbatch.
# `single_param` mode wraps weights/biases as a ``GroupedTensorStorage``
# whose ``.grad`` accumulation goes through a separate path not
# exercised by these tests, so we only assert on the per-GEMM
# parameter layout here.
if not single_param:
for name, param in te_grouped_linear.named_parameters():
if not param.requires_grad:
continue
assert param.grad is not None, (
f"{name}.grad is None after backward with empty_split={empty_split}; "
"this would desynchronize DDP / EP allreduce across ranks"
)
assert param.grad.shape == param.shape, (
f"{name}.grad has shape {tuple(param.grad.shape)} "
f"but expected {tuple(param.shape)} (empty_split={empty_split})"
)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
Loading