From 325934c731f3f1623d9f57fdca6a2e4c29c42af5 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Thu, 21 May 2026 10:14:56 +0800 Subject: [PATCH 1/4] [Common] Allow expanded columns in fused MoE aux loss Signed-off-by: Harry Zhou --- .../common/fused_router/fused_moe_aux_loss.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 7e516af97b..ab11ba8b43 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -87,8 +87,10 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, int num_cols, int topk, float coeff, DataType* aux_loss, float* Coeff_buf, cudaStream_t stream) { - NVTE_CHECK(num_experts == num_cols, "Number of experts (", num_experts, - ") must be equal to number of input columns (", num_cols, ")."); + NVTE_CHECK(num_cols > 0, "num_cols must be positive, got ", num_cols); + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(num_cols % num_experts == 0, "Number of input columns (", num_cols, + ") must be a multiple of number of experts (", num_experts, ")."); // Round up to a multiple of warp size for correct warp shuffles. const int block_size = ((std::min(1024, num_cols) + static_cast(kThreadsPerWarp) - 1) / @@ -98,7 +100,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, // One CompType per thread in shared memory. const size_t smem_size = block_size * sizeof(CompType); - check_shared_memory_capacity_num_experts(smem_size, num_experts); + check_shared_memory_capacity_num_experts(smem_size, num_cols); // Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch. const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; From dfddeb2f23933ee44a0adb10e5c8ffb28098cf27 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Thu, 21 May 2026 10:34:17 +0800 Subject: [PATCH 2/4] [PyTorch] Cover expanded columns in fused MoE aux loss test Signed-off-by: Harry Zhou --- tests/pytorch/test_fused_router.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 274a35b81d..cc03fe4003 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -414,17 +414,20 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [1024, 256, 128, 32]) @pytest.mark.parametrize("topk", [4, 32]) -def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): +@pytest.mark.parametrize("expert_multiplier", [1, 2]) +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk, expert_multiplier): if topk >= num_experts: pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") + # Sequence aux loss batches independent sequences along the expert dimension. + num_cols = num_experts * expert_multiplier # Construct the special probs to avoid inf in the sigmoid function offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 - probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + probs = torch.arange(-num_cols // 2, num_cols // 2, device="cuda", dtype=dtype) * 1e-2 probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) - probs = probs.view(num_tokens, num_experts) + probs = probs.view(num_tokens, num_cols) probs.requires_grad = True - tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) + tokens_per_expert = torch.randint(1, 1000, (num_cols,), device="cuda", dtype=torch.int32) coeff = 0.01 probs_clone = deepcopy(probs) From 6eb11b17e8b0e680fb9e372247f30bf77e641b62 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Thu, 21 May 2026 10:46:19 +0800 Subject: [PATCH 3/4] [Common] Document sequence aux loss column expansion Signed-off-by: Harry Zhou --- transformer_engine/common/fused_router/fused_moe_aux_loss.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index ab11ba8b43..cc5e5e3bcc 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -89,6 +89,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, cudaStream_t stream) { NVTE_CHECK(num_cols > 0, "num_cols must be positive, got ", num_cols); NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + // Sequence aux loss batches independent sequences along the expert dimension. NVTE_CHECK(num_cols % num_experts == 0, "Number of input columns (", num_cols, ") must be a multiple of number of experts (", num_experts, ")."); From f55e9014d887edc5ebc2cc3426a5ab3f0ef1f2b8 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Thu, 21 May 2026 11:46:37 +0800 Subject: [PATCH 4/4] [PyTorch] Scale fused aux loss tolerance by column count Signed-off-by: Harry Zhou --- tests/pytorch/test_fused_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index cc03fe4003..f54d16abe2 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -451,7 +451,7 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk, expert_multipl coeff=coeff, ) - atol, rtol = _get_tolerances(dtype, num_experts) + atol, rtol = _get_tolerances(dtype, num_cols) torch.testing.assert_close(aux_loss, aux_loss_fused, atol=atol, rtol=rtol) # Backward