diff --git a/transformer_engine/common/fused_router/async_loader.h b/transformer_engine/common/fused_router/async_loader.h new file mode 100644 index 0000000000..32647f1545 --- /dev/null +++ b/transformer_engine/common/fused_router/async_loader.h @@ -0,0 +1,255 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ +#define TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ + +#include + +#include + +#include "../utils.cuh" +#include "utils.h" + +namespace transformer_engine { +namespace fused_router { + +// ============================================================================ +// Persistent kernel grid size computation +// ============================================================================ + +// Compute a persistent grid size: min(total_blocks_needed, SMs * max_blocks_per_SM). +// `kernel_func` is a pointer to the __global__ function. +// `block_size` is kThreadsPerBlock. +// `shmem_bytes` is the dynamic shared memory per block. +// `total_blocks` is ceil(num_tokens / tokens_per_block). +template +inline size_t compute_persistent_grid(KernelFunc kernel_func, int block_size, size_t shmem_bytes, + size_t total_blocks) { + int blocks_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel_func, + block_size, shmem_bytes)); + if (blocks_per_sm <= 0) { + return total_blocks; + } + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + int num_sms; + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device_id)); + + size_t max_resident = static_cast(num_sms) * blocks_per_sm; + return (total_blocks < max_resident) ? total_blocks : max_resident; +} + +// ============================================================================ +// Occupancy-aware double-buffer decision +// ============================================================================ + +// Decide whether to use 1 or 2 buffers based on shmem budget. +// `single_buf_shmem` is the per-buffer shmem for the async-loaded data. +// `other_shmem_bytes` is shmem for everything else (scratch, work buffers). +// Returns 1 or 2. Ensures at least kMinBlocksPerSM blocks can co-reside. +inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) { + constexpr int kMinBlocksPerSM = 4; + + size_t total_single = single_buf_shmem + other_shmem_bytes; + size_t total_double = 2 * single_buf_shmem + other_shmem_bytes; + + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + int max_smem_per_sm; + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_sm, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)); + + int blocks_double = (total_double > 0) ? static_cast(max_smem_per_sm / total_double) : 0; + int blocks_single = (total_single > 0) ? static_cast(max_smem_per_sm / total_single) : 0; + + if (blocks_double >= kMinBlocksPerSM) return 2; + if (blocks_single >= kMinBlocksPerSM) return 1; + // Neither option meets the minimum; prefer single buffer for occupancy + // (total_double >= total_single, so blocks_single >= blocks_double always). + return 1; +} + +// ============================================================================ +// Vectorized global store/fill helpers (using Vec<> from utils.cuh) +// ============================================================================ + +template +struct VecTraits { + static constexpr int kVecSize = (sizeof(T) <= 16) ? (16 / sizeof(T)) : 1; +}; + +// Vectorized store: write `count` elements from shmem/registers to global memory. +template +__device__ inline void vec_store_global(T *__restrict__ dst, const T *__restrict__ src, int count, + int lane_id) { + constexpr int kVecSize = VecTraits::kVecSize; + using VecType = typename BytesToType::Type; + + bool aligned = (reinterpret_cast(dst) % (sizeof(T) * kVecSize) == 0); + int aligned_count = (count / kVecSize) * kVecSize; + + if (aligned && aligned_count > 0) { + int vec_count = aligned_count / kVecSize; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + VecType v; + T *v_elts = reinterpret_cast(&v); +#pragma unroll + for (int e = 0; e < kVecSize; e++) { + v_elts[e] = src[vi * kVecSize + e]; + } + reinterpret_cast(dst)[vi] = v; + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + } +} + +// Vectorized fill: write `val` to `count` elements of global memory. +template +__device__ inline void vec_fill_global(T *__restrict__ dst, T val, int count, int lane_id) { + constexpr int kVecSize = VecTraits::kVecSize; + using VecType = typename BytesToType::Type; + + bool aligned = (reinterpret_cast(dst) % (sizeof(T) * kVecSize) == 0); + int aligned_count = (count / kVecSize) * kVecSize; + + if (aligned && aligned_count > 0) { + VecType v; + T *v_elts = reinterpret_cast(&v); +#pragma unroll + for (int e = 0; e < kVecSize; e++) { + v_elts[e] = val; + } + int vec_count = aligned_count / kVecSize; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + reinterpret_cast(dst)[vi] = v; + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = val; + } + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = val; + } + } +} + +// ============================================================================ +// cp.async wrappers — use hardware async copy on sm_80+, no-op on older archs. +// Always defined so callers don't need #if guards. +// ============================================================================ + +__device__ __forceinline__ void cp_async_16B(void *__restrict__ dst, const void *__restrict__ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_memcpy_async(dst, src, 16); +#else + // Scalar fallback — callers must not rely on this being async. + *static_cast(dst) = *static_cast(src); +#endif +} + +__device__ __forceinline__ void cp_async_commit() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_commit(); +#endif +} + +__device__ __forceinline__ void cp_async_wait_all() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_wait_prior(0); +#endif +} + +// ============================================================================ +// RawAsyncLoader — double-buffered loader storing data in original type +// +// Enables cp.async for ALL data types (bf16, fp16, fp32) since no type +// conversion is needed during the copy. The kernel reads from shmem and +// casts to CompType during compute. +// ============================================================================ + +template +class RawAsyncLoader { + public: + // Shmem size calculation (usable from both host and device). + static __host__ __device__ inline size_t shmem_bytes(int count, int num_warps, int num_buffers) { + return static_cast(num_buffers) * count * num_warps * sizeof(T); + } + + // Device-side construction. + __device__ RawAsyncLoader(T *buf_base, int warp_id, int count, int num_warps, int num_buffers) + : phase_(0), double_buf_(num_buffers == 2) { + int per_buffer = count * num_warps; + buf_[0] = buf_base + warp_id * count; + buf_[1] = (num_buffers == 2) ? buf_base + per_buffer + warp_id * count : buf_[0]; + } + + __device__ __forceinline__ T *current_buf() { return buf_[phase_]; } + __device__ __forceinline__ T *next_buf() { return buf_[phase_ ^ 1]; } + __device__ __forceinline__ void flip() { + if (double_buf_) phase_ ^= 1; + } + + // Async load into the NEXT buffer (for prefetching). + __device__ void start_load(const T *__restrict__ src, int count, int lane_id) { + raw_load(src, next_buf(), count, lane_id); + } + + // Load into the CURRENT buffer (for the first load before the main loop). + __device__ void load_current(const T *__restrict__ src, int count, int lane_id) { + raw_load(src, current_buf(), count, lane_id); + } + + // Wait for pending async loads to complete. + __device__ __forceinline__ void wait() { + cp_async_wait_all(); + __syncwarp(); + } + + private: + T *buf_[2]; + int phase_; + bool double_buf_; + + // Raw copy: global → shmem, no type conversion. + // Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs) + // when both pointers are 16-byte aligned, with a scalar tail / fallback. + __device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) { + constexpr int kBytesPerCopy = 16; + constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T); + + bool src_aligned = (reinterpret_cast(src) % kBytesPerCopy == 0); + bool dst_aligned = (reinterpret_cast(dst) % kBytesPerCopy == 0); + int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy; + + if (src_aligned && dst_aligned && aligned_count > 0) { + int vec_count = aligned_count / kEltsPerCopy; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy); + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + cp_async_commit(); + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + } + } +}; + +} // namespace fused_router +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 7e516af97b..b8026be3fd 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -63,8 +63,8 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const int warp_id = threadIdx.x / kThreadsPerWarp; const int lane_id = threadIdx.x % kThreadsPerWarp; if (warp_id == 0) { - CompType block_sum = warp_reduce_on_shmem(shmem_block, static_cast(blockDim.x), - ReduceFuncType::SUM, lane_id); + CompType block_sum = warp_reduce_on_shmem( + shmem_block, static_cast(blockDim.x), lane_id); if (lane_id == 0) { atomicAdd(&Coeff_buf[1], static_cast(block_sum * coeff)); } diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4eb4240d7c..549d8754a5 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -8,26 +8,28 @@ #include #include +#include + #include "../common.h" #include "../util/logging.h" #include "../utils.cuh" +#include "async_loader.h" #include "utils.h" namespace transformer_engine { namespace fused_router { +// ============================================================================= +// Simple aux_loss forward kernel — exact upstream structure (no async loader, +// no persistent grid, runtime score_function dispatch). +// ============================================================================= + template -__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, - int num_experts, int topk, - int score_function, float *scores, - bool *routing_map, - CompType *intermediate_output) { - /*** - * Section: Global Variables/Addresses init - * - Each warp is responsible for one token, and has own shared memory buffer. - * Then __syncwarp() is used instead of __syncthreads() - */ - // Used variables/addresses init +__global__ void fused_score_for_moe_aux_loss_forward_simple_kernel(const DataType *logits, + int num_tokens, int num_experts, + int topk, int score_function, + float *scores, bool *routing_map, + CompType *intermediate_output) { int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; @@ -36,29 +38,17 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi CompType *topk_logits_buf = reinterpret_cast(logits_buf + num_experts * num_token_per_block); int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); - // The address of buffers on the current warp CompType *local_logits = logits_buf + warp_id * num_experts; CompType *topk_logits = topk_logits_buf + warp_id * topk; int *topk_indices = topk_indices_buf + warp_id * topk; - /*** - * Section: Main Loop - * - Each warp is responsible for one token - */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; for (int round = blockIdx.x; round < total_round; round += gridDim.x) { int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token if (token_offset_cur_warp >= num_tokens) break; - /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the logits to shmem - */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the routing_map (num_experts) + // Clear the routing_map for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { routing_map[pos_offset + i] = false; if (score_function == 1) { @@ -72,47 +62,176 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __threadfence_block(); __syncwarp(); - /*** - * Section: Preprocess - * Possible preprocess the scores before the topk operation - * - Pre-softmax - * - Sigmoid - * - Sqrtsoftplus - * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 - * This is in-place scores update - */ - if (score_function == 1) { // score_function == 1 means softmax - // Apply softmax to the logits before the topk + // Preprocess: apply score function + if (score_function == 1) { apply_softmax_on_float(local_logits, num_experts, lane_id); __syncwarp(); - // Save the softmax output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } else if (score_function == 0) { // score_function == 0 means sigmoid - // Apply sigmoid to the logits + } else if (score_function == 0) { apply_sigmoid_on_float(local_logits, num_experts, lane_id); __syncwarp(); - // Save the sigmoid output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus - // First save the original logits for backward (needed for gradient computation) + } else if (score_function == 2) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits + intermediate_output[pos_offset + i] = local_logits[i]; } __syncwarp(); - // Apply sqrtsoftplus to the logits apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the output + __syncwarp(); - // Sigmoid/Sqrtsoftplus post-processing + // Sigmoid/Sqrtsoftplus post-processing: normalize if (score_function == 0 || score_function == 2) { auto sum_logits = - warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); + warp_reduce_on_shmem(local_logits, num_experts, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] /= (sum_logits + epsilon); + } + __syncwarp(); + } + + // Topk + topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id); + __syncwarp(); + + // Write outputs + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + } + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[pos_offset + i] = local_logits[i]; + } + __threadfence_block(); + __syncwarp(); + } +} + +// ============================================================================= +// Optimized aux_loss forward kernel — async loader, persistent grid. +// ============================================================================= + +template +__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, + int num_experts, int topk, + float *scores, bool *routing_map, + CompType *intermediate_output, + int num_buffers) { + /*** + * Section: Global Variables/Addresses init + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ char shmem_raw_aux[]; + + // Shmem layout: logits_raw (async) | logits_work | topk_scratch + char *shmem_ptr = shmem_raw_aux; + DataType *logits_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader loader(logits_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *logits_work_buf = reinterpret_cast(shmem_ptr); + shmem_ptr += num_experts * num_token_per_block * sizeof(CompType); + + CompType *topk_logits_buf = reinterpret_cast(shmem_ptr); + int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); + + // The address of buffers on the current warp + CompType *local_logits = logits_work_buf + warp_id * num_experts; + CompType *topk_logits = topk_logits_buf + warp_id * topk; + int *topk_indices = topk_indices_buf + warp_id * topk; + + /*** + * Section: Main Loop — persistent grid with double-buffered async load + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + loader.load_current(logits + first_token * num_experts, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + if (token_offset_cur_warp >= num_tokens) break; + + // Single-buffer: load current round here (no prefetch possible) + if (num_buffers == 1 && round != first_round) { + loader.load_current(logits + token_offset_cur_warp * num_experts, num_experts, lane_id); + } + + loader.wait(); + DataType *raw_logits = loader.current_buf(); + + // Prefetch next round (only when double-buffered) + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } + } + } + + /*** + * Section: Init buffer + Preprocess + * - Convert raw logits (DataType) → apply score function → save intermediate_output + * + * Fused into a single loop per score function where possible: + * score_function == 0 (sigmoid): convert, sigmoid, save → shmem + * score_function == 1 (softmax): convert → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): convert, save logits, sqrtsoftplus → shmem + */ + int pos_offset = token_offset_cur_warp * num_experts; + + if constexpr (ScoreFunc == 1) { // Softmax + // Apply softmax to all logits, save softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] = static_cast(raw_logits[i]); + } + __syncwarp(); + apply_softmax_on_float(local_logits, num_experts, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + } else if constexpr (ScoreFunc == 0) { // Sigmoid + // Fused: convert → sigmoid → save sigmoid output for backward → shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + float val = sigmoid_scalar(static_cast(raw_logits[i])); + intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward + local_logits[i] = val; + } + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus + // Fused: convert → save original logit for backward → sqrtsoftplus → shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + float logit = static_cast(raw_logits[i]); + intermediate_output[pos_offset + i] = logit; // Save original logits for backward + local_logits[i] = sqrtsoftplus_scalar(logit); + } + } + __syncwarp(); + + // Sigmoid/Sqrtsoftplus post-processing: normalize scores to sum to 1 + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] /= (sum_logits + epsilon); } @@ -127,15 +246,15 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __syncwarp(); // Write the routing_map to the output tensor + vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { routing_map[pos_offset + topk_indices[i]] = true; } // Write the scores to the output tensor - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[pos_offset + i] = local_logits[i]; - } - __threadfence_block(); + vec_store_global(scores + pos_offset, local_logits, num_experts, lane_id); __syncwarp(); + + loader.flip(); } } @@ -143,33 +262,75 @@ template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, float *scores, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { - // Meta data for the kernel + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits - + topk * num_token_per_block * sizeof(CompType) // topk_logits - + topk * num_token_per_block * sizeof(int); // topk_indices - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; - // switch at K=16 where naive O(K^2*E) starts to dominate - if (topk < 16) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, - intermediate_output); + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); + size_t scratch_shmem = + topk * num_token_per_block * sizeof(CompType) + topk * num_token_per_block * sizeof(int); + size_t other_shmem = scores_shmem + scratch_shmem; + size_t logits_single_buf = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(logits_single_buf, other_shmem); + size_t logits_raw_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + size_t shared_memory_size = logits_raw_shmem + other_shmem; + + auto launch = [&](auto kernel) { + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); + kernel<<>>( + logits, num_tokens, num_experts, topk, scores, routing_map, intermediate_output, + num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + // Dispatch: small topk uses the simple kernel (no async loader overhead); + // large topk uses the optimized kernel with radix selection + persistent grid. + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). + if (topk < get_radix_topk_threshold()) { + // Simple path: exact upstream structure — no async loader, no persistent grid. + check_shared_memory_capacity_num_experts(other_shmem, num_experts); + + auto launch_simple = [&](auto kernel) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); + kernel<<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + launch_simple( + fused_score_for_moe_aux_loss_forward_simple_kernel); } else { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, - intermediate_output); + // Optimized path: async loader + persistent grid + radix topk. + NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, + "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, + " (packed 8-bit histogram), got ", num_experts, "."); + switch (score_function) { + case 0: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 1: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 2: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } - NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, @@ -185,128 +346,154 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, reinterpret_cast(intermediate_output.data.dptr), stream);); } -template +// Backward: grad_scores + intermediate_output → grad_logits. +// No routing_map — all experts participate (unlike topk backward). +// Double-buffered cp.async loads both inputs. Two-pass fused approach. +// +// Shmem layout (B = num_buffers, W = warps/block): +// grad_buf: B × E × W × sizeof(CompType) — async-loaded grad +// act_buf: B × E × W × sizeof(CompType) — async-loaded activations + +template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, - int topk, int score_function, - DataType *grad_logits) { + DataType *grad_logits, + int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ - // Used variables/addresses init int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem[]; - CompType *grad_scores_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus - CompType *act_from_fwd_buf = grad_scores_buf + num_experts * num_token_per_block; - CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; - // The address of buffers on the current warp - CompType *local_grad = grad_scores_buf + warp_id * num_experts; - CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - CompType *local_comp_buf = comp_buf + warp_id * num_experts; + + extern __shared__ char shmem_aux_bwd[]; + char *shmem_ptr = shmem_aux_bwd; + + CompType *grad_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *act_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { - int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token - if (token_offset_cur_warp >= num_tokens) break; + int first_round = blockIdx.x; + if (first_round >= total_round) return; - /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the dgrad/output_from_fwd to shmem - */ - int pos_offset = token_offset_cur_warp * num_experts; - // Load the dgrad/output_from_fwd to shmem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = grad_scores[pos_offset + i]; - local_act_from_fwd[i] = intermediate_output[pos_offset + i]; + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + int pos = first_token * num_experts; + grad_loader.load_current(grad_scores + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); } - __threadfence_block(); - __syncwarp(); + } - /*** - * Section: Backward of ops before the topk - * - Pre-softmax bwd - * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 - * - Sigmoid bwd - * - Sqrtsoftplus bwd - * - Write the grad_logits to the global mem - */ - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits - // (needed for both post-processing bwd and activation bwd, compute once here) - // For sqrtsoftplus, intermediate_output stores original logits - if (score_function == 2) { - // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_act_from_fwd[i]; - } - __syncwarp(); - apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); - __syncwarp(); + for (int round = first_round; round < total_round; round += gridDim.x) { + int token_idx = round * num_token_per_block + warp_id; + if (token_idx >= num_tokens) break; + int pos = token_idx * num_experts; + + if (num_buffers == 1 && round != first_round) { + grad_loader.load_current(grad_scores + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); } - // Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward) - if (score_function == 0 || score_function == 2) { - // Select the correct activation output buffer: - // - Sigmoid: local_act_from_fwd already contains sigmoid output - // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above - CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; - - auto sum_fwd_input = - warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers - CompType local_sum_Output_x_Grad = 0.0; - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_sum_Output_x_Grad += local_grad[i] * act_output[i]; - } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); - } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = local_grad[i] / (sum_fwd_input + epsilon) - - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); + grad_loader.wait(); + act_loader.wait(); + + CompType *raw_grad = grad_loader.current_buf(); + CompType *raw_act = act_loader.current_buf(); + + // Prefetch next round only when double-buffered; single-buffer loads above. + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_scores + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + } } - __syncwarp(); } - // Pre-softmax bwd - if (score_function == 1) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, - num_experts, lane_id); - __syncwarp(); + /*** + * Section: Pass 1 — Reduction + * Accumulate warp-level sums needed by the backward passes: + * sigmoid/sqrtsoftplus: sum_act, sum_grad_act for normalization bwd + * softmax: sum_output_x_grad = Σ(grad * softmax_output) + * + * For sqrtsoftplus, intermediate_output stores original logits, so we + * recompute sqrtsoftplus(x) on the fly to get the activation value. + */ + CompType sum_act = 0.0f; + CompType sum_grad_act = 0.0f; + CompType sum_output_x_grad = 0.0f; + + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + CompType g = static_cast(raw_grad[i]); + CompType act = raw_act[i]; + if constexpr (ScoreFunc == 0) { // Sigmoid + // act = sigmoid output; accumulate over all experts + sum_act += act; + sum_grad_act += g * act; + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus + // act = original logit; recompute sqrtsoftplus to get activation + CompType v = sqrtsoftplus_scalar(act); + sum_act += v; + sum_grad_act += g * v; + } else if constexpr (ScoreFunc == 1) { // Softmax + // act = softmax output + sum_output_x_grad += g * act; + } } - // Sigmoid bwd - if (score_function == 0) { - apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); - __syncwarp(); + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + sum_act = warp_allreduce_sum(sum_act); + sum_grad_act = warp_allreduce_sum(sum_grad_act); } - // Sqrtsoftplus bwd - // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier - // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) - if (score_function == 2) { - apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, - lane_id); - __syncwarp(); + if constexpr (ScoreFunc == 1) { + sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } - // Write the grad_logits to the global mem + + /*** + * Section: Pass 2 — Element-wise gradient + * Compute per-element gradient using the warp-level sums from Pass 1. + * Applies backward ops in reverse of forward order: + * sigmoid: normalization bwd → sigmoid bwd + * sqrtsoftplus: normalization bwd → sqrtsoftplus bwd + * softmax: softmax bwd + * Write the grad_logits to the global mem + */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = static_cast(local_grad[i]); + CompType g = static_cast(raw_grad[i]); + CompType act = raw_act[i]; + + if constexpr (ScoreFunc == 0) { // Sigmoid bwd + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + g = sigmoid_bwd_scalar(g, act); + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus bwd + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); + } else if constexpr (ScoreFunc == 1) { // Softmax bwd + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } + + grad_logits[pos + i] = static_cast(g); } - __syncwarp(); + + grad_loader.flip(); + act_loader.flip(); } } @@ -314,22 +501,44 @@ template void fused_score_for_moe_aux_loss_backward_kernel_launcher( const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { - // Meta data for the kernel + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_scores - + - num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(CompType); // comp_buf - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_memory_size)); - fused_score_for_moe_aux_loss_backward_kernel - <<>>( - intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, - grad_logits); - NVTE_CHECK_CUDA(cudaGetLastError()); + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t single_buf_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(single_buf_shmem, 0); + size_t shmem_bytes = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); + + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + intermediate_output, grad_scores, num_tokens, num_experts, grad_logits, num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + switch (score_function) { + case 0: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + case 1: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + case 2: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 9f7a830546..4384104688 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -8,25 +8,29 @@ #include #include +#include + #include "../common.h" #include "../util/logging.h" +#include "async_loader.h" #include "utils.h" namespace transformer_engine { namespace fused_router { +// ============================================================================= +// Simple forward kernel — exact upstream structure (no async loader, no +// persistent grid, runtime score_function dispatch). Faster for small topk +// due to lower scheduling overhead and separate load/compute/store phases. +// ============================================================================= + template -__global__ void fused_topk_with_score_function_forward_kernel( - const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, - int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, - CompType *intermediate_output) { - /*** - * Section: Global Variables/Addresses init - * - Each warp is responsible for one token, and has own shared memory buffer. - * Then __syncwarp() is used instead of __syncthreads() - */ - // Used variables/addresses init +__global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num_tokens, + int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, + float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, + bool *routing_map, CompType *intermediate_output) { int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; @@ -42,29 +46,17 @@ __global__ void fused_topk_with_score_function_forward_kernel( } else { topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); } - // The address of buffers on the current warp CompType *scores = scores_buf + warp_id * num_experts; CompType *topk_scores = topk_scores_buf + warp_id * topk; CompType *masked_scores = masked_scores_buf + warp_id * num_experts; CompType *group_scores = group_scores_buf + warp_id * num_groups; int *topk_indices = topk_indices_buf + warp_id * topk; - /*** - * Section: Main Loop - * - Each warp is responsible for one token - */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; for (int round = blockIdx.x; round < total_round; round += gridDim.x) { int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token if (token_offset_cur_warp >= num_tokens) break; - /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the logits to shmem - */ int pos_offset = token_offset_cur_warp * num_experts; // Clear the probs/routing_map (num_experts) for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { @@ -87,44 +79,30 @@ __global__ void fused_topk_with_score_function_forward_kernel( __threadfence_block(); __syncwarp(); - /*** - * Section: Preprocess - * Possible preprocess the scores before the topk operation - * - Pre-softmax - * - Sigmoid - * - Sqrtsoftplus - * - Expert bias - * This is in-place scores update - */ - if (use_pre_softmax && score_function == 1) { // score_function == 1 means softmax - // Apply softmax to the logits before the topk + // Preprocess: apply score function in-place on shmem + if (use_pre_softmax && score_function == 1) { apply_softmax_on_float(scores, num_experts, lane_id); __syncwarp(); - // Save the softmax output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } - } else if (score_function == 0) { // score_function == 0 means sigmoid - // Apply sigmoid to the logits + } else if (score_function == 0) { apply_sigmoid_on_float(scores, num_experts, lane_id); __syncwarp(); - // Save the sigmoid output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } - } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus - // First save the original logits for backward (needed for sqrtsoftplus gradient computation) + } else if (score_function == 2) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = scores[i]; // Save original logits + intermediate_output[pos_offset + i] = scores[i]; } __syncwarp(); - // Apply sqrtsoftplus to the logits apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the output + __syncwarp(); - // Expert bias is only used at the sigmoid/sqrtsoftplus case + // Expert bias (sigmoid/sqrtsoftplus only) if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { scores[i] += static_cast(expert_bias[i]); @@ -132,6 +110,236 @@ __global__ void fused_topk_with_score_function_forward_kernel( __syncwarp(); } + // Topk selection + if (group_topk > 0) { + int group_size = num_experts / num_groups; + for (int i = 0; i < num_groups; i++) { + topk_and_mask(scores + i * group_size, group_size, topk / group_topk, + topk_indices, topk_scores, lane_id); + __syncwarp(); + if (lane_id == 0) { + CompType tmp = 0.0; + for (int j = 0; j < topk / group_topk; j++) { + tmp = tmp + topk_scores[j]; + } + group_scores[i] = tmp; + } + __syncwarp(); + } + topk_and_mask(group_scores, num_groups, group_topk, topk_indices, topk_scores, + lane_id); + __syncwarp(); + for (int i = 0; i < group_topk; i++) { + int st = topk_indices[i] * group_size; + int ed = st + group_size; + for (int j = st + lane_id; j < ed; j += kThreadsPerWarp) { + masked_scores[j] = scores[j]; + } + } + __syncwarp(); + topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id); + } else { + topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id); + } + __syncwarp(); + + // Postprocess: revert bias, softmax, normalization + if (expert_bias && (score_function == 0 || score_function == 2)) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + } + __syncwarp(); + } + + if (!use_pre_softmax && score_function == 1) { + apply_softmax_on_float(topk_scores, topk, lane_id); + __syncwarp(); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + } + __syncwarp(); + } + + if (score_function == 0 || score_function == 2) { + if (topk > 1) { + CompType sum_scores = + warp_reduce_on_shmem(topk_scores, topk, lane_id); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); + } + } + __syncwarp(); + } + + // Write outputs + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; + } + __threadfence_block(); + __syncwarp(); + } +} + +// ============================================================================= +// Optimized forward kernel — async loader, persistent grid, double buffering. +// Used for larger topk where radix selection and compute dominate. +// ============================================================================= + +template +__global__ void fused_topk_with_score_function_forward_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, const BiasType *expert_bias, + DataType *probs, bool *routing_map, CompType *intermediate_output, int num_buffers) { + /*** + * Section: Global Variables/Addresses init + * - Each warp is responsible for one token, and has own shared memory buffer. + * Then __syncwarp() is used instead of __syncthreads() + */ + // Used variables/addresses init + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ char shmem_raw[]; + + // Shmem layout: logits_raw (async) | scores | topk_scratch (+ group_topk scratch) + char *shmem_ptr = shmem_raw; + DataType *logits_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader loader(logits_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *scores_buf = reinterpret_cast(shmem_ptr); + shmem_ptr += num_experts * num_token_per_block * sizeof(CompType); + + CompType *topk_scores_buf = reinterpret_cast(shmem_ptr); + CompType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; + int *topk_indices_buf = nullptr; + if (group_topk > 0) { + masked_scores_buf = topk_scores_buf + topk * num_token_per_block; + group_scores_buf = masked_scores_buf + num_experts * num_token_per_block; + topk_indices_buf = reinterpret_cast(group_scores_buf + num_groups * num_token_per_block); + } else { + topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); + } + // The address of buffers on the current warp + CompType *scores = scores_buf + warp_id * num_experts; + CompType *topk_scores = topk_scores_buf + warp_id * topk; + CompType *masked_scores = + (masked_scores_buf != nullptr) ? masked_scores_buf + warp_id * num_experts : nullptr; + CompType *group_scores = + (group_scores_buf != nullptr) ? group_scores_buf + warp_id * num_groups : nullptr; + int *topk_indices = topk_indices_buf + warp_id * topk; + + /*** + * Section: Main Loop — persistent grid with double-buffered async load + * - Each warp is responsible for one token + */ + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + loader.load_current(logits + first_token * num_experts, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + if (token_offset_cur_warp >= num_tokens) break; + + // Single-buffer: load current round here (no prefetch possible) + if (num_buffers == 1 && round != first_round) { + loader.load_current(logits + token_offset_cur_warp * num_experts, num_experts, lane_id); + } + + // Wait for current round's async load to complete + loader.wait(); + DataType *raw_logits = loader.current_buf(); + + // Prefetch next round (only when double-buffered, overlaps with compute) + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } + } + } + + /*** + * Section: Init buffer + Preprocess + * - Clear the global output buffers (probs, routing_map) + * - Convert raw logits (DataType) → apply score function → save intermediate → add bias + * + * Fused into a single loop per score function where possible: + * score_function == 0 (sigmoid): convert, sigmoid, save, +bias → scores + * score_function == 1 (softmax): convert → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): convert, save logits, sqrtsoftplus, +bias → scores + * + * Expert bias is only used with sigmoid/sqrtsoftplus and is fused into + * the same loop that computes the score. + */ + int pos_offset = token_offset_cur_warp * num_experts; + + // Clear the probs/routing_map (num_experts) + vec_fill_global(probs + pos_offset, static_cast(0.0f), num_experts, lane_id); + vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); + + if constexpr (ScoreFunc == 1) { // Softmax + if (use_pre_softmax) { + // Pre-softmax: apply softmax to all logits before topk, save for backward. + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = static_cast(raw_logits[i]); + } + __syncwarp(); + apply_softmax_on_float(scores, num_experts, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } else { + // Post-softmax: softmax applied after topk; init intermediate to -inf + // (only the topk positions will be filled in the postprocess section). + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = static_cast(raw_logits[i]); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } + } + } else if constexpr (ScoreFunc == 0) { // Sigmoid + // Fused: convert → sigmoid → save sigmoid output for backward → add bias → scores + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + float val = sigmoid_scalar(static_cast(raw_logits[i])); + intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward + if (expert_bias) val += static_cast(expert_bias[i]); + scores[i] = val; + } + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus + // Fused: convert → save original logit for backward → sqrtsoftplus → add bias → scores + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + float logit = static_cast(raw_logits[i]); + intermediate_output[pos_offset + i] = logit; // Save original logits for backward + float val = sqrtsoftplus_scalar(logit); + if (expert_bias) val += static_cast(expert_bias[i]); + scores[i] = val; + } + } + __syncwarp(); + + // If group_topk > 0, init the masked_scores to -inf + if (group_topk > 0) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + masked_scores[i] = -std::numeric_limits::infinity(); + } + __syncwarp(); + } + /*** * Section: Topk * Get the topk indices @@ -198,29 +406,33 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - Write the result with scaling_factor */ // Revert Expert bias from the topk scores - if (expert_bias && (score_function == 0 || score_function == 2)) { - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + if (expert_bias) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + } + __syncwarp(); } - __syncwarp(); } - // score_function == 1 means softmax - if (!use_pre_softmax && score_function == 1) { - // Apply softmax to the topk logits - apply_softmax_on_float(topk_scores, topk, lane_id); - __syncwarp(); - // Save the softmax output for backward - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + if constexpr (ScoreFunc == 1) { + if (!use_pre_softmax) { + // Apply softmax to the topk logits + apply_softmax_on_float(topk_scores, topk, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + } + __syncwarp(); } - __syncwarp(); } // Sigmoid/Sqrtsoftplus post-processing when topk > 1 - if (score_function == 0 || score_function == 2) { + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { if (topk > 1) { - CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); + CompType sum_scores = + warp_reduce_on_shmem(topk_scores, topk, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); } @@ -233,8 +445,9 @@ __global__ void fused_topk_with_score_function_forward_kernel( routing_map[pos_offset + topk_indices[i]] = true; probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; } - __threadfence_block(); __syncwarp(); + + loader.flip(); } } @@ -244,36 +457,85 @@ void fused_topk_with_score_function_forward_kernel_launcher( int num_groups, int group_topk, float scaling_factor, int score_function, const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); + if (group_topk > 0) { + NVTE_CHECK(topk % group_topk == 0, "topk must be divisible by group_topk, got topk=", topk, + " group_topk=", group_topk); + } size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // scores - + topk * num_token_per_block * sizeof(CompType) // topk_scores - + topk * num_token_per_block * sizeof(int); // topk_indices + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); + size_t scratch_shmem = + topk * num_token_per_block * sizeof(CompType) + topk * num_token_per_block * sizeof(int); if (group_topk > 0) { - shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores - shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores + scratch_shmem += num_groups * num_token_per_block * sizeof(CompType); + scratch_shmem += num_experts * num_token_per_block * sizeof(CompType); } - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; - // switch at K=16 where naive O(K^2*E) starts to dominate - if (topk < 16) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + size_t other_shmem = scores_shmem + scratch_shmem; + size_t logits_single_buf = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(logits_single_buf, other_shmem); + size_t logits_raw_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + size_t shared_memory_size = logits_raw_shmem + other_shmem; + + auto launch = [&](auto kernel) { + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, expert_bias, probs, routing_map, intermediate_output, num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + // Dispatch: small topk uses the simple kernel (no async loader overhead); + // large topk uses the optimized kernel with radix selection + persistent grid. + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). + if (topk < get_radix_topk_threshold()) { + // Simple path: no async loader, no persistent grid — lower overhead for small K. + // Uses the exact upstream kernel structure with runtime score_function dispatch. + check_shared_memory_capacity_num_experts(other_shmem, num_experts); + + auto launch_simple = [&](auto kernel) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + launch_simple(fused_topk_forward_simple_kernel); } else { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + // Optimized path: async loader + persistent grid + radix topk. + NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, + "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, + " (packed 8-bit histogram), got ", num_experts, "."); + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } - NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, @@ -295,177 +557,210 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, reinterpret_cast(intermediate_output.data.dptr), stream););); } -template +// Backward: grad_probs + intermediate_output + routing_map → grad_logits. +// +// Double-buffered cp.async loads all 3 inputs in original types. Two-pass +// fused approach (eliminates the comp_buf shmem buffer): +// Pass 1 (reduction): accumulate warp-level sums needed by normalization/softmax bwd. +// Pass 2 (element-wise): compute per-element gradient and write to global memory. +// +// Shmem layout (B = num_buffers, W = warps/block): +// grad_raw: B × E × W × sizeof(DataType) — async-loaded grad +// act_buf: B × E × W × sizeof(CompType) — async-loaded activations +// mask_buf: B × E × W × sizeof(bool) — async-loaded routing mask + +template __global__ void fused_topk_with_score_function_backward_kernel( - // Inputs tensor const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, - // Other parameters int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, - int score_function, - // Output tensor - DataType *grad_logits) { + DataType *grad_logits, int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ - // Used variables/addresses init int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem[]; - CompType *grad_probs_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus - CompType *act_from_fwd_buf = grad_probs_buf + num_experts * num_token_per_block; - CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; - // To store the routing_map from the fwd - bool *routing_map_buf = reinterpret_cast(comp_buf + num_experts * num_token_per_block); - // The address of buffers on the current warp - CompType *local_grad = grad_probs_buf + warp_id * num_experts; - CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - CompType *local_comp_buf = comp_buf + warp_id * num_experts; - bool *local_routing_map = routing_map_buf + warp_id * num_experts; + + extern __shared__ char shmem_bwd[]; + char *shmem_ptr = shmem_bwd; + + DataType *grad_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *act_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + bool *mask_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { - int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token - if (token_offset_cur_warp >= num_tokens) break; + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + int pos = first_token * num_experts; + grad_loader.load_current(grad_probs + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + mask_loader.load_current(routing_map + pos, num_experts, lane_id); + } + } - /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the dgrad/output_from_fwd to shmem - */ - int pos_offset = token_offset_cur_warp * num_experts; - // Load the dgrad/output_from_fwd to shmem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = grad_probs[pos_offset + i]; - local_act_from_fwd[i] = intermediate_output[pos_offset + i]; - local_routing_map[i] = routing_map[pos_offset + i]; + for (int round = first_round; round < total_round; round += gridDim.x) { + int token_idx = round * num_token_per_block + warp_id; + if (token_idx >= num_tokens) break; + int pos = token_idx * num_experts; + + if (num_buffers == 1 && round != first_round) { + grad_loader.load_current(grad_probs + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + mask_loader.load_current(routing_map + pos, num_experts, lane_id); } - __threadfence_block(); - __syncwarp(); /*** - * Section: Backward of ops after the topk - * - Backward of the used scaling_factor - * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 - * - Softmax bwd if use_pre_softmax is false + * Section: Wait for async load + prefetch next round */ - // Backward of the used scaling_factor - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_grad[i] = local_grad[i] * scaling_factor; + grad_loader.wait(); + act_loader.wait(); + mask_loader.wait(); + + DataType *raw_grad = grad_loader.current_buf(); + CompType *local_act = act_loader.current_buf(); + bool *local_mask = mask_loader.current_buf(); + + // Prefetch next round only when double-buffered; single-buffer loads above. + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_probs + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + mask_loader.start_load(routing_map + next_pos, num_experts, lane_id); + } } } - __syncwarp(); - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits - // (needed for both post-processing bwd and activation bwd, compute once here) - // For sqrtsoftplus, intermediate_output stores original logits - if (score_function == 2) { - // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_act_from_fwd[i]; - } - __syncwarp(); - apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); - __syncwarp(); - } + /*** + * Section: Pass 1 — Reduction + * Accumulate warp-level sums needed by the backward passes: + * sigmoid/sqrtsoftplus (topk>1): sum_act, sum_grad_act for normalization bwd + * softmax: sum_output_x_grad = Σ(grad * softmax_output) + * + * For sqrtsoftplus, intermediate_output stores original logits, so we + * recompute sqrtsoftplus(x) on the fly to get the activation value. + */ + CompType sum_act = 0.0f; + CompType sum_grad_act = 0.0f; + CompType sum_output_x_grad = 0.0f; - // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && (score_function == 0 || score_function == 2)) { - // Select the correct activation output buffer: - // - Sigmoid: local_act_from_fwd already contains sigmoid output - // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above - CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; - - CompType sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ act_output, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers - CompType local_sum_Output_x_Grad = 0.0; + bool need_reduce = ((ScoreFunc == 0 || ScoreFunc == 2) && topk > 1) || (ScoreFunc == 1); + if (need_reduce) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + CompType g = static_cast(raw_grad[i]) * scaling_factor; + CompType act = local_act[i]; + bool routed = local_mask[i]; + + if constexpr (ScoreFunc == 0) { // Sigmoid + // act = sigmoid output; accumulate over routed experts only + if (routed) { + sum_act += act; + sum_grad_act += g * act; + } + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus + // act = original logit; recompute sqrtsoftplus to get activation + if (routed) { + CompType v = sqrtsoftplus_scalar(act); + sum_act += v; + sum_grad_act += g * v; + } + } else if constexpr (ScoreFunc == 1) { // Softmax + if (!use_pre_softmax) { + // Post-softmax: act = softmax output (routed positions only) + if (routed) sum_output_x_grad += g * act; + } else { + // Pre-softmax: act = softmax output (all experts) + sum_output_x_grad += (routed ? g : 0.0f) * act; + } } } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + sum_act = warp_allreduce_sum(sum_act); + sum_grad_act = warp_allreduce_sum(sum_grad_act); } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_grad[i] = - local_grad[i] / (sum_fwd_input + epsilon) - - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); - } else { - local_grad[i] = 0.0; - } + if constexpr (ScoreFunc == 1) { + sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } - __syncwarp(); - } - - // Softmax bwd if use_pre_softmax is false - if (!use_pre_softmax && score_function == 1) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, - num_experts, lane_id); - __syncwarp(); } /*** - * Section: Backward of topk - * mask the unselected position in the grad + * Section: Pass 2 — Element-wise gradient + * Compute per-element gradient using the warp-level sums from Pass 1. + * Applies backward ops in reverse of forward order: + * 1. Backward of scaling_factor (multiply grad by scaling_factor) + * 2. Backward of normalization (sigmoid/sqrtsoftplus with topk > 1) + * 3. Backward of post-softmax / topk mask + * 4. Backward of pre-softmax + * 5. Backward of activation (sigmoid / sqrtsoftplus) + * Write the grad_logits to the global mem */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (!local_routing_map[i]) { - local_grad[i] = 0.0; + CompType g = static_cast(raw_grad[i]) * scaling_factor; + CompType act = local_act[i]; + bool routed = local_mask[i]; + + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + if (topk > 1) { + g = normalize_bwd_scalar(g, routed, sum_act, sum_grad_act); + } } - } - __syncwarp(); - /*** - * Section: Backward of ops before the topk - * - Pre-softmax bwd - * - Sigmoid bwd - * - Sqrtsoftplus bwd - * - Write the grad_logits to the global mem - */ - // Pre-softmax bwd - if (score_function == 1 && use_pre_softmax) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, - num_experts, lane_id); - __syncwarp(); - } - // Sigmoid bwd - if (score_function == 0) { - apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); - __syncwarp(); - } - // Sqrtsoftplus bwd - // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier - // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) - if (score_function == 2) { - apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, - lane_id); - __syncwarp(); - } - // Write the grad_logits to the global mem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = local_grad[i]; + // Softmax bwd if use_pre_softmax is false (routed subset only) + if constexpr (ScoreFunc == 1) { + if (!use_pre_softmax) { + g = routed ? softmax_bwd_scalar(g, act, sum_output_x_grad) : 0.0f; + } + } + + // Backward of topk: mask the unselected position in the grad + if (!routed) g = 0.0f; + + // Pre-softmax bwd (all experts participate) + if constexpr (ScoreFunc == 1) { + if (use_pre_softmax) { + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } + } + + // Sigmoid bwd: dy/dx = y * (1 - y), where y = sigmoid output + if constexpr (ScoreFunc == 0) { + g = sigmoid_bwd_scalar(g, act); + // Sqrtsoftplus bwd: dy/dx = sigmoid(x) / (2 * y), where x = original logit + } else if constexpr (ScoreFunc == 2) { + g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); + } + + grad_logits[pos + i] = static_cast(g); } - __syncwarp(); + + grad_loader.flip(); + act_loader.flip(); + mask_loader.flip(); } } @@ -474,23 +769,47 @@ void fused_topk_with_score_function_backward_kernel_launcher( const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { - // Meta data for the kernel + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_probs - + - num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(CompType) // comp_buf - + num_experts * num_token_per_block * sizeof(bool); // routing_map - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_memory_size)); - fused_topk_with_score_function_backward_kernel - <<>>( - routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, - use_pre_softmax, scaling_factor, score_function, grad_logits); - NVTE_CHECK_CUDA(cudaGetLastError()); + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t single_buf_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(single_buf_shmem, 0); + size_t shmem_bytes = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); + + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, grad_logits, num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_backward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_backward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_backward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } void fused_topk_with_score_function_backward(const Tensor &routing_map, diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 08ad3d16a6..76929b1cd4 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -7,13 +7,30 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#include + #include "../util/logging.h" +#include "../util/system.h" #include "../utils.cuh" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { namespace fused_router { +// Topk values below this threshold use naive O(K*E) selection; +// at or above it, use radix O(E) selection. Configurable via +// NVTE_RADIX_TOPK_THRESHOLD (default 8: radix is faster for K>=8, +// naive avoids overhead for very small K). +// +// NOTE: This is an inline function with a static local. Each translation unit +// that includes this header gets its own copy of the static, so the env var is +// read once per TU (not once globally). This is safe because environment +// variables are immutable during process lifetime in our usage. +inline int get_radix_topk_threshold() { + static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 8); + return threshold; +} + // Check if requested shared memory size exceeds device capacity. // Throws an error with num_experts info to help users diagnose the issue. inline void check_shared_memory_capacity_num_experts(size_t shared_memory_size, int num_experts) { @@ -38,107 +55,63 @@ constexpr int kThreadsPerBlock = 128; // Using 4 warps in 1 CTA, Each warp is responsible for 1 token. constexpr float epsilon = 1e-20; -template -__device__ inline T max(T a, T b) { - return a > b ? a : b; -} - -template -__device__ inline T sum(T a, T b) { - return a + b; -} - enum ReduceFuncType { SUM, MAX, }; -template -__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, - int lane_id) { - T (*reduce_func)(T, T); - CompType default_val = 0.0; - if (type == ReduceFuncType::SUM) { - reduce_func = sum; - default_val = 0.0; - } else if (type == ReduceFuncType::MAX) { - reduce_func = max; - default_val = -std::numeric_limits::infinity(); - } +// Warp-level reduction over shared memory data. Templated on the reduction +// type to enable compile-time dispatch (no function pointer overhead). +template +__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, int lane_id) { + constexpr CompType default_val = + (type == ReduceFuncType::SUM) ? 0.0f : -std::numeric_limits::infinity(); - // Some value is handled in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread + // Each lane accumulates its strided slice CompType val = lane_id < data_size ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - val = reduce_func(val, data_ptr[i]); - } - - // Warp shuffle between threads - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); - __syncwarp(); - return T(val); -} - -template -__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, - ReduceFuncType type, int lane_id) { - T (*reduce_func)(T, T); - CompType default_val = 0.0; - if (type == ReduceFuncType::SUM) { - reduce_func = sum; - default_val = 0.0; - } else if (type == ReduceFuncType::MAX) { - reduce_func = max; - default_val = -std::numeric_limits::infinity(); - } - - // Some value is handled in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread - CompType val = lane_id < data_size && mask[lane_id] ? data_ptr[lane_id] : default_val; - for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - if (mask[i]) { - val = reduce_func(val, data_ptr[i]); + if constexpr (type == ReduceFuncType::SUM) { + val += data_ptr[i]; + } else { + val = (data_ptr[i] > val) ? static_cast(data_ptr[i]) : val; } } - // Warp shuffle between threads - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); + // Warp shuffle butterfly reduction + if constexpr (type == ReduceFuncType::SUM) { + val += __shfl_xor_sync(0xffffffff, val, 16); + val += __shfl_xor_sync(0xffffffff, val, 8); + val += __shfl_xor_sync(0xffffffff, val, 4); + val += __shfl_xor_sync(0xffffffff, val, 2); + val += __shfl_xor_sync(0xffffffff, val, 1); + } else { + auto shfl_max = [](CompType a, CompType b) { return a > b ? a : b; }; + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 16)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 8)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 4)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 2)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 1)); + } __syncwarp(); return T(val); } +// ============================================================================ +// Array (in-place on shmem) score functions — used by legacy simple kernel +// ============================================================================ + __device__ inline void apply_sigmoid_on_float(float *scores, int data_size, int lane_id) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { scores[i] = 1.0f / (1.0f + expf(-scores[i])); } } -__device__ inline void apply_sigmoid_bwd_on_float(float *grad, float *fwd_output, int data_size, - int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - grad[i] = grad[i] * fwd_output[i] * (1.0f - fwd_output[i]); - } -} - -// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) __device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { float x = scores[i]; - // softplus(x) = log(1 + exp(x)), numerically stable version - // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) float softplus_val; if (x > 20.0f) { - softplus_val = x; // for large x, softplus(x) ≈ x + softplus_val = x; } else { softplus_val = log1pf(expf(x)); } @@ -146,61 +119,52 @@ __device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, } } -// sqrtsoftplus backward: -// y = sqrt(softplus(x)) -// Matches PyTorch's Softplus(beta=1.0, threshold=20.0) -// We need the original logits (x) to compute the gradient -__device__ inline void apply_sqrtsoftplus_bwd_on_float(float *grad, float *fwd_output, - float *logits_buf, int data_size, - int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - float x = logits_buf[i]; // original logit - float y = fwd_output[i]; // sqrtsoftplus output - float dy_dx; - if (x > 20.0f) { - // When softplus(x) = x, y = sqrt(x), dy/dx = 1/(2*y) - dy_dx = 1.0f / (2.0f * y + epsilon); - } else { - // When softplus(x) = log(1+exp(x)), dy/dx = sigmoid(x) / (2*y) - // where sigmoid(x) = 1 / (1 + exp(-x)) - float sigmoid_x = 1.0f / (1.0f + expf(-x)); - dy_dx = sigmoid_x / (2.0f * y + epsilon); - } - grad[i] = grad[i] * dy_dx; - } +// ============================================================================ +// Scalar (per-element) score functions — for fused paths +// ============================================================================ + +// Forward: y = sigmoid(x) = 1 / (1 + exp(-x)) +__device__ __forceinline__ float sigmoid_scalar(float x) { return 1.0f / (1.0f + expf(-x)); } + +// Forward: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) +__device__ __forceinline__ float sqrtsoftplus_scalar(float x) { + float sp = (x > 20.0f) ? x : log1pf(expf(x)); + return sqrtf(sp); } -__device__ inline void apply_softmax_bwd_on_float(float *grad, float *fwd_output, float *comp_buf, - bool *mask, int data_size, int lane_id) { - // Put the result of output * grad to the comp_buf - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - if (mask) { - if (mask[i]) - comp_buf[i] = grad[i] * fwd_output[i]; - else - comp_buf[i] = 0.0f; - } else { - comp_buf[i] = grad[i] * fwd_output[i]; - } - } - __syncwarp(); - float sum_Output_x_Grad = warp_reduce_on_shmem( - /*data ptr = */ comp_buf, - /*data size = */ data_size, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // In-place update - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - if (mask) { - if (mask[i]) - grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); - else - grad[i] = 0.0f; - } else { - grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); - } - } +// Backward: sigmoid — given sigmoid output y, dy/dx = y * (1 - y) +__device__ __forceinline__ float sigmoid_bwd_scalar(float grad, float y) { + return grad * y * (1.0f - y); +} + +// Backward: sqrtsoftplus — given original logit x and sqrtsoftplus output y = sqrt(softplus(x)), +// dy/dx = sigmoid(x) / (2 * y). For large x (>20), softplus(x) ≈ x so dy/dx ≈ 1/(2*y). +__device__ __forceinline__ float sqrtsoftplus_bwd_scalar(float grad, float x, float y) { + float dy_dx = + (x > 20.0f) ? (1.0f / (2.0f * y + epsilon)) : (sigmoid_scalar(x) / (2.0f * y + epsilon)); + return grad * dy_dx; +} + +// Backward: normalization — given grad, routed flag, and pre-computed sums. +// Used by sigmoid/sqrtsoftplus with topk > 1. +// sum_act = sum of activation outputs over routed experts. +// sum_grad_act = sum of grad * act over routed experts. +__device__ __forceinline__ float normalize_bwd_scalar(float grad, bool routed, float sum_act, + float sum_grad_act) { + if (!routed) return 0.0f; + float denom = sum_act + epsilon; + return grad / denom - sum_grad_act / (denom * denom); +} + +// Backward: softmax element — given grad, softmax output, and sum(output * grad). +__device__ __forceinline__ float softmax_bwd_scalar(float grad, float act, float dot) { + return act * (grad - dot); } +// ============================================================================ +// Array (in-place on shmem) softmax — still used by forward kernels +// ============================================================================ + __device__ inline void apply_softmax_on_float(float *scores, int data_size, int lane_id) { // --- Pass 1: Online accumulation of max and sum_exp --- float local_max = -std::numeric_limits::infinity(); @@ -256,6 +220,11 @@ enum class TopkFuncType { Radix = 1, }; +// Maximum num_experts supported by the packed 8-bit radix topk histogram. +// Each thread processes ceil(data_size/32) elements per bucket. With 8-bit +// counters the max per-thread count is 255, so data_size <= 255 * 32 = 8160. +constexpr int kMaxExpertsRadixTopk = 255 * 32; // 8160 + /******************************************************************************* * radix_topk_and_mask — Warp-level radix-selection based top-K * @@ -281,11 +250,16 @@ enum class TopkFuncType { * broken by ascending index for determinism matching torch.topk). * Write indices and scores to the output arrays. * + * Register pressure optimization: pack 16 bucket counts into 4 registers + * using 8-bit fields (4 counters per u32). The original counts[16] + + * total_counts[16] required 32 registers, causing massive spill to local + * memory on large kernels (81% of L1 traffic on E=2304, K=36). + * * Tie-breaking: (value DESC, index ASC) — matches torch.topk behavior. * * Constraints: * - 0 < topk <= data_size - * - No upper limit on topk or data_size (unlike v1's 128 cap) + * - data_size <= kMaxExpertsRadixTopk (8160) to avoid 8-bit overflow * - scores must be in shared memory accessible by the warp * * Complexity: 9 × O(E/32) = O(E) per warp, independent of K. @@ -293,8 +267,8 @@ enum class TopkFuncType { __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int topk, int *topk_indices, CompType *topk_scores, int lane_id) { - // assert(topk > 0 && "naive_topk_and_mask_v2: topk must be positive"); - // assert(topk <= data_size && "naive_topk_and_mask_v2: topk exceeds data_size"); + assert(data_size > 0 && data_size <= kMaxExpertsRadixTopk); + assert(topk > 0 && topk <= data_size); constexpr int RADIX_BITS = 4; constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 16 buckets @@ -303,54 +277,58 @@ __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int // ========================================================================= // Phase 1: Radix selection — find the bit pattern of the K-th largest value + // + // Packed counters: 16 bucket counts are stored in 4 × u32 registers using + // 8-bit fields (4 counters per register). Bucket b is in byte (b % 4) of + // packed[b / 4]. This reduces register usage from 32 (counts[16] + + // total_counts[16]) to 4 registers. + // + // Max per-thread count per bucket = ceil(data_size / 32). + // For E=2304: max 72 — fits in 8 bits (max 255). + // Constraint: data_size must be <= kMaxExpertsRadixTopk (8160). // ========================================================================= unsigned int desired = 0; // accumulated bit pattern of the K-th value unsigned int desired_mask = 0; // bits determined so far int k_remaining = topk; // how many more elements we need to skip +#pragma unroll 1 for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { int digit_pos = pass * RADIX_BITS; - // Each thread counts elements per bucket that match the desired pattern - unsigned int counts[RADIX_SIZE]; -#pragma unroll - for (int b = 0; b < RADIX_SIZE; b++) { - counts[b] = 0; - } + // Packed counters: packed[i] holds 4 × 8-bit counts for buckets [4i..4i+3]. + // Bucket b is in byte (b % 4) of packed[b / 4]. + unsigned int packed[4] = {0, 0, 0, 0}; for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { unsigned int u = float_to_ordered_uint(scores[i]); - // Check if this element matches the desired pattern on already-decided bits if ((u & desired_mask) == desired) { int bucket = (u >> digit_pos) & RADIX_MASK; - counts[bucket]++; + int pack_idx = bucket >> 2; // bucket / 4 + int byte_idx = bucket & 3; // bucket % 4 + int shift = byte_idx << 3; // byte_idx * 8 + packed[pack_idx] += (1u << shift); // increment the 8-bit field } } - // Warp-reduce each bucket count across all 32 lanes - unsigned int total_counts[RADIX_SIZE]; -#pragma unroll - for (int b = 0; b < RADIX_SIZE; b++) { - unsigned int c = warp_allreduce_sum(counts[b]); - total_counts[b] = c; // same value on all lanes after full reduction - } - - // Scan buckets from LARGEST digit value (15) to smallest (0). - // We're looking for the top-K largest, so we want the highest-valued - // bucket first. Accumulate counts until we find the bucket containing - // the k_remaining-th element. + // Warp-reduce each bucket, then scan to find the target bucket. int target_bucket = 0; + int k_remaining_copy = k_remaining; +#pragma unroll for (int b = RADIX_SIZE - 1; b >= 0; b--) { - unsigned int bc = total_counts[b]; - if (bc < static_cast(k_remaining)) { - // All elements in this bucket are in the top set; skip them - k_remaining -= bc; + // Unpack: extract 8-bit count for bucket b from packed[b/4] + int pack_idx = b >> 2; + int shift = (b & 3) << 3; + unsigned int my_count = (packed[pack_idx] >> shift) & 0xFFu; + // Warp-reduce to get total count across all 32 lanes + unsigned int bc = warp_allreduce_sum(my_count); + if (bc < static_cast(k_remaining_copy)) { + k_remaining_copy -= bc; } else { - // The K-th element is in this bucket target_bucket = b; break; } } + k_remaining = k_remaining_copy; // Update the desired pattern and mask desired |= (static_cast(target_bucket) << digit_pos); @@ -448,7 +426,7 @@ __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int ? scores[lane_id] : -std::numeric_limits::infinity(); int index = (lane_id < data_size) ? lane_id : 0; - // Some value is hanlded in local thread + // Some value is handled in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {