[Common] Fix fused MoE aux loss for sequence aux loss#3018
[Common] Fix fused MoE aux loss for sequence aux loss#3018harryzhou2000 wants to merge 4 commits into
Conversation
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
5e0c29e to
6eb11b1
Compare
Greptile SummaryThis PR relaxes the fused MoE auxiliary-loss validation to support sequence aux loss, where router scores are reshaped from
Confidence Score: 5/5Safe to merge — the change is a targeted relaxation of a single input validation, the kernel arithmetic is unchanged, and the test expansion covers the new code path. The kernel logic is unmodified; only the pre-launch assertion changes from strict equality to divisibility, which is the correct condition for the sequence aux-loss reshape. The coefficient formula still uses The only minor item is in Important Files Changed
Reviews (2): Last reviewed commit: "[PyTorch] Scale fused aux loss tolerance..." | Re-trigger Greptile |
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Summary
Details
Sequence aux loss reshapes router scores from
[seq_length * batch_size, num_experts]to[seq_length, batch_size * num_experts]before calling the same aux-loss helper. The logicalnum_expertsis still used in the loss coefficient, whileprobs.shape[1]is the actual reduction width. The previous assertion required these values to be equal and rejected valid sequence aux-loss inputs whenbatch_size > 1.Testing
912 passed, 120 skipped, 3 warnings in 24.52s