Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,20 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32])
@pytest.mark.parametrize("topk", [4, 32])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
@pytest.mark.parametrize("expert_multiplier", [1, 2])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk, expert_multiplier):
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
# Sequence aux loss batches independent sequences along the expert dimension.
num_cols = num_experts * expert_multiplier
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
probs = torch.arange(-num_cols // 2, num_cols // 2, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs = probs.view(num_tokens, num_cols)
probs.requires_grad = True

tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
tokens_per_expert = torch.randint(1, 1000, (num_cols,), device="cuda", dtype=torch.int32)
coeff = 0.01

probs_clone = deepcopy(probs)
Expand All @@ -448,7 +451,7 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
coeff=coeff,
)

atol, rtol = _get_tolerances(dtype, num_experts)
atol, rtol = _get_tolerances(dtype, num_cols)
torch.testing.assert_close(aux_loss, aux_loss_fused, atol=atol, rtol=rtol)

# Backward
Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/common/fused_router/fused_moe_aux_loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
int num_cols, int topk, float coeff,
DataType* aux_loss, float* Coeff_buf,
cudaStream_t stream) {
NVTE_CHECK(num_experts == num_cols, "Number of experts (", num_experts,
") must be equal to number of input columns (", num_cols, ").");
NVTE_CHECK(num_cols > 0, "num_cols must be positive, got ", num_cols);
NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts);
// Sequence aux loss batches independent sequences along the expert dimension.
NVTE_CHECK(num_cols % num_experts == 0, "Number of input columns (", num_cols,
") must be a multiple of number of experts (", num_experts, ").");

// Round up to a multiple of warp size for correct warp shuffles.
const int block_size = ((std::min(1024, num_cols) + static_cast<int>(kThreadsPerWarp) - 1) /
Expand All @@ -98,7 +101,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,

// One CompType per thread in shared memory.
const size_t smem_size = block_size * sizeof(CompType);
check_shared_memory_capacity_num_experts(smem_size, num_experts);
check_shared_memory_capacity_num_experts(smem_size, num_cols);

// Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch.
const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
Expand Down
Loading