diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 36c09060ed..274a35b81d 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -17,6 +17,30 @@ torch.cuda.manual_seed(seed) +def _get_tolerances(dtype: torch.dtype, num_experts: int): + """Return (atol, rtol) scaled by the number of experts. + + With many experts the fused and reference kernels accumulate + floating-point reductions (e.g. normalization sums) in different + orders, causing O(num_experts * machine_eps) rounding divergence. + Scale the default tolerances accordingly so that small expert + counts keep tight checks while large counts (1024+) get the + headroom they need. + """ + # Default tolerances for torch.testing.assert_close + base_atol, base_rtol = 1e-5, 1.3e-6 + # TODO: account for fp16, bf16 as dtype + if dtype != torch.float32: + raise NotImplementedError("tolerances implemented for fp32 only") + eps = 2e-7 + # The worst-case rounding error from summing N values is O(N * eps). + # Use 2 * num_experts * eps as the tolerance floor so tests pass for + # large expert counts while remaining tight for small ones. + atol = max(base_atol, 2 * num_experts * eps) + rtol = max(base_rtol, 2 * num_experts * eps) + return atol, rtol + + # Pytorch-based group topk def group_limited_topk( scores: torch.Tensor, @@ -153,6 +177,13 @@ def run_comparison( score_function, enable_bias, ): + if topk >= num_experts: + pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") + if group_topk is not None and num_groups is not None: + group_size = num_experts // num_groups + per_group_topk = topk // group_topk + if per_group_topk >= group_size: + pytest.skip(f"per-group topk ({per_group_topk}) >= group_size ({group_size})") # Set some parameters if score_function in ("sigmoid", "sqrtsoftplus"): # Construct logits with a narrow range to avoid very small activation values, @@ -215,7 +246,8 @@ def run_comparison( expert_bias=expert_bias_clone, ) - torch.testing.assert_close(probs, probs_fused) + atol, rtol = _get_tolerances(dtype, num_experts) + torch.testing.assert_close(probs, probs_fused, atol=atol, rtol=rtol) torch.testing.assert_close(routing_map, routing_map_fused) # Fake the loss @@ -227,13 +259,13 @@ def run_comparison( loss_fused.backward() # Check the gradient - torch.testing.assert_close(logits.grad, logits_clone.grad) + torch.testing.assert_close(logits.grad, logits_clone.grad, atol=atol, rtol=rtol) @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) -@pytest.mark.parametrize("num_experts", [128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("num_experts", [1024, 128, 32]) +@pytest.mark.parametrize("topk", [4, 8, 16, 32]) @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) @pytest.mark.parametrize("enable_bias", [True, False]) @@ -263,8 +295,8 @@ def test_topk_sigmoid( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) -@pytest.mark.parametrize("num_experts", [128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("num_experts", [1024, 128, 32]) +@pytest.mark.parametrize("topk", [4, 8, 16, 32]) @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) @pytest.mark.parametrize("enable_bias", [True, False]) @@ -294,8 +326,8 @@ def test_topk_sqrtsoftplus( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) -@pytest.mark.parametrize("num_experts", [128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("num_experts", [1024, 128, 32]) +@pytest.mark.parametrize("topk", [4, 8, 16, 32]) @pytest.mark.parametrize("use_pre_softmax", [True, False]) @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) @@ -325,10 +357,12 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168]) -@pytest.mark.parametrize("num_experts", [256, 128, 32]) -@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32]) +@pytest.mark.parametrize("topk", [1, 4, 8, 16, 32]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): + if topk >= num_experts: + pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") if score_function in ("sigmoid", "sqrtsoftplus"): # Construct logits with a narrow range to avoid very small activation values offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 @@ -364,7 +398,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f score_function=score_function, ) - torch.testing.assert_close(scores, scores_fused) + atol, rtol = _get_tolerances(dtype, num_experts) + torch.testing.assert_close(scores, scores_fused, atol=atol, rtol=rtol) torch.testing.assert_close(routing_map, routing_map_fused) loss = torch.sum(scores) @@ -372,14 +407,16 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f loss_fused = torch.sum(scores_fused) loss_fused.backward() - torch.testing.assert_close(logits.grad, logits_clone.grad) + torch.testing.assert_close(logits.grad, logits_clone.grad, atol=atol, rtol=rtol) @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) -@pytest.mark.parametrize("num_experts", [256, 128, 32]) -@pytest.mark.parametrize("topk", [4]) +@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32]) +@pytest.mark.parametrize("topk", [4, 32]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): + if topk >= num_experts: + pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") # Construct the special probs to avoid inf in the sigmoid function offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -411,13 +448,14 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): coeff=coeff, ) - torch.testing.assert_close(aux_loss, aux_loss_fused) + atol, rtol = _get_tolerances(dtype, num_experts) + torch.testing.assert_close(aux_loss, aux_loss_fused, atol=atol, rtol=rtol) # Backward aux_loss.backward() aux_loss_fused.backward() - torch.testing.assert_close(probs.grad, probs_clone.grad) + torch.testing.assert_close(probs.grad, probs_clone.grad, atol=atol, rtol=rtol) def profile_topk_softmax( diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index ebdcb293e0..4eb4240d7c 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -16,7 +16,7 @@ namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, float *scores, @@ -123,7 +123,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi * Section: Topk * Get the topk indices */ - naive_topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id); + topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id); __syncwarp(); // Write the routing_map to the output tensor @@ -149,10 +149,26 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits + topk * num_token_per_block * sizeof(CompType) // topk_logits + topk * num_token_per_block * sizeof(int); // topk_indices - fused_score_for_moe_aux_loss_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, - intermediate_output); + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; + // switch at K=16 where naive O(K^2*E) starts to dominate + if (topk < 16) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + fused_score_for_moe_aux_loss_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_score_for_moe_aux_loss_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output); + } else { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + fused_score_for_moe_aux_loss_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_score_for_moe_aux_loss_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output); + } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -305,6 +321,10 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( + num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd + num_experts * num_token_per_block * sizeof(CompType); // comp_buf + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); fused_score_for_moe_aux_loss_backward_kernel <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 1bed871de8..9f7a830546 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -10,13 +10,12 @@ #include "../common.h" #include "../util/logging.h" -#include "../utils.cuh" #include "utils.h" namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, @@ -146,7 +145,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( int group_size = num_experts / num_groups; // Top2 for (int i = 0; i < num_groups; i++) { - naive_topk_and_mask( + topk_and_mask( /*scores ptr = */ scores + i * group_size, /*data size = */ group_size, /*topk = */ topk / group_topk, @@ -166,7 +165,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( } // select the topk groups - naive_topk_and_mask( + topk_and_mask( /*scores ptr = */ group_scores, /*data size = */ num_groups, /*topk = */ group_topk, @@ -183,10 +182,10 @@ __global__ void fused_topk_with_score_function_forward_kernel( } } __syncwarp(); - naive_topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id); + topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id); } else { - naive_topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id); + topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id); } __syncwarp(); @@ -254,10 +253,26 @@ void fused_topk_with_score_function_forward_kernel_launcher( shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores } - fused_topk_with_score_function_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; + // switch at K=16 where naive O(K^2*E) starts to dominate + if (topk < 16) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + fused_topk_with_score_function_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_topk_with_score_function_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + } else { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + fused_topk_with_score_function_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_topk_with_score_function_forward_kernel + <<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -467,6 +482,10 @@ void fused_topk_with_score_function_backward_kernel_launcher( num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd + num_experts * num_token_per_block * sizeof(CompType) // comp_buf + num_experts * num_token_per_block * sizeof(bool); // routing_map + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); fused_topk_with_score_function_backward_kernel <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 372efdc490..08ad3d16a6 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -7,11 +7,26 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#include "../util/logging.h" +#include "../utils.cuh" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { namespace fused_router { +// Check if requested shared memory size exceeds device capacity. +// Throws an error with num_experts info to help users diagnose the issue. +inline void check_shared_memory_capacity_num_experts(size_t shared_memory_size, int num_experts) { + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + int max_smem_per_block; + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_block, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + NVTE_CHECK(shared_memory_size <= static_cast(max_smem_per_block), "Shared memory size (", + shared_memory_size, " bytes) exceeds device capacity (", max_smem_per_block, + " bytes). Try reducing num_experts (currently ", num_experts, ")."); +} + // Using FP32 to handle all the calculations. // Currently, only FP32 is supported because // 1. The score functions (sigmoid, softmax, sqrtsoftplus) are implemented in FP32. @@ -51,7 +66,7 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncT default_val = -std::numeric_limits::infinity(); } - // Some value is hanlded in local thread + // Some value is handled in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread CompType val = lane_id < data_size ? data_ptr[lane_id] : default_val; @@ -82,7 +97,7 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat default_val = -std::numeric_limits::infinity(); } - // Some value is hanlded in local thread + // Some value is handled in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread CompType val = lane_id < data_size && mask[lane_id] ? data_ptr[lane_id] : default_val; @@ -187,22 +202,233 @@ __device__ inline void apply_softmax_bwd_on_float(float *grad, float *fwd_output } __device__ inline void apply_softmax_on_float(float *scores, int data_size, int lane_id) { - // 1. compute the max of value - float max_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id); - // 2. value -> exp_value + // --- Pass 1: Online accumulation of max and sum_exp --- + float local_max = -std::numeric_limits::infinity(); + float local_sum = 0.0f; + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = expf(scores[i] - max_val); + float val = scores[i]; + if (val > local_max) { + // Rescale accumulated sum for the new max + local_sum *= expf(local_max - val); + local_max = val; + } + local_sum += expf(val - local_max); } - __syncwarp(); - // 3. compute the sum of exp_value - float sum_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id); - // 4. update the softmax value + + // Warp-level reduction of (max, sum_exp) across 32 lanes. + // When merging two lanes with (max_a, sum_a) and (max_b, sum_b): + // merged_max = max(max_a, max_b) + // merged_sum = sum_a * exp(max_a - merged_max) + sum_b * exp(max_b - merged_max) + // + // NaN guard: when data_size < 32, some lanes have (max=-inf, sum=0). + // Merging two such lanes computes expf(-inf - (-inf)) = expf(NaN) = NaN, + // and 0.0 * NaN = NaN in IEEE 754, contaminating valid lanes. + // Fix: treat -inf max as "no data" and skip the expf computation. +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + float other_max = warp_shuffle_xor(local_max, offset); + float other_sum = warp_shuffle_xor(local_sum, offset); + float new_max = fmaxf(local_max, other_max); + if (new_max > -std::numeric_limits::infinity()) { + // At least one side has real data; safe to compute expf differences + float my_scale = + (local_max > -std::numeric_limits::infinity()) ? expf(local_max - new_max) : 0.0f; + float other_scale = + (other_max > -std::numeric_limits::infinity()) ? expf(other_max - new_max) : 0.0f; + local_sum = local_sum * my_scale + other_sum * other_scale; + } + // else: both sides are -inf (no data), keep local_sum = 0 + local_max = new_max; + } + // After reduction, all lanes have the same (local_max, local_sum) + + // --- Pass 2: Normalize in-place --- + float inv_sum = 1.0f / local_sum; for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = scores[i] / sum_val; + scores[i] = expf(scores[i] - local_max) * inv_sum; } __syncwarp(); } +enum class TopkFuncType { + Naive = 0, + Radix = 1, +}; + +/******************************************************************************* + * radix_topk_and_mask — Warp-level radix-selection based top-K + * + * O(E) algorithm independent of K, adapted from PyTorch's radix selection. + * Uses 4-bit radix (16 buckets) → 8 passes for float32. + * + * Algorithm: + * Phase 1 — Radix selection (8 passes): + * Convert float scores to "order-preserving" uint32 (flip sign bit for + * positives, flip all bits for negatives). Then iterate 4 bits at a time + * from the MSB. Each pass: + * 1. Each of 32 threads counts elements per radix bucket that match the + * "desired" bit pattern found so far. + * 2. Warp-reduce the per-thread histograms (16 sums). + * 3. Scan buckets from largest to smallest to locate which bucket + * contains the K-th largest element. + * 4. Narrow the desired pattern by 4 bits. + * After 8 passes: the exact uint32 bit pattern of the K-th value is known. + * + * Phase 2 — Gather (single pass over E): + * Collect elements strictly greater than the K-th value (same uint order), + * then fill remaining slots with elements equal to the K-th value (ties + * broken by ascending index for determinism matching torch.topk). + * Write indices and scores to the output arrays. + * + * Tie-breaking: (value DESC, index ASC) — matches torch.topk behavior. + * + * Constraints: + * - 0 < topk <= data_size + * - No upper limit on topk or data_size (unlike v1's 128 cap) + * - scores must be in shared memory accessible by the warp + * + * Complexity: 9 × O(E/32) = O(E) per warp, independent of K. + ******************************************************************************/ + +__device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int topk, + int *topk_indices, CompType *topk_scores, int lane_id) { + // assert(topk > 0 && "naive_topk_and_mask_v2: topk must be positive"); + // assert(topk <= data_size && "naive_topk_and_mask_v2: topk exceeds data_size"); + + constexpr int RADIX_BITS = 4; + constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 16 buckets + constexpr int RADIX_MASK = RADIX_SIZE - 1; // 0xF + constexpr int NUM_PASSES = 32 / RADIX_BITS; // 8 passes for float32 + + // ========================================================================= + // Phase 1: Radix selection — find the bit pattern of the K-th largest value + // ========================================================================= + unsigned int desired = 0; // accumulated bit pattern of the K-th value + unsigned int desired_mask = 0; // bits determined so far + int k_remaining = topk; // how many more elements we need to skip + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int digit_pos = pass * RADIX_BITS; + + // Each thread counts elements per bucket that match the desired pattern + unsigned int counts[RADIX_SIZE]; +#pragma unroll + for (int b = 0; b < RADIX_SIZE; b++) { + counts[b] = 0; + } + + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + unsigned int u = float_to_ordered_uint(scores[i]); + // Check if this element matches the desired pattern on already-decided bits + if ((u & desired_mask) == desired) { + int bucket = (u >> digit_pos) & RADIX_MASK; + counts[bucket]++; + } + } + + // Warp-reduce each bucket count across all 32 lanes + unsigned int total_counts[RADIX_SIZE]; +#pragma unroll + for (int b = 0; b < RADIX_SIZE; b++) { + unsigned int c = warp_allreduce_sum(counts[b]); + total_counts[b] = c; // same value on all lanes after full reduction + } + + // Scan buckets from LARGEST digit value (15) to smallest (0). + // We're looking for the top-K largest, so we want the highest-valued + // bucket first. Accumulate counts until we find the bucket containing + // the k_remaining-th element. + int target_bucket = 0; + for (int b = RADIX_SIZE - 1; b >= 0; b--) { + unsigned int bc = total_counts[b]; + if (bc < static_cast(k_remaining)) { + // All elements in this bucket are in the top set; skip them + k_remaining -= bc; + } else { + // The K-th element is in this bucket + target_bucket = b; + break; + } + } + + // Update the desired pattern and mask + desired |= (static_cast(target_bucket) << digit_pos); + desired_mask |= (static_cast(RADIX_MASK) << digit_pos); + } + + // After all passes, `desired` holds the exact ordered-uint bit pattern of + // the K-th largest value, and `k_remaining` is the number of elements with + // that exact value that should be included in the top-K set. + // (k_remaining >= 1 unless all elements equal the K-th value boundary) + + // ========================================================================= + // Phase 2: Gather — collect top-K elements into output arrays + // ========================================================================= + // Two sub-passes over the data: + // Pass A: Collect all elements strictly greater than the K-th value. + // Pass B: Collect elements equal to the K-th value (up to k_remaining), + // in ascending index order for deterministic tie-breaking. + // + // Since the warp processes indices in strided order, we need a warp-level + // prefix sum to assign output positions without conflicts. + + // --- Pass A: elements strictly greater than K-th value --- + // Use a warp-wide running counter for output position. + int write_pos = 0; // shared across warp via __shfl_sync + + for (int base = 0; base < data_size; base += kThreadsPerWarp) { + int i = base + lane_id; + bool valid = (i < data_size); + + unsigned int u = valid ? float_to_ordered_uint(scores[i]) : 0; + bool is_greater = valid && (u > desired); + + // Warp ballot to count how many lanes have a qualifying element + unsigned int ballot = __ballot_sync(0xffffffff, is_greater); + int lane_prefix = __popc(ballot & ((1u << lane_id) - 1)); // exclusive prefix + int total_qualifying = __popc(ballot); + + if (is_greater) { + int out_idx = write_pos + lane_prefix; + if (out_idx < topk) { + topk_indices[out_idx] = i; + topk_scores[out_idx] = scores[i]; + } + } + write_pos += total_qualifying; + } + + // --- Pass B: elements equal to K-th value (up to k_remaining) --- + int tie_remaining = k_remaining; // broadcast same value to all lanes + + for (int base = 0; base < data_size && tie_remaining > 0; base += kThreadsPerWarp) { + int i = base + lane_id; + bool valid = (i < data_size); + + unsigned int u = valid ? float_to_ordered_uint(scores[i]) : 0; + bool is_equal = valid && (u == desired); + + unsigned int ballot = __ballot_sync(0xffffffff, is_equal); + int lane_prefix = __popc(ballot & ((1u << lane_id) - 1)); + int total_equal = __popc(ballot); + + if (is_equal && lane_prefix < tie_remaining) { + int out_idx = write_pos + lane_prefix; + if (out_idx < topk) { + topk_indices[out_idx] = i; + topk_scores[out_idx] = scores[i]; + } + } + + int consumed = (total_equal < tie_remaining) ? total_equal : tie_remaining; + write_pos += consumed; + tie_remaining -= consumed; + } + + __syncwarp(); +} + __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, int *topk_indices, CompType *topk_scores, int lane_id) { // Check if the index is masked by the later iteration @@ -249,6 +475,16 @@ __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int } } +template +__device__ __forceinline__ void topk_and_mask(CompType *scores, int data_size, int topk, + int *topk_indices, CompType *topk_scores, + int lane_id) { + if constexpr (TopkFunc == TopkFuncType::Radix) + return radix_topk_and_mask(scores, data_size, topk, topk_indices, topk_scores, lane_id); + else + return naive_topk_and_mask(scores, data_size, topk, topk_indices, topk_scores, lane_id); +} + // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future #define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 8c50e83926..b322ce8fba 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -920,6 +920,32 @@ __device__ __forceinline__ void reciprocal(float *value_inv, const float *value_inv = __frcp_rn(value); } +// Convert float to an unsigned integer that preserves descending sort order. +// After conversion, a numerically larger float maps to a larger uint32. +__device__ __forceinline__ unsigned int float_to_ordered_uint(float f) { + unsigned int u = __float_as_uint(f); + // If sign bit is set (negative), flip all bits. + // If sign bit is clear (positive or +0), flip only the sign bit. + unsigned int mask = (u & 0x80000000u) ? 0xFFFFFFFFu : 0x80000000u; + return u ^ mask; +} + +// Convert back from ordered uint to float. +__device__ __forceinline__ float ordered_uint_to_float(unsigned int u) { + // Reverse the transformation: if MSB is set (was positive), flip sign bit. + // If MSB is clear (was negative), flip all bits. + unsigned int mask = (u & 0x80000000u) ? 0x80000000u : 0xFFFFFFFFu; + return __uint_as_float(u ^ mask); +} + +template +__device__ __forceinline__ T warp_allreduce_sum(T x) { + // Butterfly reduction +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) x += warp_shuffle_xor(x, offset); + return x; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// using fp8e4m3 = __nv_fp8_e4m3;