Skip to content

[Common] Fix fused MoE aux loss for sequence aux loss#3018

Open
harryzhou2000 wants to merge 4 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/aux_loss_fix
Open

[Common] Fix fused MoE aux loss for sequence aux loss#3018
harryzhou2000 wants to merge 4 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/aux_loss_fix

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

@harryzhou2000 harryzhou2000 commented May 21, 2026

Summary

  • Allow fused MoE aux loss inputs whose column count is a multiple of the logical expert count.
  • Document that sequence aux loss batches independent sequences along the expert dimension.
  • Extend the PyTorch fused router test to cover expanded expert columns.

Details

Sequence aux loss reshapes router scores from [seq_length * batch_size, num_experts] to [seq_length, batch_size * num_experts] before calling the same aux-loss helper. The logical num_experts is still used in the loss coefficient, while probs.shape[1] is the actual reduction width. The previous assertion required these values to be equal and rejected valid sequence aux-loss inputs when batch_size > 1.

Testing

  • PyTorch fused router UT on B200: 912 passed, 120 skipped, 3 warnings in 24.52s

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/aux_loss_fix branch from 5e0c29e to 6eb11b1 Compare May 21, 2026 02:51
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 21, 2026 02:52
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR relaxes the fused MoE auxiliary-loss validation to support sequence aux loss, where router scores are reshaped from [seq_len × batch, num_experts] to [seq_len, batch × num_experts] before being passed to the aux-loss helper. The coefficient computation intentionally continues to use the logical num_experts, while the actual reduction width is num_cols.

  • CUDA kernel (fused_moe_aux_loss.cu): The equality assertion num_experts == num_cols is replaced with a divisibility check num_cols % num_experts == 0, along with individual positivity guards; the shared-memory capacity call is updated to use num_cols.
  • Test (test_fused_router.py): A new expert_multiplier parameter (values 1 and 2) is introduced; num_cols, tensor shapes, and tokens_per_expert are all derived from num_experts * expert_multiplier, and _get_tolerances is correctly called with num_cols so the floating-point error floor scales with the actual reduction width.

Confidence Score: 5/5

Safe to merge — the change is a targeted relaxation of a single input validation, the kernel arithmetic is unchanged, and the test expansion covers the new code path.

The kernel logic is unmodified; only the pre-launch assertion changes from strict equality to divisibility, which is the correct condition for the sequence aux-loss reshape. The coefficient formula still uses num_experts as intended, tokens_per_expert indexing covers num_cols columns consistently in both kernel and test, and _get_tolerances now scales with the actual reduction width. No functional regressions are introduced.

The only minor item is in fused_moe_aux_loss.cu: check_shared_memory_capacity_num_experts is now called with num_cols but its error message still refers to num_experts, which would produce a misleading diagnostic if the shared-memory limit were ever hit with expert_multiplier > 1.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Replaces the strict num_experts == num_cols assertion with a divisibility check (num_cols % num_experts == 0) and adds individual positivity guards; updates the shared-memory capacity check to use num_cols. The coefficient computation still correctly uses num_experts. One cosmetic issue: check_shared_memory_capacity_num_experts is now passed num_cols, but its error message still says "Try reducing num_experts".
tests/pytorch/test_fused_router.py Adds expert_multiplier parametrize (values 1 and 2) to test_fused_moe_aux_loss, derives num_cols = num_experts * expert_multiplier, widens all tensor shapes and tokens_per_expert accordingly, and correctly passes num_cols to _get_tolerances so the error floor scales with the actual reduction width.

Reviews (2): Last reviewed commit: "[PyTorch] Scale fused aux loss tolerance..." | Re-trigger Greptile

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@ptrendx ptrendx requested a review from denera May 21, 2026 23:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant