Skip to content
72 changes: 56 additions & 16 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@
torch.cuda.manual_seed(seed)


def _get_tolerances(dtype: torch.dtype, num_experts: int):
"""Return (atol, rtol) scaled by the number of experts.

With many experts the fused and reference kernels accumulate
floating-point reductions (e.g. normalization sums) in different
orders, causing O(num_experts * machine_eps) rounding divergence.
Scale the default tolerances accordingly so that small expert
counts keep tight checks while large counts (1024+) get the
headroom they need.
"""
# Default tolerances for torch.testing.assert_close
base_atol, base_rtol = 1e-5, 1.3e-6

eps = {
torch.float32: 2e-7,
torch.float16: 1e-3,
torch.bfloat16: 4e-3,
}.get(dtype, 2e-7)
# The worst-case rounding error from summing N values is O(N * eps).
# Use 2 * num_experts * eps as the tolerance floor so tests pass for
# large expert counts while remaining tight for small ones.
atol = max(base_atol, 2 * num_experts * eps)
rtol = max(base_rtol, 2 * num_experts * eps)
return atol, rtol


# Pytorch-based group topk
def group_limited_topk(
scores: torch.Tensor,
Expand Down Expand Up @@ -153,6 +179,13 @@ def run_comparison(
score_function,
enable_bias,
):
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
if group_topk is not None and num_groups is not None:
group_size = num_experts // num_groups
per_group_topk = topk // group_topk
if per_group_topk >= group_size:
pytest.skip(f"per-group topk ({per_group_topk}) >= group_size ({group_size})")
# Set some parameters
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values,
Expand Down Expand Up @@ -215,7 +248,8 @@ def run_comparison(
expert_bias=expert_bias_clone,
)

torch.testing.assert_close(probs, probs_fused)
atol, rtol = _get_tolerances(dtype, num_experts)
torch.testing.assert_close(probs, probs_fused, atol=atol, rtol=rtol)
torch.testing.assert_close(routing_map, routing_map_fused)

# Fake the loss
Expand All @@ -227,13 +261,13 @@ def run_comparison(
loss_fused.backward()

# Check the gradient
torch.testing.assert_close(logits.grad, logits_clone.grad)
torch.testing.assert_close(logits.grad, logits_clone.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("num_experts", [1024, 128, 32])
@pytest.mark.parametrize("topk", [4, 8, 16, 32])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
Expand Down Expand Up @@ -263,8 +297,8 @@ def test_topk_sigmoid(

@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("num_experts", [1024, 128, 32])
@pytest.mark.parametrize("topk", [4, 8, 16, 32])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
Expand Down Expand Up @@ -294,8 +328,8 @@ def test_topk_sqrtsoftplus(

@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("num_experts", [1024, 128, 32])
@pytest.mark.parametrize("topk", [4, 8, 16, 32])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
Expand Down Expand Up @@ -325,10 +359,12 @@ def test_topk_softmax(

@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [1, 4, 8])
@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32])
@pytest.mark.parametrize("topk", [1, 4, 8, 16, 32])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
Expand Down Expand Up @@ -364,22 +400,25 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
score_function=score_function,
)

torch.testing.assert_close(scores, scores_fused)
atol, rtol = _get_tolerances(dtype, num_experts)
torch.testing.assert_close(scores, scores_fused, atol=atol, rtol=rtol)
torch.testing.assert_close(routing_map, routing_map_fused)

loss = torch.sum(scores)
loss.backward()
loss_fused = torch.sum(scores_fused)
loss_fused.backward()

torch.testing.assert_close(logits.grad, logits_clone.grad)
torch.testing.assert_close(logits.grad, logits_clone.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32])
@pytest.mark.parametrize("topk", [4, 32])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
Expand Down Expand Up @@ -411,13 +450,14 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
coeff=coeff,
)

torch.testing.assert_close(aux_loss, aux_loss_fused)
atol, rtol = _get_tolerances(dtype, num_experts)
torch.testing.assert_close(aux_loss, aux_loss_fused, atol=atol, rtol=rtol)

# Backward
aux_loss.backward()
aux_loss_fused.backward()

torch.testing.assert_close(probs.grad, probs_clone.grad)
torch.testing.assert_close(probs.grad, probs_clone.grad, atol=atol, rtol=rtol)


def profile_topk_softmax(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace transformer_engine {
namespace fused_router {

template <typename DataType>
template <typename DataType, TopkFuncType TopkFunc = TopkFuncType::Naive>
__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens,
int num_experts, int topk,
int score_function, float *scores,
Expand Down Expand Up @@ -123,7 +123,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
* Section: Topk
* Get the topk indices
*/
naive_topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id);
topk_and_mask<TopkFunc>(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id);
__syncwarp();

// Write the routing_map to the output tensor
Expand All @@ -149,10 +149,26 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits
+ topk * num_token_per_block * sizeof(CompType) // topk_logits
+ topk * num_token_per_block * sizeof(int); // topk_indices
fused_score_for_moe_aux_loss_forward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
check_shared_memory_capacity_num_experts(shared_memory_size, num_experts);
// Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float;
// switch at K=16 where naive O(K^2*E) starts to dominate
if (topk < 16) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
} else {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
}
NVTE_CHECK_CUDA(cudaGetLastError());
}

Expand Down Expand Up @@ -305,6 +321,10 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher(
+
num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(CompType); // comp_buf
check_shared_memory_capacity_num_experts(shared_memory_size, num_experts);
NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel<DataType>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_memory_size));
fused_score_for_moe_aux_loss_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@

#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"

namespace transformer_engine {
namespace fused_router {

template <typename DataType, typename BiasType>
template <typename DataType, typename BiasType, TopkFuncType TopkFunc = TopkFuncType::Naive>
__global__ void fused_topk_with_score_function_forward_kernel(
const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
Expand Down Expand Up @@ -146,7 +145,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
int group_size = num_experts / num_groups;
// Top2
for (int i = 0; i < num_groups; i++) {
naive_topk_and_mask(
topk_and_mask<TopkFunc>(
/*scores ptr = */ scores + i * group_size,
/*data size = */ group_size,
/*topk = */ topk / group_topk,
Expand All @@ -166,7 +165,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
}

// select the topk groups
naive_topk_and_mask(
topk_and_mask<TopkFunc>(
/*scores ptr = */ group_scores,
/*data size = */ num_groups,
/*topk = */ group_topk,
Expand All @@ -183,10 +182,10 @@ __global__ void fused_topk_with_score_function_forward_kernel(
}
}
__syncwarp();
naive_topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id);
topk_and_mask<TopkFunc>(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id);

} else {
naive_topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id);
topk_and_mask<TopkFunc>(scores, num_experts, topk, topk_indices, topk_scores, lane_id);
}
__syncwarp();

Expand Down Expand Up @@ -254,10 +253,26 @@ void fused_topk_with_score_function_forward_kernel_launcher(
shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores
shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores
}
fused_topk_with_score_function_forward_kernel<DataType, BiasType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
check_shared_memory_capacity_num_experts(shared_memory_size, num_experts);
// Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float;
// switch at K=16 where naive O(K^2*E) starts to dominate
if (topk < 16) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
fused_topk_with_score_function_forward_kernel<DataType, BiasType, TopkFuncType::Naive>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
fused_topk_with_score_function_forward_kernel<DataType, BiasType, TopkFuncType::Naive>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
} else {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
fused_topk_with_score_function_forward_kernel<DataType, BiasType, TopkFuncType::Radix>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
fused_topk_with_score_function_forward_kernel<DataType, BiasType, TopkFuncType::Radix>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
}
NVTE_CHECK_CUDA(cudaGetLastError());
}

Expand Down Expand Up @@ -467,6 +482,10 @@ void fused_topk_with_score_function_backward_kernel_launcher(
num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(CompType) // comp_buf
+ num_experts * num_token_per_block * sizeof(bool); // routing_map
check_shared_memory_capacity_num_experts(shared_memory_size, num_experts);
NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel<DataType>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_memory_size));
fused_topk_with_score_function_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,
Expand Down
Loading
Loading