From e4ddb5cb5f33e42d4cdb4d48a1b9f0dd92b6088a Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 01/15] [Common] Fuse preprocess/backward loops in fused router kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace multi-loop preprocess (separate clear/load/score/save/bias loops) with single fused loops per score function in all 4 kernel paths (topk forward, topk backward, aux_loss forward, aux_loss backward). Replace multi-pass backward (array-based helpers + comp_buf shmem) with a two-pass approach using scalar helpers: Pass 1: reduction — warp-level sums via warp_allreduce_sum() Pass 2: element-wise — scalar gradient computation → write to global Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar, sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar, softmax_bwd_scalar. Remove dead array helpers from utils.h: apply_sigmoid_on_float, apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float, apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float, masked_warp_reduce_on_shmem. Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf eliminated). Net -226 lines across 3 files. Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 229 ++++++------ .../fused_topk_with_score_function.cu | 333 ++++++++---------- .../common/fused_router/utils.h | 137 ++----- 3 files changed, 287 insertions(+), 412 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4eb4240d7c..6207faad07 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -52,64 +52,46 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi if (token_offset_cur_warp >= num_tokens) break; /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the logits to shmem + * Section: Init buffer + Preprocess + * - Load logits → apply score function → save intermediate_output + * + * Fused into a single loop per score function where possible: + * score_function == 0 (sigmoid): load, sigmoid, save → shmem + * score_function == 1 (softmax): load → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): load, save logits, sqrtsoftplus → shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the routing_map (num_experts) - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - routing_map[pos_offset + i] = false; - if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); - } - } - // Load the logits to shmem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = static_cast(logits[pos_offset + i]); - } - __threadfence_block(); - __syncwarp(); - /*** - * Section: Preprocess - * Possible preprocess the scores before the topk operation - * - Pre-softmax - * - Sigmoid - * - Sqrtsoftplus - * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 - * This is in-place scores update - */ - if (score_function == 1) { // score_function == 1 means softmax - // Apply softmax to the logits before the topk + if (score_function == 1) { // Softmax + // Apply softmax to all logits, save softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] = static_cast(logits[pos_offset + i]); + } + __syncwarp(); apply_softmax_on_float(local_logits, num_experts, lane_id); __syncwarp(); // Save the softmax output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } else if (score_function == 0) { // score_function == 0 means sigmoid - // Apply sigmoid to the logits - apply_sigmoid_on_float(local_logits, num_experts, lane_id); - __syncwarp(); - // Save the sigmoid output for backward + } else if (score_function == 0) { // Sigmoid + // Fused: load logit → sigmoid → save sigmoid output for backward → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = local_logits[i]; + float val = sigmoid_scalar(static_cast(logits[pos_offset + i])); + intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward + local_logits[i] = val; } - } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus - // First save the original logits for backward (needed for gradient computation) + } else if (score_function == 2) { // Sqrtsoftplus + // Fused: load logit → save original logit for backward → sqrtsoftplus → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits + float logit = static_cast(logits[pos_offset + i]); + intermediate_output[pos_offset + i] = logit; // Save original logits for backward + local_logits[i] = sqrtsoftplus_scalar(logit); } - __syncwarp(); - // Apply sqrtsoftplus to the logits - apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); } + __syncwarp(); - __syncwarp(); //Confirm the scores is written to the output - - // Sigmoid/Sqrtsoftplus post-processing + // Sigmoid/Sqrtsoftplus post-processing: normalize scores to sum to 1 if (score_function == 0 || score_function == 2) { auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); @@ -185,6 +167,16 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, reinterpret_cast(intermediate_output.data.dptr), stream);); } +// Backward: grad_scores + intermediate_output → grad_logits. +// No routing_map — all experts participate (unlike topk backward). +// +// Two-pass fused approach (eliminates the comp_buf shmem buffer): +// Pass 1 (reduction): accumulate warp-level sums for normalization/softmax bwd. +// Pass 2 (element-wise): compute per-element gradient and write to global memory. +// +// Shmem layout (W = warps/block): +// grad_buf: E × W × sizeof(CompType) — grad_scores loaded from global +// act_buf: E × W × sizeof(CompType) — intermediate_output (sigmoid/softmax out, or logits) template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, @@ -196,19 +188,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ - // Used variables/addresses init int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - CompType *grad_scores_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus - CompType *act_from_fwd_buf = grad_scores_buf + num_experts * num_token_per_block; - CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; + CompType *grad_buf = reinterpret_cast(shmem); + CompType *act_buf = grad_buf + num_experts * num_token_per_block; // The address of buffers on the current warp - CompType *local_grad = grad_scores_buf + warp_id * num_experts; - CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - CompType *local_comp_buf = comp_buf + warp_id * num_experts; + CompType *local_grad = grad_buf + warp_id * num_experts; + CompType *local_act = act_buf + warp_id * num_experts; /*** * Section: Main Loop @@ -216,95 +204,80 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; for (int round = blockIdx.x; round < total_round; round += gridDim.x) { - int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token - if (token_offset_cur_warp >= num_tokens) break; + int token_idx = round * num_token_per_block + warp_id; + if (token_idx >= num_tokens) break; + int pos = token_idx * num_experts; /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the dgrad/output_from_fwd to shmem + * Section: Load inputs to shmem + * - Load the grad_scores/intermediate_output to shmem */ - int pos_offset = token_offset_cur_warp * num_experts; - // Load the dgrad/output_from_fwd to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = grad_scores[pos_offset + i]; - local_act_from_fwd[i] = intermediate_output[pos_offset + i]; + local_grad[i] = grad_scores[pos + i]; + local_act[i] = intermediate_output[pos + i]; } - __threadfence_block(); __syncwarp(); /*** - * Section: Backward of ops before the topk - * - Pre-softmax bwd - * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 - * - Sigmoid bwd - * - Sqrtsoftplus bwd - * - Write the grad_logits to the global mem + * Section: Pass 1 — Reduction + * Accumulate warp-level sums needed by the backward passes: + * sigmoid/sqrtsoftplus: sum_act, sum_grad_act for normalization bwd + * softmax: sum_output_x_grad = Σ(grad * softmax_output) + * + * For sqrtsoftplus, intermediate_output stores original logits, so we + * recompute sqrtsoftplus(x) on the fly to get the activation value. */ - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits - // (needed for both post-processing bwd and activation bwd, compute once here) - // For sqrtsoftplus, intermediate_output stores original logits - if (score_function == 2) { - // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_act_from_fwd[i]; + CompType sum_act = 0.0f; + CompType sum_grad_act = 0.0f; + CompType sum_output_x_grad = 0.0f; + + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + CompType g = local_grad[i]; + CompType act = local_act[i]; + if (score_function == 0) { // Sigmoid + // act = sigmoid output; accumulate over all experts + sum_act += act; sum_grad_act += g * act; + } else if (score_function == 2) { // Sqrtsoftplus + // act = original logit; recompute sqrtsoftplus to get activation + CompType v = sqrtsoftplus_scalar(act); + sum_act += v; sum_grad_act += g * v; + } else if (score_function == 1) { // Softmax + // act = softmax output + sum_output_x_grad += g * act; } - __syncwarp(); - apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); - __syncwarp(); } - - // Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward) if (score_function == 0 || score_function == 2) { - // Select the correct activation output buffer: - // - Sigmoid: local_act_from_fwd already contains sigmoid output - // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above - CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; - - auto sum_fwd_input = - warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers - CompType local_sum_Output_x_Grad = 0.0; - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_sum_Output_x_Grad += local_grad[i] * act_output[i]; - } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); - } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = local_grad[i] / (sum_fwd_input + epsilon) - - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); - } - __syncwarp(); + sum_act = warp_allreduce_sum(sum_act); + sum_grad_act = warp_allreduce_sum(sum_grad_act); } - - // Pre-softmax bwd if (score_function == 1) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, - num_experts, lane_id); - __syncwarp(); - } - // Sigmoid bwd - if (score_function == 0) { - apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); - __syncwarp(); + sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } - // Sqrtsoftplus bwd - // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier - // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) - if (score_function == 2) { - apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, - lane_id); - __syncwarp(); - } - // Write the grad_logits to the global mem + + /*** + * Section: Pass 2 — Element-wise gradient + * Compute per-element gradient using the warp-level sums from Pass 1. + * Applies backward ops in reverse of forward order: + * sigmoid: normalization bwd → sigmoid bwd + * sqrtsoftplus: normalization bwd → sqrtsoftplus bwd + * softmax: softmax bwd + * Write the grad_logits to the global mem + */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = static_cast(local_grad[i]); + CompType g = local_grad[i]; + CompType act = local_act[i]; + + if (score_function == 0) { // Sigmoid bwd + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + g = sigmoid_bwd_scalar(g, act); + } else if (score_function == 2) { // Sqrtsoftplus bwd + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); + } else if (score_function == 1) { // Softmax bwd + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } + + grad_logits[pos + i] = static_cast(g); } __syncwarp(); } @@ -314,13 +287,11 @@ template void fused_score_for_moe_aux_loss_backward_kernel_launcher( const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { - // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_scores - + - num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(CompType); // comp_buf + // Shmem: grad_buf + act_buf (no comp_buf — eliminated by fused scalar approach) + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_buf + + num_experts * num_token_per_block * sizeof(CompType); // act_buf check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 9f7a830546..2142b808c4 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -60,74 +60,71 @@ __global__ void fused_topk_with_score_function_forward_kernel( if (token_offset_cur_warp >= num_tokens) break; /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the logits to shmem + * Section: Init buffer + Preprocess + * - Clear the global output buffers (probs, routing_map) + * - Load logits → apply score function → save intermediate_output → add expert bias + * + * Fused into a single loop per score function where possible: + * score_function == 0 (sigmoid): load, sigmoid, save, +bias → scores + * score_function == 1 (softmax): load → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): load, save logits, sqrtsoftplus, +bias → scores + * + * Expert bias is only used with sigmoid/sqrtsoftplus and is fused into + * the same loop that computes the score. */ int pos_offset = token_offset_cur_warp * num_experts; + // Clear the probs/routing_map (num_experts) for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { probs[pos_offset + i] = 0.0; routing_map[pos_offset + i] = false; - if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); - } - } - // Load the logits to shmem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] = logits[pos_offset + i]; } - // If group_topk > 0, init the masked_scores to -inf - if (group_topk > 0) { - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - masked_scores[i] = -std::numeric_limits::infinity(); - } - } - __threadfence_block(); - __syncwarp(); - /*** - * Section: Preprocess - * Possible preprocess the scores before the topk operation - * - Pre-softmax - * - Sigmoid - * - Sqrtsoftplus - * - Expert bias - * This is in-place scores update - */ - if (use_pre_softmax && score_function == 1) { // score_function == 1 means softmax - // Apply softmax to the logits before the topk - apply_softmax_on_float(scores, num_experts, lane_id); - __syncwarp(); - // Save the softmax output for backward - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = scores[i]; + if (score_function == 1) { // Softmax + if (use_pre_softmax) { + // Pre-softmax: apply softmax to all logits before topk, save for backward. + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = static_cast(logits[pos_offset + i]); + } + __syncwarp(); + apply_softmax_on_float(scores, num_experts, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } else { + // Post-softmax: softmax applied after topk; init intermediate to -inf + // (only the topk positions will be filled in the postprocess section). + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = static_cast(logits[pos_offset + i]); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } } - } else if (score_function == 0) { // score_function == 0 means sigmoid - // Apply sigmoid to the logits - apply_sigmoid_on_float(scores, num_experts, lane_id); - __syncwarp(); - // Save the sigmoid output for backward + } else if (score_function == 0) { // Sigmoid + // Fused: load logit → sigmoid → save sigmoid output for backward → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = scores[i]; + float val = sigmoid_scalar(static_cast(logits[pos_offset + i])); + intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward + if (expert_bias) val += static_cast(expert_bias[i]); + scores[i] = val; } - } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus - // First save the original logits for backward (needed for sqrtsoftplus gradient computation) + } else if (score_function == 2) { // Sqrtsoftplus + // Fused: load logit → save original logit for backward → sqrtsoftplus → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - intermediate_output[pos_offset + i] = scores[i]; // Save original logits + float logit = static_cast(logits[pos_offset + i]); + intermediate_output[pos_offset + i] = logit; // Save original logits for backward + float val = sqrtsoftplus_scalar(logit); + if (expert_bias) val += static_cast(expert_bias[i]); + scores[i] = val; } - __syncwarp(); - // Apply sqrtsoftplus to the logits - apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); } + __syncwarp(); - __syncwarp(); //Confirm the scores is written to the output - - // Expert bias is only used at the sigmoid/sqrtsoftplus case - if (expert_bias && (score_function == 0 || score_function == 2)) { + // If group_topk > 0, init the masked_scores to -inf + if (group_topk > 0) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] += static_cast(expert_bias[i]); + masked_scores[i] = -std::numeric_limits::infinity(); } __syncwarp(); } @@ -295,6 +292,16 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, reinterpret_cast(intermediate_output.data.dptr), stream););); } +// Backward: grad_probs + intermediate_output + routing_map → grad_logits. +// +// Two-pass fused approach (eliminates the comp_buf shmem buffer): +// Pass 1 (reduction): accumulate warp-level sums needed by normalization/softmax bwd. +// Pass 2 (element-wise): compute per-element gradient and write to global memory. +// +// Shmem layout (W = warps/block): +// grad_buf: E × W × sizeof(CompType) — grad_probs loaded from global +// act_buf: E × W × sizeof(CompType) — intermediate_output (sigmoid/softmax out, or logits) +// mask_buf: E × W × sizeof(bool) — routing_map from forward template __global__ void fused_topk_with_score_function_backward_kernel( // Inputs tensor @@ -309,22 +316,17 @@ __global__ void fused_topk_with_score_function_backward_kernel( * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ - // Used variables/addresses init int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - CompType *grad_probs_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus - CompType *act_from_fwd_buf = grad_probs_buf + num_experts * num_token_per_block; - CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; - // To store the routing_map from the fwd - bool *routing_map_buf = reinterpret_cast(comp_buf + num_experts * num_token_per_block); + CompType *grad_buf = reinterpret_cast(shmem); + CompType *act_buf = grad_buf + num_experts * num_token_per_block; + bool *mask_buf = reinterpret_cast(act_buf + num_experts * num_token_per_block); // The address of buffers on the current warp - CompType *local_grad = grad_probs_buf + warp_id * num_experts; - CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - CompType *local_comp_buf = comp_buf + warp_id * num_experts; - bool *local_routing_map = routing_map_buf + warp_id * num_experts; + CompType *local_grad = grad_buf + warp_id * num_experts; + CompType *local_act = act_buf + warp_id * num_experts; + bool *local_mask = mask_buf + warp_id * num_experts; /*** * Section: Main Loop @@ -332,138 +334,113 @@ __global__ void fused_topk_with_score_function_backward_kernel( */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; for (int round = blockIdx.x; round < total_round; round += gridDim.x) { - int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token - if (token_offset_cur_warp >= num_tokens) break; + int token_idx = round * num_token_per_block + warp_id; + if (token_idx >= num_tokens) break; + int pos = token_idx * num_experts; /*** - * Section: Init buffer - * - Clear the global buffer which will accept the result of this round - * - Clear/Init the shmem buffer used by current warp this round - * - Load the dgrad/output_from_fwd to shmem + * Section: Load inputs to shmem + * - Load the grad_probs/intermediate_output/routing_map to shmem */ - int pos_offset = token_offset_cur_warp * num_experts; - // Load the dgrad/output_from_fwd to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = grad_probs[pos_offset + i]; - local_act_from_fwd[i] = intermediate_output[pos_offset + i]; - local_routing_map[i] = routing_map[pos_offset + i]; + local_grad[i] = static_cast(grad_probs[pos + i]); + local_act[i] = intermediate_output[pos + i]; + local_mask[i] = routing_map[pos + i]; } - __threadfence_block(); __syncwarp(); /*** - * Section: Backward of ops after the topk - * - Backward of the used scaling_factor - * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 - * - Softmax bwd if use_pre_softmax is false + * Section: Pass 1 — Reduction + * Accumulate warp-level sums needed by the backward passes: + * sigmoid/sqrtsoftplus (topk>1): sum_act, sum_grad_act for normalization bwd + * softmax: sum_output_x_grad = Σ(grad * softmax_output) + * + * For sqrtsoftplus, intermediate_output stores original logits, so we + * recompute sqrtsoftplus(x) on the fly to get the activation value. */ - // Backward of the used scaling_factor - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_grad[i] = local_grad[i] * scaling_factor; - } - } - __syncwarp(); - - // Sqrtsoftplus: First compute sqrtsoftplus output from original logits - // (needed for both post-processing bwd and activation bwd, compute once here) - // For sqrtsoftplus, intermediate_output stores original logits - if (score_function == 2) { - // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_act_from_fwd[i]; - } - __syncwarp(); - apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); - __syncwarp(); - } + CompType sum_act = 0.0f; + CompType sum_grad_act = 0.0f; + CompType sum_output_x_grad = 0.0f; - // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && (score_function == 0 || score_function == 2)) { - // Select the correct activation output buffer: - // - Sigmoid: local_act_from_fwd already contains sigmoid output - // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above - CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; - - CompType sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ act_output, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Compute sum of output * grad using registers - CompType local_sum_Output_x_Grad = 0.0; + bool need_reduce = ((score_function == 0 || score_function == 2) && topk > 1) + || (score_function == 1); + if (need_reduce) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + CompType g = local_grad[i] * scaling_factor; + CompType act = local_act[i]; + bool routed = local_mask[i]; + + if (score_function == 0) { // Sigmoid + // act = sigmoid output; accumulate over routed experts only + if (routed) { sum_act += act; sum_grad_act += g * act; } + } else if (score_function == 2) { // Sqrtsoftplus + // act = original logit; recompute sqrtsoftplus to get activation + if (routed) { + CompType v = sqrtsoftplus_scalar(act); + sum_act += v; sum_grad_act += g * v; + } + } else if (score_function == 1) { // Softmax + if (!use_pre_softmax) { + // Post-softmax: act = softmax output (routed positions only) + if (routed) sum_output_x_grad += g * act; + } else { + // Pre-softmax: act = softmax output (all experts) + sum_output_x_grad += (routed ? g : 0.0f) * act; + } } } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + if (score_function == 0 || score_function == 2) { + sum_act = warp_allreduce_sum(sum_act); + sum_grad_act = warp_allreduce_sum(sum_grad_act); } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; - // In-place update - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (local_routing_map[i]) { - local_grad[i] = - local_grad[i] / (sum_fwd_input + epsilon) - - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); - } else { - local_grad[i] = 0.0; - } + if (score_function == 1) { + sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } - __syncwarp(); - } - - // Softmax bwd if use_pre_softmax is false - if (!use_pre_softmax && score_function == 1) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, - num_experts, lane_id); - __syncwarp(); } /*** - * Section: Backward of topk - * mask the unselected position in the grad + * Section: Pass 2 — Element-wise gradient + * Compute per-element gradient using the warp-level sums from Pass 1. + * Applies backward ops in reverse of forward order: + * 1. Backward of scaling_factor (multiply grad by scaling_factor) + * 2. Backward of normalization (sigmoid/sqrtsoftplus with topk > 1) + * 3. Backward of post-softmax / topk mask + * 4. Backward of pre-softmax + * 5. Backward of activation (sigmoid / sqrtsoftplus) + * Write the grad_logits to the global mem */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - if (!local_routing_map[i]) { - local_grad[i] = 0.0; + CompType g = local_grad[i] * scaling_factor; + CompType act = local_act[i]; + bool routed = local_mask[i]; + + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if ((score_function == 0 || score_function == 2) && topk > 1) { + g = normalize_bwd_scalar(g, routed, sum_act, sum_grad_act); } - } - __syncwarp(); - /*** - * Section: Backward of ops before the topk - * - Pre-softmax bwd - * - Sigmoid bwd - * - Sqrtsoftplus bwd - * - Write the grad_logits to the global mem - */ - // Pre-softmax bwd - if (score_function == 1 && use_pre_softmax) { - apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr, - num_experts, lane_id); - __syncwarp(); - } - // Sigmoid bwd - if (score_function == 0) { - apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); - __syncwarp(); - } - // Sqrtsoftplus bwd - // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier - // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) - if (score_function == 2) { - apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, - lane_id); - __syncwarp(); - } - // Write the grad_logits to the global mem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = local_grad[i]; + // Softmax bwd if use_pre_softmax is false (routed subset only) + if (score_function == 1 && !use_pre_softmax) { + g = routed ? softmax_bwd_scalar(g, act, sum_output_x_grad) : 0.0f; + } + + // Backward of topk: mask the unselected position in the grad + if (!routed) g = 0.0f; + + // Pre-softmax bwd (all experts participate) + if (score_function == 1 && use_pre_softmax) { + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } + + // Sigmoid bwd: dy/dx = y * (1 - y), where y = sigmoid output + if (score_function == 0) { + g = sigmoid_bwd_scalar(g, act); + // Sqrtsoftplus bwd: dy/dx = sigmoid(x) / (2 * y), where x = original logit + } else if (score_function == 2) { + g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); + } + + grad_logits[pos + i] = static_cast(g); } __syncwarp(); } @@ -474,14 +451,12 @@ void fused_topk_with_score_function_backward_kernel_launcher( const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { - // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_probs - + - num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(CompType) // comp_buf - + num_experts * num_token_per_block * sizeof(bool); // routing_map + // Shmem: grad_buf + act_buf + mask_buf (no comp_buf — eliminated by fused scalar approach) + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_buf + + num_experts * num_token_per_block * sizeof(CompType) // act_buf + + num_experts * num_token_per_block * sizeof(bool); // mask_buf check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 08ad3d16a6..2fb3752f0c 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -84,122 +84,51 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncT return T(val); } -template -__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, - ReduceFuncType type, int lane_id) { - T (*reduce_func)(T, T); - CompType default_val = 0.0; - if (type == ReduceFuncType::SUM) { - reduce_func = sum; - default_val = 0.0; - } else if (type == ReduceFuncType::MAX) { - reduce_func = max; - default_val = -std::numeric_limits::infinity(); - } +// ============================================================================ +// Scalar (per-element) score functions — for fused paths +// ============================================================================ - // Some value is handled in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread - CompType val = lane_id < data_size && mask[lane_id] ? data_ptr[lane_id] : default_val; - for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - if (mask[i]) { - val = reduce_func(val, data_ptr[i]); - } - } +// Forward: y = sigmoid(x) = 1 / (1 + exp(-x)) +__device__ __forceinline__ float sigmoid_scalar(float x) { return 1.0f / (1.0f + expf(-x)); } - // Warp shuffle between threads - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); - __syncwarp(); - return T(val); +// Forward: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) +__device__ __forceinline__ float sqrtsoftplus_scalar(float x) { + float sp = (x > 20.0f) ? x : log1pf(expf(x)); + return sqrtf(sp); } -__device__ inline void apply_sigmoid_on_float(float *scores, int data_size, int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = 1.0f / (1.0f + expf(-scores[i])); - } +// Backward: sigmoid — given sigmoid output y, dy/dx = y * (1 - y) +__device__ __forceinline__ float sigmoid_bwd_scalar(float grad, float y) { + return grad * y * (1.0f - y); } -__device__ inline void apply_sigmoid_bwd_on_float(float *grad, float *fwd_output, int data_size, - int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - grad[i] = grad[i] * fwd_output[i] * (1.0f - fwd_output[i]); - } +// Backward: sqrtsoftplus — given original logit x and sqrtsoftplus output y = sqrt(softplus(x)), +// dy/dx = sigmoid(x) / (2 * y). For large x (>20), softplus(x) ≈ x so dy/dx ≈ 1/(2*y). +__device__ __forceinline__ float sqrtsoftplus_bwd_scalar(float grad, float x, float y) { + float dy_dx = (x > 20.0f) ? (1.0f / (2.0f * y + epsilon)) + : (sigmoid_scalar(x) / (2.0f * y + epsilon)); + return grad * dy_dx; } -// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) -__device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - float x = scores[i]; - // softplus(x) = log(1 + exp(x)), numerically stable version - // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) - float softplus_val; - if (x > 20.0f) { - softplus_val = x; // for large x, softplus(x) ≈ x - } else { - softplus_val = log1pf(expf(x)); - } - scores[i] = sqrtf(softplus_val); - } +// Backward: normalization — given grad, routed flag, and pre-computed sums. +// Used by sigmoid/sqrtsoftplus with topk > 1. +// sum_act = sum of activation outputs over routed experts. +// sum_grad_act = sum of grad * act over routed experts. +__device__ __forceinline__ float normalize_bwd_scalar(float grad, bool routed, float sum_act, + float sum_grad_act) { + if (!routed) return 0.0f; + float denom = sum_act + epsilon; + return grad / denom - sum_grad_act / (denom * denom); } -// sqrtsoftplus backward: -// y = sqrt(softplus(x)) -// Matches PyTorch's Softplus(beta=1.0, threshold=20.0) -// We need the original logits (x) to compute the gradient -__device__ inline void apply_sqrtsoftplus_bwd_on_float(float *grad, float *fwd_output, - float *logits_buf, int data_size, - int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - float x = logits_buf[i]; // original logit - float y = fwd_output[i]; // sqrtsoftplus output - float dy_dx; - if (x > 20.0f) { - // When softplus(x) = x, y = sqrt(x), dy/dx = 1/(2*y) - dy_dx = 1.0f / (2.0f * y + epsilon); - } else { - // When softplus(x) = log(1+exp(x)), dy/dx = sigmoid(x) / (2*y) - // where sigmoid(x) = 1 / (1 + exp(-x)) - float sigmoid_x = 1.0f / (1.0f + expf(-x)); - dy_dx = sigmoid_x / (2.0f * y + epsilon); - } - grad[i] = grad[i] * dy_dx; - } +// Backward: softmax element — given grad, softmax output, and sum(output * grad). +__device__ __forceinline__ float softmax_bwd_scalar(float grad, float act, float dot) { + return act * (grad - dot); } -__device__ inline void apply_softmax_bwd_on_float(float *grad, float *fwd_output, float *comp_buf, - bool *mask, int data_size, int lane_id) { - // Put the result of output * grad to the comp_buf - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - if (mask) { - if (mask[i]) - comp_buf[i] = grad[i] * fwd_output[i]; - else - comp_buf[i] = 0.0f; - } else { - comp_buf[i] = grad[i] * fwd_output[i]; - } - } - __syncwarp(); - float sum_Output_x_Grad = warp_reduce_on_shmem( - /*data ptr = */ comp_buf, - /*data size = */ data_size, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // In-place update - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - if (mask) { - if (mask[i]) - grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); - else - grad[i] = 0.0f; - } else { - grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); - } - } -} +// ============================================================================ +// Array (in-place on shmem) softmax — still used by forward kernels +// ============================================================================ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int lane_id) { // --- Pass 1: Online accumulation of max and sum_exp --- From 7fd8e45b37712cd50e1a7199c935e0b4469f5ff7 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 02/15] [Common] Add async loader with persistent grid and double buffering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add async_loader.h with: - RawAsyncLoader: cp.async on sm_80+, int4 fallback on sm_70, stores data in original type (no conversion during copy) - compute_persistent_grid(): occupancy-based grid sizing - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision - vec_fill_global(), vec_store_global(): vectorized output helpers Forward kernels (topk + aux_loss): - Logits loaded via RawAsyncLoader with double-buffered prefetch - Persistent grid replaces 1-shot grid launch - DataType→CompType conversion during compute, not during load - vec_fill_global for clearing probs/routing_map Backward kernels (topk + aux_loss): - All inputs loaded via RawAsyncLoader (topk: 3 loaders for grad/act/mask; aux_loss: 2 loaders for grad/act) - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2) - Persistent grid with occupancy-based sizing Signed-off-by: Harry Zhou --- .../common/fused_router/async_loader.h | 254 +++++++++++++++++ .../fused_score_for_moe_aux_loss.cu | 242 +++++++++++------ .../fused_topk_with_score_function.cu | 256 ++++++++++++------ 3 files changed, 582 insertions(+), 170 deletions(-) create mode 100644 transformer_engine/common/fused_router/async_loader.h diff --git a/transformer_engine/common/fused_router/async_loader.h b/transformer_engine/common/fused_router/async_loader.h new file mode 100644 index 0000000000..93cef00118 --- /dev/null +++ b/transformer_engine/common/fused_router/async_loader.h @@ -0,0 +1,254 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ +#define TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ + +#include + +#include + +#include "../utils.cuh" +#include "utils.h" + +namespace transformer_engine { +namespace fused_router { + +// ============================================================================ +// Persistent kernel grid size computation +// ============================================================================ + +// Compute a persistent grid size: min(total_blocks_needed, SMs * max_blocks_per_SM). +// `kernel_func` is a pointer to the __global__ function. +// `block_size` is kThreadsPerBlock. +// `shmem_bytes` is the dynamic shared memory per block. +// `total_blocks` is ceil(num_tokens / tokens_per_block). +template +inline size_t compute_persistent_grid(KernelFunc kernel_func, int block_size, size_t shmem_bytes, + size_t total_blocks) { + int blocks_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel_func, + block_size, shmem_bytes)); + if (blocks_per_sm <= 0) { + return total_blocks; + } + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + int num_sms; + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device_id)); + + size_t max_resident = static_cast(num_sms) * blocks_per_sm; + return (total_blocks < max_resident) ? total_blocks : max_resident; +} + +// ============================================================================ +// Occupancy-aware double-buffer decision +// ============================================================================ + +// Decide whether to use 1 or 2 buffers based on shmem budget. +// `single_buf_shmem` is the per-buffer shmem for the async-loaded data. +// `other_shmem_bytes` is shmem for everything else (scratch, work buffers). +// Returns 1 or 2. Ensures at least kMinBlocksPerSM blocks can co-reside. +inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) { + constexpr int kMinBlocksPerSM = 4; + + size_t total_single = single_buf_shmem + other_shmem_bytes; + size_t total_double = 2 * single_buf_shmem + other_shmem_bytes; + + int device_id; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + int max_smem; + NVTE_CHECK_CUDA( + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + + int blocks_double = (total_double > 0) ? static_cast(max_smem / total_double) : 0; + int blocks_single = (total_single > 0) ? static_cast(max_smem / total_single) : 0; + + if (blocks_double >= kMinBlocksPerSM) return 2; + if (blocks_single >= kMinBlocksPerSM) return 1; + return (blocks_double >= blocks_single) ? 2 : 1; +} + +// ============================================================================ +// Vectorized global store/fill helpers (using Vec<> from utils.cuh) +// ============================================================================ + +template +struct VecTraits { + static constexpr int kVecSize = (sizeof(T) <= 16) ? (16 / sizeof(T)) : 1; +}; + +// Vectorized store: write `count` elements from shmem/registers to global memory. +template +__device__ inline void vec_store_global(T *__restrict__ dst, const T *__restrict__ src, int count, + int lane_id) { + constexpr int kVecSize = VecTraits::kVecSize; + using VecType = typename BytesToType::Type; + + bool aligned = (reinterpret_cast(dst) % (sizeof(T) * kVecSize) == 0); + int aligned_count = (count / kVecSize) * kVecSize; + + if (aligned && aligned_count > 0) { + int vec_count = aligned_count / kVecSize; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + VecType v; + T *v_elts = reinterpret_cast(&v); +#pragma unroll + for (int e = 0; e < kVecSize; e++) { + v_elts[e] = src[vi * kVecSize + e]; + } + reinterpret_cast(dst)[vi] = v; + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + } +} + +// Vectorized fill: write `val` to `count` elements of global memory. +template +__device__ inline void vec_fill_global(T *__restrict__ dst, T val, int count, int lane_id) { + constexpr int kVecSize = VecTraits::kVecSize; + using VecType = typename BytesToType::Type; + + bool aligned = (reinterpret_cast(dst) % (sizeof(T) * kVecSize) == 0); + int aligned_count = (count / kVecSize) * kVecSize; + + if (aligned && aligned_count > 0) { + VecType v; + T *v_elts = reinterpret_cast(&v); +#pragma unroll + for (int e = 0; e < kVecSize; e++) { + v_elts[e] = val; + } + int vec_count = aligned_count / kVecSize; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + reinterpret_cast(dst)[vi] = v; + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = val; + } + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = val; + } + } +} + +// ============================================================================ +// cp.async wrappers — use hardware async copy on sm_80+, no-op on older archs. +// Always defined so callers don't need #if guards. +// ============================================================================ + +__device__ __forceinline__ void cp_async_16B(void *__restrict__ dst, const void *__restrict__ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_memcpy_async(dst, src, 16); +#else + // Scalar fallback — callers must not rely on this being async. + *static_cast(dst) = *static_cast(src); +#endif +} + +__device__ __forceinline__ void cp_async_commit() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_commit(); +#endif +} + +__device__ __forceinline__ void cp_async_wait_all() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __pipeline_wait_prior(0); +#endif +} + +// ============================================================================ +// RawAsyncLoader — double-buffered loader storing data in original type +// +// Enables cp.async for ALL data types (bf16, fp16, fp32) since no type +// conversion is needed during the copy. The kernel reads from shmem and +// casts to CompType during compute. +// ============================================================================ + +template +class RawAsyncLoader { + public: + // Shmem size calculation (usable from both host and device). + static __host__ __device__ inline size_t shmem_bytes(int count, int num_warps, int num_buffers) { + return static_cast(num_buffers) * count * num_warps * sizeof(T); + } + + // Device-side construction. + __device__ RawAsyncLoader(T *buf_base, int warp_id, int count, int num_warps, int num_buffers) + : phase_(0), double_buf_(num_buffers == 2) { + int per_buffer = count * num_warps; + buf_[0] = buf_base + warp_id * count; + buf_[1] = (num_buffers == 2) ? buf_base + per_buffer + warp_id * count : buf_[0]; + } + + __device__ __forceinline__ T *current_buf() { return buf_[phase_]; } + __device__ __forceinline__ T *next_buf() { return buf_[phase_ ^ 1]; } + __device__ __forceinline__ void flip() { + if (double_buf_) phase_ ^= 1; + } + + // Async load into the NEXT buffer (for prefetching). + __device__ void start_load(const T *__restrict__ src, int count, int lane_id) { + raw_load(src, next_buf(), count, lane_id); + } + + // Load into the CURRENT buffer (for the first load before the main loop). + __device__ void load_current(const T *__restrict__ src, int count, int lane_id) { + raw_load(src, current_buf(), count, lane_id); + } + + // Wait for pending async loads to complete. + __device__ __forceinline__ void wait() { + cp_async_wait_all(); + __syncwarp(); + } + + private: + T *buf_[2]; + int phase_; + bool double_buf_; + + // Raw copy: global → shmem, no type conversion. + // Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs) + // when both pointers are 16-byte aligned, with a scalar tail / fallback. + __device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) { + constexpr int kBytesPerCopy = 16; + constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T); + + bool src_aligned = (reinterpret_cast(src) % kBytesPerCopy == 0); + bool dst_aligned = (reinterpret_cast(dst) % kBytesPerCopy == 0); + int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy; + + if (src_aligned && dst_aligned && aligned_count > 0) { + int vec_count = aligned_count / kEltsPerCopy; + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { + cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy); + } + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + cp_async_commit(); + } else { + for (int i = lane_id; i < count; i += kThreadsPerWarp) { + dst[i] = src[i]; + } + cp_async_commit(); // No-op on sm_70; matches wait() expectation on sm_80+. + } + } +}; + +} // namespace fused_router +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 6207faad07..30411f570c 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -11,61 +11,92 @@ #include "../common.h" #include "../util/logging.h" #include "../utils.cuh" +#include "async_loader.h" #include "utils.h" namespace transformer_engine { namespace fused_router { template -__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, - int num_experts, int topk, - int score_function, float *scores, - bool *routing_map, - CompType *intermediate_output) { +__global__ void fused_score_for_moe_aux_loss_forward_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, + int score_function, float *scores, bool *routing_map, CompType *intermediate_output, + int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ - // Used variables/addresses init int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem_scores_for_aux_loss[]; - CompType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); - CompType *topk_logits_buf = - reinterpret_cast(logits_buf + num_experts * num_token_per_block); + extern __shared__ char shmem_raw_aux[]; + + // Shmem layout: logits_raw (async) | logits_work | topk_scratch + char *shmem_ptr = shmem_raw_aux; + DataType *logits_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader loader(logits_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *logits_work_buf = reinterpret_cast(shmem_ptr); + shmem_ptr += num_experts * num_token_per_block * sizeof(CompType); + + CompType *topk_logits_buf = reinterpret_cast(shmem_ptr); int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); + // The address of buffers on the current warp - CompType *local_logits = logits_buf + warp_id * num_experts; + CompType *local_logits = logits_work_buf + warp_id * num_experts; CompType *topk_logits = topk_logits_buf + warp_id * topk; int *topk_indices = topk_indices_buf + warp_id * topk; /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + loader.load_current(logits + first_token * num_experts, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token if (token_offset_cur_warp >= num_tokens) break; + loader.wait(); + DataType *raw_logits = loader.current_buf(); + + // Prefetch next round + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } + } + /*** * Section: Init buffer + Preprocess - * - Load logits → apply score function → save intermediate_output + * - Convert raw logits (DataType) → apply score function → save intermediate_output * * Fused into a single loop per score function where possible: - * score_function == 0 (sigmoid): load, sigmoid, save → shmem - * score_function == 1 (softmax): load → shmem, softmax (multi-pass), save - * score_function == 2 (sqrtsoftplus): load, save logits, sqrtsoftplus → shmem + * score_function == 0 (sigmoid): convert, sigmoid, save → shmem + * score_function == 1 (softmax): convert → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): convert, save logits, sqrtsoftplus → shmem */ int pos_offset = token_offset_cur_warp * num_experts; if (score_function == 1) { // Softmax // Apply softmax to all logits, save softmax output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = static_cast(logits[pos_offset + i]); + local_logits[i] = static_cast(raw_logits[i]); } __syncwarp(); apply_softmax_on_float(local_logits, num_experts, lane_id); @@ -75,16 +106,16 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi intermediate_output[pos_offset + i] = local_logits[i]; } } else if (score_function == 0) { // Sigmoid - // Fused: load logit → sigmoid → save sigmoid output for backward → shmem + // Fused: convert → sigmoid → save sigmoid output for backward → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - float val = sigmoid_scalar(static_cast(logits[pos_offset + i])); + float val = sigmoid_scalar(static_cast(raw_logits[i])); intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward local_logits[i] = val; } } else if (score_function == 2) { // Sqrtsoftplus - // Fused: load logit → save original logit for backward → sqrtsoftplus → shmem + // Fused: convert → save original logit for backward → sqrtsoftplus → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - float logit = static_cast(logits[pos_offset + i]); + float logit = static_cast(raw_logits[i]); intermediate_output[pos_offset + i] = logit; // Save original logits for backward local_logits[i] = sqrtsoftplus_scalar(logit); } @@ -109,15 +140,15 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __syncwarp(); // Write the routing_map to the output tensor + vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { routing_map[pos_offset + topk_indices[i]] = true; } // Write the scores to the output tensor - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[pos_offset + i] = local_logits[i]; - } - __threadfence_block(); + vec_store_global(scores + pos_offset, local_logits, num_experts, lane_id); __syncwarp(); + + loader.flip(); } } @@ -125,33 +156,39 @@ template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, float *scores, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { - // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits - + topk * num_token_per_block * sizeof(CompType) // topk_logits - + topk * num_token_per_block * sizeof(int); // topk_indices + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); + size_t scratch_shmem = topk * num_token_per_block * sizeof(CompType) + + topk * num_token_per_block * sizeof(int); + size_t other_shmem = scores_shmem + scratch_shmem; + size_t logits_single_buf = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(logits_single_buf, other_shmem); + size_t logits_raw_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + size_t shared_memory_size = logits_raw_shmem + other_shmem; check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); + kernel<<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output, num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; // switch at K=16 where naive O(K^2*E) starts to dominate if (topk < 16) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, - intermediate_output); + launch(fused_score_for_moe_aux_loss_forward_kernel); } else { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, - intermediate_output); + launch(fused_score_for_moe_aux_loss_forward_kernel); } - NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, @@ -169,14 +206,13 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, // Backward: grad_scores + intermediate_output → grad_logits. // No routing_map — all experts participate (unlike topk backward). +// Double-buffered cp.async loads both inputs. Two-pass fused approach. // -// Two-pass fused approach (eliminates the comp_buf shmem buffer): -// Pass 1 (reduction): accumulate warp-level sums for normalization/softmax bwd. -// Pass 2 (element-wise): compute per-element gradient and write to global memory. -// -// Shmem layout (W = warps/block): -// grad_buf: E × W × sizeof(CompType) — grad_scores loaded from global -// act_buf: E × W × sizeof(CompType) — intermediate_output (sigmoid/softmax out, or logits) +// Shmem layout (B = 2, W = warps/block): +// grad_buf: B × E × W × sizeof(CompType) — double-buffered async load +// act_buf: B × E × W × sizeof(CompType) — double-buffered async load +constexpr int kAuxBwdNumBuffers = 2; + template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, @@ -191,32 +227,59 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem[]; - CompType *grad_buf = reinterpret_cast(shmem); - CompType *act_buf = grad_buf + num_experts * num_token_per_block; - // The address of buffers on the current warp - CompType *local_grad = grad_buf + warp_id * num_experts; - CompType *local_act = act_buf + warp_id * num_experts; + + extern __shared__ char shmem_aux_bwd[]; + char *shmem_ptr = shmem_aux_bwd; + + CompType *grad_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, + num_token_per_block, kAuxBwdNumBuffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, + kAuxBwdNumBuffers); + + CompType *act_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, + num_token_per_block, kAuxBwdNumBuffers); /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + int pos = first_token * num_experts; + grad_loader.load_current(grad_scores + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { int token_idx = round * num_token_per_block + warp_id; if (token_idx >= num_tokens) break; int pos = token_idx * num_experts; - /*** - * Section: Load inputs to shmem - * - Load the grad_scores/intermediate_output to shmem - */ - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = grad_scores[pos + i]; - local_act[i] = intermediate_output[pos + i]; + grad_loader.wait(); + act_loader.wait(); + + CompType *raw_grad = grad_loader.current_buf(); + CompType *raw_act = act_loader.current_buf(); + + // Prefetch next round + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_scores + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + } } - __syncwarp(); /*** * Section: Pass 1 — Reduction @@ -232,8 +295,8 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int CompType sum_output_x_grad = 0.0f; for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - CompType g = local_grad[i]; - CompType act = local_act[i]; + CompType g = static_cast(raw_grad[i]); + CompType act = raw_act[i]; if (score_function == 0) { // Sigmoid // act = sigmoid output; accumulate over all experts sum_act += act; sum_grad_act += g * act; @@ -264,8 +327,8 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int * Write the grad_logits to the global mem */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - CompType g = local_grad[i]; - CompType act = local_act[i]; + CompType g = static_cast(raw_grad[i]); + CompType act = raw_act[i]; if (score_function == 0) { // Sigmoid bwd g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); @@ -279,7 +342,9 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int grad_logits[pos + i] = static_cast(g); } - __syncwarp(); + + grad_loader.flip(); + act_loader.flip(); } } @@ -288,18 +353,21 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - // Shmem: grad_buf + act_buf (no comp_buf — eliminated by fused scalar approach) - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_buf - + num_experts * num_token_per_block * sizeof(CompType); // act_buf - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_memory_size)); - fused_score_for_moe_aux_loss_backward_kernel - <<>>( - intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, - grad_logits); + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t shmem_bytes = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers); + check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); + + auto kernel = fused_score_for_moe_aux_loss_backward_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, + grad_logits); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 2142b808c4..cc1e6511b0 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -10,6 +10,7 @@ #include "../common.h" #include "../util/logging.h" +#include "async_loader.h" #include "utils.h" namespace transformer_engine { @@ -20,7 +21,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, const BiasType *expert_bias, DataType *probs, bool *routing_map, - CompType *intermediate_output) { + CompType *intermediate_output, int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -30,9 +31,19 @@ __global__ void fused_topk_with_score_function_forward_kernel( int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem[]; - CompType *scores_buf = reinterpret_cast(shmem); - CompType *topk_scores_buf = scores_buf + num_experts * num_token_per_block; + extern __shared__ char shmem_raw[]; + + // Shmem layout: logits_raw (async) | scores | topk_scratch (+ group_topk scratch) + char *shmem_ptr = shmem_raw; + DataType *logits_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader loader(logits_shmem_base, warp_id, num_experts, num_token_per_block, + num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + + CompType *scores_buf = reinterpret_cast(shmem_ptr); + shmem_ptr += num_experts * num_token_per_block * sizeof(CompType); + + CompType *topk_scores_buf = reinterpret_cast(shmem_ptr); CompType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; int *topk_indices_buf = nullptr; if (group_topk > 0) { @@ -45,29 +56,54 @@ __global__ void fused_topk_with_score_function_forward_kernel( // The address of buffers on the current warp CompType *scores = scores_buf + warp_id * num_experts; CompType *topk_scores = topk_scores_buf + warp_id * topk; - CompType *masked_scores = masked_scores_buf + warp_id * num_experts; - CompType *group_scores = group_scores_buf + warp_id * num_groups; + CompType *masked_scores = + (masked_scores_buf != nullptr) ? masked_scores_buf + warp_id * num_experts : nullptr; + CompType *group_scores = + (group_scores_buf != nullptr) ? group_scores_buf + warp_id * num_groups : nullptr; int *topk_indices = topk_indices_buf + warp_id * topk; /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + loader.load_current(logits + first_token * num_experts, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { int token_offset_cur_warp = round * num_token_per_block + warp_id; - // Each warp is responsible for one token if (token_offset_cur_warp >= num_tokens) break; + // Wait for current round's async load to complete + loader.wait(); + DataType *raw_logits = loader.current_buf(); + + // Prefetch next round (overlaps with compute below) + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } + } + /*** * Section: Init buffer + Preprocess * - Clear the global output buffers (probs, routing_map) - * - Load logits → apply score function → save intermediate_output → add expert bias + * - Convert raw logits (DataType) → apply score function → save intermediate → add bias * * Fused into a single loop per score function where possible: - * score_function == 0 (sigmoid): load, sigmoid, save, +bias → scores - * score_function == 1 (softmax): load → shmem, softmax (multi-pass), save - * score_function == 2 (sqrtsoftplus): load, save logits, sqrtsoftplus, +bias → scores + * score_function == 0 (sigmoid): convert, sigmoid, save, +bias → scores + * score_function == 1 (softmax): convert → shmem, softmax (multi-pass), save + * score_function == 2 (sqrtsoftplus): convert, save logits, sqrtsoftplus, +bias → scores * * Expert bias is only used with sigmoid/sqrtsoftplus and is fused into * the same loop that computes the score. @@ -75,16 +111,14 @@ __global__ void fused_topk_with_score_function_forward_kernel( int pos_offset = token_offset_cur_warp * num_experts; // Clear the probs/routing_map (num_experts) - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - probs[pos_offset + i] = 0.0; - routing_map[pos_offset + i] = false; - } + vec_fill_global(probs + pos_offset, static_cast(0), num_experts, lane_id); + vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); if (score_function == 1) { // Softmax if (use_pre_softmax) { // Pre-softmax: apply softmax to all logits before topk, save for backward. for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] = static_cast(logits[pos_offset + i]); + scores[i] = static_cast(raw_logits[i]); } __syncwarp(); apply_softmax_on_float(scores, num_experts, lane_id); @@ -97,22 +131,22 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Post-softmax: softmax applied after topk; init intermediate to -inf // (only the topk positions will be filled in the postprocess section). for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] = static_cast(logits[pos_offset + i]); + scores[i] = static_cast(raw_logits[i]); intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } } else if (score_function == 0) { // Sigmoid - // Fused: load logit → sigmoid → save sigmoid output for backward → add bias → scores + // Fused: convert → sigmoid → save sigmoid output for backward → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - float val = sigmoid_scalar(static_cast(logits[pos_offset + i])); + float val = sigmoid_scalar(static_cast(raw_logits[i])); intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward if (expert_bias) val += static_cast(expert_bias[i]); scores[i] = val; } } else if (score_function == 2) { // Sqrtsoftplus - // Fused: load logit → save original logit for backward → sqrtsoftplus → add bias → scores + // Fused: convert → save original logit for backward → sqrtsoftplus → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - float logit = static_cast(logits[pos_offset + i]); + float logit = static_cast(raw_logits[i]); intermediate_output[pos_offset + i] = logit; // Save original logits for backward float val = sqrtsoftplus_scalar(logit); if (expert_bias) val += static_cast(expert_bias[i]); @@ -230,8 +264,9 @@ __global__ void fused_topk_with_score_function_forward_kernel( routing_map[pos_offset + topk_indices[i]] = true; probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; } - __threadfence_block(); __syncwarp(); + + loader.flip(); } } @@ -242,35 +277,44 @@ void fused_topk_with_score_function_forward_kernel_launcher( const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // scores - + topk * num_token_per_block * sizeof(CompType) // topk_scores - + topk * num_token_per_block * sizeof(int); // topk_indices + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); + size_t scratch_shmem = topk * num_token_per_block * sizeof(CompType) + + topk * num_token_per_block * sizeof(int); if (group_topk > 0) { - shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores - shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores + scratch_shmem += num_groups * num_token_per_block * sizeof(CompType); + scratch_shmem += num_experts * num_token_per_block * sizeof(CompType); } + size_t other_shmem = scores_shmem + scratch_shmem; + size_t logits_single_buf = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(logits_single_buf, other_shmem); + size_t logits_raw_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + size_t shared_memory_size = logits_raw_shmem + other_shmem; check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output, + num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; // switch at K=16 where naive O(K^2*E) starts to dominate if (topk < 16) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + launch(fused_topk_with_score_function_forward_kernel); } else { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel - <<>>( - logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + launch(fused_topk_with_score_function_forward_kernel); } - NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, @@ -294,14 +338,17 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, // Backward: grad_probs + intermediate_output + routing_map → grad_logits. // -// Two-pass fused approach (eliminates the comp_buf shmem buffer): +// Double-buffered cp.async loads all 3 inputs in original types. Two-pass +// fused approach (eliminates the comp_buf shmem buffer): // Pass 1 (reduction): accumulate warp-level sums needed by normalization/softmax bwd. // Pass 2 (element-wise): compute per-element gradient and write to global memory. // -// Shmem layout (W = warps/block): -// grad_buf: E × W × sizeof(CompType) — grad_probs loaded from global -// act_buf: E × W × sizeof(CompType) — intermediate_output (sigmoid/softmax out, or logits) -// mask_buf: E × W × sizeof(bool) — routing_map from forward +// Shmem layout (B = 2, W = warps/block): +// grad_raw: B × E × W × sizeof(DataType) — double-buffered async load +// act_buf: B × E × W × sizeof(CompType) — double-buffered async load +// mask_buf: B × E × W × sizeof(bool) — double-buffered async load +constexpr int kBwdNumBuffers = 2; + template __global__ void fused_topk_with_score_function_backward_kernel( // Inputs tensor @@ -319,35 +366,72 @@ __global__ void fused_topk_with_score_function_backward_kernel( int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem[]; - CompType *grad_buf = reinterpret_cast(shmem); - CompType *act_buf = grad_buf + num_experts * num_token_per_block; - bool *mask_buf = reinterpret_cast(act_buf + num_experts * num_token_per_block); - // The address of buffers on the current warp - CompType *local_grad = grad_buf + warp_id * num_experts; - CompType *local_act = act_buf + warp_id * num_experts; - bool *local_mask = mask_buf + warp_id * num_experts; + + extern __shared__ char shmem_bwd[]; + char *shmem_ptr = shmem_bwd; + + DataType *grad_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, + num_token_per_block, kBwdNumBuffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, + kBwdNumBuffers); + + CompType *act_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, + num_token_per_block, kBwdNumBuffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, + kBwdNumBuffers); + + bool *mask_shmem_base = reinterpret_cast(shmem_ptr); + RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, + num_token_per_block, kBwdNumBuffers); /*** - * Section: Main Loop + * Section: Main Loop — persistent grid with double-buffered async load * - Each warp is responsible for one token */ int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; - for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int first_round = blockIdx.x; + if (first_round >= total_round) return; + + // Kick off first async load + { + int first_token = first_round * num_token_per_block + warp_id; + if (first_token < num_tokens) { + int pos = first_token * num_experts; + grad_loader.load_current(grad_probs + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + mask_loader.load_current(routing_map + pos, num_experts, lane_id); + } + } + + for (int round = first_round; round < total_round; round += gridDim.x) { int token_idx = round * num_token_per_block + warp_id; if (token_idx >= num_tokens) break; int pos = token_idx * num_experts; /*** - * Section: Load inputs to shmem - * - Load the grad_probs/intermediate_output/routing_map to shmem + * Section: Wait for async load + prefetch next round */ - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = static_cast(grad_probs[pos + i]); - local_act[i] = intermediate_output[pos + i]; - local_mask[i] = routing_map[pos + i]; + grad_loader.wait(); + act_loader.wait(); + mask_loader.wait(); + + DataType *raw_grad = grad_loader.current_buf(); + CompType *local_act = act_loader.current_buf(); + bool *local_mask = mask_loader.current_buf(); + + // Prefetch next round + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_probs + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + mask_loader.start_load(routing_map + next_pos, num_experts, lane_id); + } } - __syncwarp(); /*** * Section: Pass 1 — Reduction @@ -366,7 +450,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( || (score_function == 1); if (need_reduce) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - CompType g = local_grad[i] * scaling_factor; + CompType g = static_cast(raw_grad[i]) * scaling_factor; CompType act = local_act[i]; bool routed = local_mask[i]; @@ -410,7 +494,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( * Write the grad_logits to the global mem */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - CompType g = local_grad[i] * scaling_factor; + CompType g = static_cast(raw_grad[i]) * scaling_factor; CompType act = local_act[i]; bool routed = local_mask[i]; @@ -442,7 +526,10 @@ __global__ void fused_topk_with_score_function_backward_kernel( grad_logits[pos + i] = static_cast(g); } - __syncwarp(); + + grad_loader.flip(); + act_loader.flip(); + mask_loader.flip(); } } @@ -452,19 +539,22 @@ void fused_topk_with_score_function_backward_kernel_launcher( int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; - size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - // Shmem: grad_buf + act_buf + mask_buf (no comp_buf — eliminated by fused scalar approach) - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_buf - + num_experts * num_token_per_block * sizeof(CompType) // act_buf - + num_experts * num_token_per_block * sizeof(bool); // mask_buf - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_memory_size)); - fused_topk_with_score_function_backward_kernel - <<>>( - routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, - use_pre_softmax, scaling_factor, score_function, grad_logits); + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + size_t shmem_bytes = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); + check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); + + auto kernel = fused_topk_with_score_function_backward_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, score_function, grad_logits); NVTE_CHECK_CUDA(cudaGetLastError()); } From 87a0cc3b32fd564202da8a844f9ed279c21f88b7 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 03/15] [Common] Pack radix topk histogram into 8-bit fields Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32 registers using 8-bit fields (4 counters per register). Eliminates massive register spill to local memory on large kernels (81% of L1 traffic on E=2304, K=36). Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks in both forward launchers to guard against 8-bit overflow. All current MoE configurations (max E=2304) are well within this limit. Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 3 + .../fused_topk_with_score_function.cu | 3 + .../common/fused_router/utils.h | 69 +++++++++++-------- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 30411f570c..d4301ee2e3 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -187,6 +187,9 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( if (topk < 16) { launch(fused_score_for_moe_aux_loss_forward_kernel); } else { + NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, + "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, + " (packed 8-bit histogram), got ", num_experts, "."); launch(fused_score_for_moe_aux_loss_forward_kernel); } } diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index cc1e6511b0..24485a3c36 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -312,6 +312,9 @@ void fused_topk_with_score_function_forward_kernel_launcher( launch(fused_topk_with_score_function_forward_kernel); } else { + NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, + "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, + " (packed 8-bit histogram), got ", num_experts, "."); launch(fused_topk_with_score_function_forward_kernel); } diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 2fb3752f0c..7d0d34a969 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -185,6 +185,11 @@ enum class TopkFuncType { Radix = 1, }; +// Maximum num_experts supported by the packed 8-bit radix topk histogram. +// Each thread processes ceil(data_size/32) elements per bucket. With 8-bit +// counters the max per-thread count is 255, so data_size <= 255 * 32 = 8160. +constexpr int kMaxExpertsRadixTopk = 255 * 32; // 8160 + /******************************************************************************* * radix_topk_and_mask — Warp-level radix-selection based top-K * @@ -210,11 +215,16 @@ enum class TopkFuncType { * broken by ascending index for determinism matching torch.topk). * Write indices and scores to the output arrays. * + * Register pressure optimization: pack 16 bucket counts into 4 registers + * using 8-bit fields (4 counters per u32). The original counts[16] + + * total_counts[16] required 32 registers, causing massive spill to local + * memory on large kernels (81% of L1 traffic on E=2304, K=36). + * * Tie-breaking: (value DESC, index ASC) — matches torch.topk behavior. * * Constraints: * - 0 < topk <= data_size - * - No upper limit on topk or data_size (unlike v1's 128 cap) + * - data_size <= kMaxExpertsRadixTopk (8160) to avoid 8-bit overflow * - scores must be in shared memory accessible by the warp * * Complexity: 9 × O(E/32) = O(E) per warp, independent of K. @@ -222,9 +232,6 @@ enum class TopkFuncType { __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int topk, int *topk_indices, CompType *topk_scores, int lane_id) { - // assert(topk > 0 && "naive_topk_and_mask_v2: topk must be positive"); - // assert(topk <= data_size && "naive_topk_and_mask_v2: topk exceeds data_size"); - constexpr int RADIX_BITS = 4; constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 16 buckets constexpr int RADIX_MASK = RADIX_SIZE - 1; // 0xF @@ -232,54 +239,58 @@ __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int // ========================================================================= // Phase 1: Radix selection — find the bit pattern of the K-th largest value + // + // Packed counters: 16 bucket counts are stored in 4 × u32 registers using + // 8-bit fields (4 counters per register). Bucket b is in byte (b % 4) of + // packed[b / 4]. This reduces register usage from 32 (counts[16] + + // total_counts[16]) to 4 registers. + // + // Max per-thread count per bucket = ceil(data_size / 32). + // For E=2304: max 72 — fits in 8 bits (max 255). + // Constraint: data_size must be <= kMaxExpertsRadixTopk (8160). // ========================================================================= unsigned int desired = 0; // accumulated bit pattern of the K-th value unsigned int desired_mask = 0; // bits determined so far int k_remaining = topk; // how many more elements we need to skip +#pragma unroll 1 for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { int digit_pos = pass * RADIX_BITS; - // Each thread counts elements per bucket that match the desired pattern - unsigned int counts[RADIX_SIZE]; -#pragma unroll - for (int b = 0; b < RADIX_SIZE; b++) { - counts[b] = 0; - } + // Packed counters: packed[i] holds 4 × 8-bit counts for buckets [4i..4i+3]. + // Bucket b is in byte (b % 4) of packed[b / 4]. + unsigned int packed[4] = {0, 0, 0, 0}; for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { unsigned int u = float_to_ordered_uint(scores[i]); - // Check if this element matches the desired pattern on already-decided bits if ((u & desired_mask) == desired) { int bucket = (u >> digit_pos) & RADIX_MASK; - counts[bucket]++; + int pack_idx = bucket >> 2; // bucket / 4 + int byte_idx = bucket & 3; // bucket % 4 + int shift = byte_idx << 3; // byte_idx * 8 + packed[pack_idx] += (1u << shift); // increment the 8-bit field } } - // Warp-reduce each bucket count across all 32 lanes - unsigned int total_counts[RADIX_SIZE]; -#pragma unroll - for (int b = 0; b < RADIX_SIZE; b++) { - unsigned int c = warp_allreduce_sum(counts[b]); - total_counts[b] = c; // same value on all lanes after full reduction - } - - // Scan buckets from LARGEST digit value (15) to smallest (0). - // We're looking for the top-K largest, so we want the highest-valued - // bucket first. Accumulate counts until we find the bucket containing - // the k_remaining-th element. + // Warp-reduce each bucket, then scan to find the target bucket. int target_bucket = 0; + int k_remaining_copy = k_remaining; +#pragma unroll for (int b = RADIX_SIZE - 1; b >= 0; b--) { - unsigned int bc = total_counts[b]; - if (bc < static_cast(k_remaining)) { - // All elements in this bucket are in the top set; skip them - k_remaining -= bc; + // Unpack: extract 8-bit count for bucket b from packed[b/4] + int pack_idx = b >> 2; + int shift = (b & 3) << 3; + unsigned int my_count = (packed[pack_idx] >> shift) & 0xFFu; + // Warp-reduce to get total count across all 32 lanes + unsigned int bc = warp_allreduce_sum(my_count); + if (bc < static_cast(k_remaining_copy)) { + k_remaining_copy -= bc; } else { - // The K-th element is in this bucket target_bucket = b; break; } } + k_remaining = k_remaining_copy; // Update the desired pattern and mask desired |= (static_cast(target_bucket) << digit_pos); From bfd694a46f034fb8aeb8f2a92bb25a5f90526670 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 04/15] [Common] Template fused router kernels on ScoreFunc for compile-time dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace runtime score_function parameter in all 4 kernel __global__ functions with template int ScoreFunc (0=sigmoid, 1=softmax, 2=sqrtsoftplus). All score_function branches now use if constexpr, eliminating dead-code register pressure and branch overhead. Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations per DataType. Backward launchers dispatch on ScoreFunc = 3 instantiations per DataType. Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 98 +++++++---- .../fused_topk_with_score_function.cu | 159 ++++++++++++------ 2 files changed, 172 insertions(+), 85 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index d4301ee2e3..458d8b458f 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -17,10 +17,10 @@ namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_score_for_moe_aux_loss_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, - int score_function, float *scores, bool *routing_map, CompType *intermediate_output, + float *scores, bool *routing_map, CompType *intermediate_output, int num_buffers) { /*** * Section: Global Variables/Addresses init @@ -93,7 +93,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel( */ int pos_offset = token_offset_cur_warp * num_experts; - if (score_function == 1) { // Softmax + if constexpr (ScoreFunc == 1) { // Softmax // Apply softmax to all logits, save softmax output for backward for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] = static_cast(raw_logits[i]); @@ -105,14 +105,14 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } else if (score_function == 0) { // Sigmoid + } else if constexpr (ScoreFunc == 0) { // Sigmoid // Fused: convert → sigmoid → save sigmoid output for backward → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { float val = sigmoid_scalar(static_cast(raw_logits[i])); intermediate_output[pos_offset + i] = val; // Save sigmoid output for backward local_logits[i] = val; } - } else if (score_function == 2) { // Sqrtsoftplus + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // Fused: convert → save original logit for backward → sqrtsoftplus → shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { float logit = static_cast(raw_logits[i]); @@ -123,7 +123,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel( __syncwarp(); // Sigmoid/Sqrtsoftplus post-processing: normalize scores to sum to 1 - if (score_function == 0 || score_function == 2) { + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { @@ -177,20 +177,44 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); kernel<<>>( - logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + logits, num_tokens, num_experts, topk, scores, routing_map, intermediate_output, num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; - // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; - // switch at K=16 where naive O(K^2*E) starts to dominate + // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType). + // Radix selection is O(E), independent of K; switch at K=16 where naive O(K^2*E) dominates. if (topk < 16) { - launch(fused_score_for_moe_aux_loss_forward_kernel); + switch (score_function) { + case 0: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 1: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 2: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } else { NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, " (packed 8-bit histogram), got ", num_experts, "."); - launch(fused_score_for_moe_aux_loss_forward_kernel); + switch (score_function) { + case 0: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 1: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + case 2: + launch(fused_score_for_moe_aux_loss_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } } @@ -216,11 +240,11 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, // act_buf: B × E × W × sizeof(CompType) — double-buffered async load constexpr int kAuxBwdNumBuffers = 2; -template +template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, - int topk, int score_function, + int topk, DataType *grad_logits) { /*** * Section: Global Variables/Addresses init @@ -300,23 +324,23 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { CompType g = static_cast(raw_grad[i]); CompType act = raw_act[i]; - if (score_function == 0) { // Sigmoid + if constexpr (ScoreFunc == 0) { // Sigmoid // act = sigmoid output; accumulate over all experts sum_act += act; sum_grad_act += g * act; - } else if (score_function == 2) { // Sqrtsoftplus + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // act = original logit; recompute sqrtsoftplus to get activation CompType v = sqrtsoftplus_scalar(act); sum_act += v; sum_grad_act += g * v; - } else if (score_function == 1) { // Softmax + } else if constexpr (ScoreFunc == 1) { // Softmax // act = softmax output sum_output_x_grad += g * act; } } - if (score_function == 0 || score_function == 2) { + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { sum_act = warp_allreduce_sum(sum_act); sum_grad_act = warp_allreduce_sum(sum_grad_act); } - if (score_function == 1) { + if constexpr (ScoreFunc == 1) { sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } @@ -333,13 +357,13 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int CompType g = static_cast(raw_grad[i]); CompType act = raw_act[i]; - if (score_function == 0) { // Sigmoid bwd + if constexpr (ScoreFunc == 0) { // Sigmoid bwd g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); g = sigmoid_bwd_scalar(g, act); - } else if (score_function == 2) { // Sqrtsoftplus bwd + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus bwd g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); - } else if (score_function == 1) { // Softmax bwd + } else if constexpr (ScoreFunc == 1) { // Softmax bwd g = softmax_bwd_scalar(g, act, sum_output_x_grad); } @@ -363,15 +387,29 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers); check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); - auto kernel = fused_score_for_moe_aux_loss_backward_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); - size_t grid_size = - compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); - kernel<<>>( - intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, - grad_logits); - NVTE_CHECK_CUDA(cudaGetLastError()); + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + intermediate_output, grad_scores, num_tokens, num_experts, topk, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + switch (score_function) { + case 0: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + case 1: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + case 2: + launch(fused_score_for_moe_aux_loss_backward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 24485a3c36..8c88b8f7e3 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -16,10 +16,11 @@ namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, - int num_groups, int group_topk, float scaling_factor, int score_function, + int num_groups, int group_topk, float scaling_factor, const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, int num_buffers) { /*** @@ -114,7 +115,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( vec_fill_global(probs + pos_offset, static_cast(0), num_experts, lane_id); vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); - if (score_function == 1) { // Softmax + if constexpr (ScoreFunc == 1) { // Softmax if (use_pre_softmax) { // Pre-softmax: apply softmax to all logits before topk, save for backward. for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { @@ -135,7 +136,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } - } else if (score_function == 0) { // Sigmoid + } else if constexpr (ScoreFunc == 0) { // Sigmoid // Fused: convert → sigmoid → save sigmoid output for backward → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { float val = sigmoid_scalar(static_cast(raw_logits[i])); @@ -143,7 +144,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( if (expert_bias) val += static_cast(expert_bias[i]); scores[i] = val; } - } else if (score_function == 2) { // Sqrtsoftplus + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // Fused: convert → save original logit for backward → sqrtsoftplus → add bias → scores for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { float logit = static_cast(raw_logits[i]); @@ -229,27 +230,30 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - Write the result with scaling_factor */ // Revert Expert bias from the topk scores - if (expert_bias && (score_function == 0 || score_function == 2)) { - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + if (expert_bias) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + } + __syncwarp(); } - __syncwarp(); } - // score_function == 1 means softmax - if (!use_pre_softmax && score_function == 1) { - // Apply softmax to the topk logits - apply_softmax_on_float(topk_scores, topk, lane_id); - __syncwarp(); - // Save the softmax output for backward - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + if constexpr (ScoreFunc == 1) { + if (!use_pre_softmax) { + // Apply softmax to the topk logits + apply_softmax_on_float(topk_scores, topk, lane_id); + __syncwarp(); + // Save the softmax output for backward + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + } + __syncwarp(); } - __syncwarp(); } // Sigmoid/Sqrtsoftplus post-processing when topk > 1 - if (score_function == 0 || score_function == 2) { + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { if (topk > 1) { CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { @@ -301,22 +305,50 @@ void fused_topk_with_score_function_forward_kernel_launcher( compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); kernel<<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output, + scaling_factor, expert_bias, probs, routing_map, intermediate_output, num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; - // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; - // switch at K=16 where naive O(K^2*E) starts to dominate + // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType × BiasType). + // Radix selection is O(E), independent of K; switch at K=16 where naive O(K^2*E) dominates. if (topk < 16) { - launch(fused_topk_with_score_function_forward_kernel); + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } else { NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, " (packed 8-bit histogram), got ", num_experts, "."); - launch(fused_topk_with_score_function_forward_kernel); + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } } @@ -352,14 +384,10 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, // mask_buf: B × E × W × sizeof(bool) — double-buffered async load constexpr int kBwdNumBuffers = 2; -template +template __global__ void fused_topk_with_score_function_backward_kernel( - // Inputs tensor const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, - // Other parameters int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, - int score_function, - // Output tensor DataType *grad_logits) { /*** * Section: Global Variables/Addresses init @@ -449,24 +477,24 @@ __global__ void fused_topk_with_score_function_backward_kernel( CompType sum_grad_act = 0.0f; CompType sum_output_x_grad = 0.0f; - bool need_reduce = ((score_function == 0 || score_function == 2) && topk > 1) - || (score_function == 1); + bool need_reduce = ((ScoreFunc == 0 || ScoreFunc == 2) && topk > 1) + || (ScoreFunc == 1); if (need_reduce) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { CompType g = static_cast(raw_grad[i]) * scaling_factor; CompType act = local_act[i]; bool routed = local_mask[i]; - if (score_function == 0) { // Sigmoid + if constexpr (ScoreFunc == 0) { // Sigmoid // act = sigmoid output; accumulate over routed experts only if (routed) { sum_act += act; sum_grad_act += g * act; } - } else if (score_function == 2) { // Sqrtsoftplus + } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // act = original logit; recompute sqrtsoftplus to get activation if (routed) { CompType v = sqrtsoftplus_scalar(act); sum_act += v; sum_grad_act += g * v; } - } else if (score_function == 1) { // Softmax + } else if constexpr (ScoreFunc == 1) { // Softmax if (!use_pre_softmax) { // Post-softmax: act = softmax output (routed positions only) if (routed) sum_output_x_grad += g * act; @@ -476,11 +504,11 @@ __global__ void fused_topk_with_score_function_backward_kernel( } } } - if (score_function == 0 || score_function == 2) { + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { sum_act = warp_allreduce_sum(sum_act); sum_grad_act = warp_allreduce_sum(sum_grad_act); } - if (score_function == 1) { + if constexpr (ScoreFunc == 1) { sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); } } @@ -502,28 +530,34 @@ __global__ void fused_topk_with_score_function_backward_kernel( bool routed = local_mask[i]; // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if ((score_function == 0 || score_function == 2) && topk > 1) { - g = normalize_bwd_scalar(g, routed, sum_act, sum_grad_act); + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + if (topk > 1) { + g = normalize_bwd_scalar(g, routed, sum_act, sum_grad_act); + } } // Softmax bwd if use_pre_softmax is false (routed subset only) - if (score_function == 1 && !use_pre_softmax) { - g = routed ? softmax_bwd_scalar(g, act, sum_output_x_grad) : 0.0f; + if constexpr (ScoreFunc == 1) { + if (!use_pre_softmax) { + g = routed ? softmax_bwd_scalar(g, act, sum_output_x_grad) : 0.0f; + } } // Backward of topk: mask the unselected position in the grad if (!routed) g = 0.0f; // Pre-softmax bwd (all experts participate) - if (score_function == 1 && use_pre_softmax) { - g = softmax_bwd_scalar(g, act, sum_output_x_grad); + if constexpr (ScoreFunc == 1) { + if (use_pre_softmax) { + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } } // Sigmoid bwd: dy/dx = y * (1 - y), where y = sigmoid output - if (score_function == 0) { + if constexpr (ScoreFunc == 0) { g = sigmoid_bwd_scalar(g, act); // Sqrtsoftplus bwd: dy/dx = sigmoid(x) / (2 * y), where x = original logit - } else if (score_function == 2) { + } else if constexpr (ScoreFunc == 2) { g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); } @@ -550,15 +584,30 @@ void fused_topk_with_score_function_backward_kernel_launcher( RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); - auto kernel = fused_topk_with_score_function_backward_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); - size_t grid_size = - compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); - kernel<<>>( - routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, - use_pre_softmax, scaling_factor, score_function, grad_logits); - NVTE_CHECK_CUDA(cudaGetLastError()); + auto launch = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + kernel<<>>( + routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_backward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_backward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_backward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } } void fused_topk_with_score_function_backward(const Tensor &routing_map, From c9efdfa8b9dccdb038533754813ab73f94536173 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 05/15] [Common] Add NVTE_RADIX_TOPK_THRESHOLD for topk algorithm selection Fix broken topk < 0 threshold (radix was always selected, naive unreachable). Replace with configurable NVTE_RADIX_TOPK_THRESHOLD env var (default 0, i.e. always use radix). Set to 16 to restore the old naive-for-small-K behavior. Uses the standard TE pattern: static local + getenv (read once, cached for process lifetime). Signed-off-by: Harry Zhou --- .../fused_router/fused_score_for_moe_aux_loss.cu | 12 ++++++++++-- .../fused_router/fused_topk_with_score_function.cu | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 458d8b458f..c81760b608 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -10,6 +10,7 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/system.h" #include "../utils.cuh" #include "async_loader.h" #include "utils.h" @@ -17,6 +18,12 @@ namespace transformer_engine { namespace fused_router { +// Reuse the same threshold as the topk kernel (see fused_topk_with_score_function.cu). +static int get_radix_topk_threshold() { + static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); + return threshold; +} + template __global__ void fused_score_for_moe_aux_loss_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, @@ -183,8 +190,9 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( }; // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType). - // Radix selection is O(E), independent of K; switch at K=16 where naive O(K^2*E) dominates. - if (topk < 16) { + // Radix selection is O(E), independent of K; naive is O(K*E). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 16). + if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: launch(fused_score_for_moe_aux_loss_forward_kernel); diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 8c88b8f7e3..024231d8db 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -10,12 +10,21 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/system.h" #include "async_loader.h" #include "utils.h" namespace transformer_engine { namespace fused_router { +// Topk values below this threshold use naive O(K*E) selection; +// at or above it, use radix O(E) selection. Configurable via +// NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). +static int get_radix_topk_threshold() { + static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); + return threshold; +} + template __global__ void fused_topk_with_score_function_forward_kernel( @@ -311,8 +320,9 @@ void fused_topk_with_score_function_forward_kernel_launcher( }; // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType × BiasType). - // Radix selection is O(E), independent of K; switch at K=16 where naive O(K^2*E) dominates. - if (topk < 16) { + // Radix selection is O(E), independent of K; naive is O(K*E). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 16). + if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: launch(fused_topk_with_score_function_forward_kernel Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 06/15] [Common] Fix single-buffer prefetch clobbering in forward kernels When choose_num_buffers() returns 1 (shmem too tight for double buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1] alias the same memory. The prefetch via start_load(next_buf()) then overwrites the current buffer while compute is still reading it. Fix: guard the prefetch on num_buffers > 1. When single-buffered, load the current round's data at the top of each iteration instead. The first round's load_current is still issued before the loop. Backward kernels are unaffected (always kBwdNumBuffers=2). Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 19 +++++++++++++------ .../fused_topk_with_score_function.cu | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index c81760b608..f97884e636 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -77,15 +77,22 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel( int token_offset_cur_warp = round * num_token_per_block + warp_id; if (token_offset_cur_warp >= num_tokens) break; + // Single-buffer: load current round here (no prefetch possible) + if (num_buffers == 1 && round != first_round) { + loader.load_current(logits + token_offset_cur_warp * num_experts, num_experts, lane_id); + } + loader.wait(); DataType *raw_logits = loader.current_buf(); - // Prefetch next round - int next_round = round + gridDim.x; - if (next_round < total_round) { - int next_token = next_round * num_token_per_block + warp_id; - if (next_token < num_tokens) { - loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + // Prefetch next round (only when double-buffered) + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } } } diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 024231d8db..b506be4231 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -92,16 +92,23 @@ __global__ void fused_topk_with_score_function_forward_kernel( int token_offset_cur_warp = round * num_token_per_block + warp_id; if (token_offset_cur_warp >= num_tokens) break; + // Single-buffer: load current round here (no prefetch possible) + if (num_buffers == 1 && round != first_round) { + loader.load_current(logits + token_offset_cur_warp * num_experts, num_experts, lane_id); + } + // Wait for current round's async load to complete loader.wait(); DataType *raw_logits = loader.current_buf(); - // Prefetch next round (overlaps with compute below) - int next_round = round + gridDim.x; - if (next_round < total_round) { - int next_token = next_round * num_token_per_block + warp_id; - if (next_token < num_tokens) { - loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + // Prefetch next round (only when double-buffered, overlaps with compute) + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + loader.start_load(logits + next_token * num_experts, num_experts, lane_id); + } } } From b8a02dd083242f5699cc360d9901846b754f52bc Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:27:10 +0800 Subject: [PATCH 07/15] code formatting Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 45 ++++++------- .../fused_topk_with_score_function.cu | 64 +++++++++---------- .../common/fused_router/utils.h | 4 +- 3 files changed, 57 insertions(+), 56 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index f97884e636..60fb718949 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -25,10 +25,11 @@ static int get_radix_topk_threshold() { } template -__global__ void fused_score_for_moe_aux_loss_forward_kernel( - const DataType *logits, int num_tokens, int num_experts, int topk, - float *scores, bool *routing_map, CompType *intermediate_output, - int num_buffers) { +__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, + int num_experts, int topk, + float *scores, bool *routing_map, + CompType *intermediate_output, + int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -174,8 +175,8 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); - size_t scratch_shmem = topk * num_token_per_block * sizeof(CompType) - + topk * num_token_per_block * sizeof(int); + size_t scratch_shmem = + topk * num_token_per_block * sizeof(CompType) + topk * num_token_per_block * sizeof(int); size_t other_shmem = scores_shmem + scratch_shmem; size_t logits_single_buf = RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); @@ -191,8 +192,8 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); kernel<<>>( - logits, num_tokens, num_experts, topk, scores, routing_map, - intermediate_output, num_buffers); + logits, num_tokens, num_experts, topk, scores, routing_map, intermediate_output, + num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; @@ -259,8 +260,7 @@ template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, - int topk, - DataType *grad_logits) { + int topk, DataType *grad_logits) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -274,14 +274,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int char *shmem_ptr = shmem_aux_bwd; CompType *grad_shmem_base = reinterpret_cast(shmem_ptr); - RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, - num_token_per_block, kAuxBwdNumBuffers); - shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, - kAuxBwdNumBuffers); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, + kAuxBwdNumBuffers); + shmem_ptr += + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); - RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, - num_token_per_block, kAuxBwdNumBuffers); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, + kAuxBwdNumBuffers); /*** * Section: Main Loop — persistent grid with double-buffered async load @@ -341,11 +341,13 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int CompType act = raw_act[i]; if constexpr (ScoreFunc == 0) { // Sigmoid // act = sigmoid output; accumulate over all experts - sum_act += act; sum_grad_act += g * act; + sum_act += act; + sum_grad_act += g * act; } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // act = original logit; recompute sqrtsoftplus to get activation CompType v = sqrtsoftplus_scalar(act); - sum_act += v; sum_grad_act += g * v; + sum_act += v; + sum_grad_act += g * v; } else if constexpr (ScoreFunc == 1) { // Softmax // act = softmax output sum_output_x_grad += g * act; @@ -403,10 +405,9 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); auto launch = [&](auto kernel) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); - size_t grid_size = - compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); kernel<<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, grad_logits); NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index b506be4231..5dcccd3e77 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -29,9 +29,8 @@ template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, - int num_groups, int group_topk, float scaling_factor, - const BiasType *expert_bias, DataType *probs, bool *routing_map, - CompType *intermediate_output, int num_buffers) { + int num_groups, int group_topk, float scaling_factor, const BiasType *expert_bias, + DataType *probs, bool *routing_map, CompType *intermediate_output, int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -299,8 +298,8 @@ void fused_topk_with_score_function_forward_kernel_launcher( size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); - size_t scratch_shmem = topk * num_token_per_block * sizeof(CompType) - + topk * num_token_per_block * sizeof(int); + size_t scratch_shmem = + topk * num_token_per_block * sizeof(CompType) + topk * num_token_per_block * sizeof(int); if (group_topk > 0) { scratch_shmem += num_groups * num_token_per_block * sizeof(CompType); scratch_shmem += num_experts * num_token_per_block * sizeof(CompType); @@ -321,8 +320,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); kernel<<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, expert_bias, probs, routing_map, intermediate_output, - num_buffers); + scaling_factor, expert_bias, probs, routing_map, intermediate_output, num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; @@ -333,15 +331,15 @@ void fused_topk_with_score_function_forward_kernel_launcher( switch (score_function) { case 0: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Naive, 0>); break; case 1: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Naive, 1>); break; case 2: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Naive, 2>); break; default: NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); @@ -353,15 +351,15 @@ void fused_topk_with_score_function_forward_kernel_launcher( switch (score_function) { case 0: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Radix, 0>); break; case 1: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Radix, 1>); break; case 2: launch(fused_topk_with_score_function_forward_kernel); + TopkFuncType::Radix, 2>); break; default: NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); @@ -419,20 +417,20 @@ __global__ void fused_topk_with_score_function_backward_kernel( char *shmem_ptr = shmem_bwd; DataType *grad_shmem_base = reinterpret_cast(shmem_ptr); - RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, - num_token_per_block, kBwdNumBuffers); - shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, - kBwdNumBuffers); + RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, + kBwdNumBuffers); + shmem_ptr += + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); - RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, - num_token_per_block, kBwdNumBuffers); - shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, - kBwdNumBuffers); + RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, + kBwdNumBuffers); + shmem_ptr += + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); bool *mask_shmem_base = reinterpret_cast(shmem_ptr); - RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, - num_token_per_block, kBwdNumBuffers); + RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, num_token_per_block, + kBwdNumBuffers); /*** * Section: Main Loop — persistent grid with double-buffered async load @@ -494,8 +492,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( CompType sum_grad_act = 0.0f; CompType sum_output_x_grad = 0.0f; - bool need_reduce = ((ScoreFunc == 0 || ScoreFunc == 2) && topk > 1) - || (ScoreFunc == 1); + bool need_reduce = ((ScoreFunc == 0 || ScoreFunc == 2) && topk > 1) || (ScoreFunc == 1); if (need_reduce) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { CompType g = static_cast(raw_grad[i]) * scaling_factor; @@ -504,12 +501,16 @@ __global__ void fused_topk_with_score_function_backward_kernel( if constexpr (ScoreFunc == 0) { // Sigmoid // act = sigmoid output; accumulate over routed experts only - if (routed) { sum_act += act; sum_grad_act += g * act; } + if (routed) { + sum_act += act; + sum_grad_act += g * act; + } } else if constexpr (ScoreFunc == 2) { // Sqrtsoftplus // act = original logit; recompute sqrtsoftplus to get activation if (routed) { CompType v = sqrtsoftplus_scalar(act); - sum_act += v; sum_grad_act += g * v; + sum_act += v; + sum_grad_act += g * v; } } else if constexpr (ScoreFunc == 1) { // Softmax if (!use_pre_softmax) { @@ -573,7 +574,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( // Sigmoid bwd: dy/dx = y * (1 - y), where y = sigmoid output if constexpr (ScoreFunc == 0) { g = sigmoid_bwd_scalar(g, act); - // Sqrtsoftplus bwd: dy/dx = sigmoid(x) / (2 * y), where x = original logit + // Sqrtsoftplus bwd: dy/dx = sigmoid(x) / (2 * y), where x = original logit } else if constexpr (ScoreFunc == 2) { g = sqrtsoftplus_bwd_scalar(g, act, sqrtsoftplus_scalar(act)); } @@ -602,10 +603,9 @@ void fused_topk_with_score_function_backward_kernel_launcher( check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); auto launch = [&](auto kernel) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); - size_t grid_size = - compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); + size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); kernel<<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, grad_logits); diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 7d0d34a969..6b42f1df14 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -105,8 +105,8 @@ __device__ __forceinline__ float sigmoid_bwd_scalar(float grad, float y) { // Backward: sqrtsoftplus — given original logit x and sqrtsoftplus output y = sqrt(softplus(x)), // dy/dx = sigmoid(x) / (2 * y). For large x (>20), softplus(x) ≈ x so dy/dx ≈ 1/(2*y). __device__ __forceinline__ float sqrtsoftplus_bwd_scalar(float grad, float x, float y) { - float dy_dx = (x > 20.0f) ? (1.0f / (2.0f * y + epsilon)) - : (sigmoid_scalar(x) / (2.0f * y + epsilon)); + float dy_dx = + (x > 20.0f) ? (1.0f / (2.0f * y + epsilon)) : (sigmoid_scalar(x) / (2.0f * y + epsilon)); return grad * dy_dx; } From e8c8fc3b7380997ca00b239e468ef4aad8f1219e Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 10:54:19 +0800 Subject: [PATCH 08/15] [Common] Harden fused router: assertions, shmem budget fix, cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code review fixes: - C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block max). These coincide on Hopper/Blackwell but differ on Ampere. - H3: Remove dead fallback branch in choose_num_buffers() — since total_double >= total_single always, blocks_single >= blocks_double, so the old ternary always returned 1 anyway. - H4/M8: Add host-side NVTE_CHECK in all 4 launchers: - num_experts > 0 - topk in [1, num_experts] - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets) - M9: Assert topk % group_topk == 0 when group_topk > 0. - H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in radix_topk_and_mask() — zero cost in release (NDEBUG), catches 8-bit histogram overflow in debug builds. - L1: Fix stale comments claiming default threshold is 16 (it is 0). - L4: Fix typo 'hanlded' -> 'handled'. - L8: Remove unused topk parameter from aux loss backward kernel. Signed-off-by: Harry Zhou --- .../common/fused_router/async_loader.h | 16 ++++++++++------ .../fused_score_for_moe_aux_loss.cu | 17 ++++++++++++++--- .../fused_topk_with_score_function.cu | 18 +++++++++++++++++- transformer_engine/common/fused_router/utils.h | 7 ++++++- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_router/async_loader.h b/transformer_engine/common/fused_router/async_loader.h index 93cef00118..004a2f7fc6 100644 --- a/transformer_engine/common/fused_router/async_loader.h +++ b/transformer_engine/common/fused_router/async_loader.h @@ -60,16 +60,20 @@ inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) int device_id; NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - int max_smem; - NVTE_CHECK_CUDA( - cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + int max_smem_per_sm; + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_sm, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)); - int blocks_double = (total_double > 0) ? static_cast(max_smem / total_double) : 0; - int blocks_single = (total_single > 0) ? static_cast(max_smem / total_single) : 0; + int blocks_double = + (total_double > 0) ? static_cast(max_smem_per_sm / total_double) : 0; + int blocks_single = + (total_single > 0) ? static_cast(max_smem_per_sm / total_single) : 0; if (blocks_double >= kMinBlocksPerSM) return 2; if (blocks_single >= kMinBlocksPerSM) return 1; - return (blocks_double >= blocks_single) ? 2 : 1; + // Neither option meets the minimum; prefer single buffer for occupancy + // (total_double >= total_single, so blocks_single >= blocks_double always). + return 1; } // ============================================================================ diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 60fb718949..ddc3a0e593 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include @@ -171,6 +172,12 @@ template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, float *scores, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; @@ -199,7 +206,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType). // Radix selection is O(E), independent of K; naive is O(K*E). - // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 16). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: @@ -260,7 +267,7 @@ template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, - int topk, DataType *grad_logits) { + DataType *grad_logits) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -396,6 +403,10 @@ template void fused_score_for_moe_aux_loss_backward_kernel_launcher( const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; @@ -409,7 +420,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); kernel<<>>( - intermediate_output, grad_scores, num_tokens, num_experts, topk, grad_logits); + intermediate_output, grad_scores, num_tokens, num_experts, grad_logits); NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 5dcccd3e77..0c7a0a6a74 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include @@ -295,6 +296,17 @@ void fused_topk_with_score_function_forward_kernel_launcher( int num_groups, int group_topk, float scaling_factor, int score_function, const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); + if (group_topk > 0) { + NVTE_CHECK(topk % group_topk == 0, + "topk must be divisible by group_topk, got topk=", topk, + " group_topk=", group_topk); + } size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); @@ -326,7 +338,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType × BiasType). // Radix selection is O(E), independent of K; naive is O(K*E). - // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 16). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: @@ -593,6 +605,10 @@ void fused_topk_with_score_function_backward_kernel_launcher( const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 6b42f1df14..f7ca3ddc07 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#include + #include "../util/logging.h" #include "../utils.cuh" #include "transformer_engine/transformer_engine.h" @@ -232,6 +234,9 @@ constexpr int kMaxExpertsRadixTopk = 255 * 32; // 8160 __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int topk, int *topk_indices, CompType *topk_scores, int lane_id) { + assert(data_size > 0 && data_size <= kMaxExpertsRadixTopk); + assert(topk > 0 && topk <= data_size); + constexpr int RADIX_BITS = 4; constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 16 buckets constexpr int RADIX_MASK = RADIX_SIZE - 1; // 0xF @@ -388,7 +393,7 @@ __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int ? scores[lane_id] : -std::numeric_limits::infinity(); int index = (lane_id < data_size) ? lane_id : 0; - // Some value is hanlded in local thread + // Some value is handled in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { From 271a4bc9b60d918b342ee088e7fcd58a6ee901ad Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 11:03:47 +0800 Subject: [PATCH 09/15] [Common] Consolidate get_radix_topk_threshold() into utils.h Move the duplicated static function from both .cu files into utils.h as an inline function. Each TU gets its own static local (read-once per TU), which is safe since environment variables are immutable during process lifetime. Documented this in a NOTE comment. Signed-off-by: Harry Zhou --- .../fused_router/fused_score_for_moe_aux_loss.cu | 7 ------- .../fused_router/fused_topk_with_score_function.cu | 9 --------- transformer_engine/common/fused_router/utils.h | 14 ++++++++++++++ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index ddc3a0e593..09cd1d35ef 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -11,7 +11,6 @@ #include "../common.h" #include "../util/logging.h" -#include "../util/system.h" #include "../utils.cuh" #include "async_loader.h" #include "utils.h" @@ -19,12 +18,6 @@ namespace transformer_engine { namespace fused_router { -// Reuse the same threshold as the topk kernel (see fused_topk_with_score_function.cu). -static int get_radix_topk_threshold() { - static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); - return threshold; -} - template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 0c7a0a6a74..f275529a9e 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -11,21 +11,12 @@ #include "../common.h" #include "../util/logging.h" -#include "../util/system.h" #include "async_loader.h" #include "utils.h" namespace transformer_engine { namespace fused_router { -// Topk values below this threshold use naive O(K*E) selection; -// at or above it, use radix O(E) selection. Configurable via -// NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). -static int get_radix_topk_threshold() { - static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); - return threshold; -} - template __global__ void fused_topk_with_score_function_forward_kernel( diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index f7ca3ddc07..9b842d06e1 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -10,12 +10,26 @@ #include #include "../util/logging.h" +#include "../util/system.h" #include "../utils.cuh" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { namespace fused_router { +// Topk values below this threshold use naive O(K*E) selection; +// at or above it, use radix O(E) selection. Configurable via +// NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). +// +// NOTE: This is an inline function with a static local. Each translation unit +// that includes this header gets its own copy of the static, so the env var is +// read once per TU (not once globally). This is safe because environment +// variables are immutable during process lifetime in our usage. +inline int get_radix_topk_threshold() { + static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); + return threshold; +} + // Check if requested shared memory size exceeds device capacity. // Throws an error with num_experts info to help users diagnose the issue. inline void check_shared_memory_capacity_num_experts(size_t shared_memory_size, int num_experts) { From 7590d5836ca038fea07617941b57c564552bc3f2 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 11:09:14 +0800 Subject: [PATCH 10/15] [Common] Template warp_reduce_on_shmem on ReduceFuncType Replace runtime function-pointer dispatch with compile-time if constexpr. Eliminates indirect call overhead in the reduction loop and warp shuffle butterfly, allowing the compiler to emit straight-line arithmetic. Removes the now-unused max() and sum() helper functions. Signed-off-by: Harry Zhou --- .../common/fused_router/fused_moe_aux_loss.cu | 4 +- .../fused_score_for_moe_aux_loss.cu | 4 +- .../fused_topk_with_score_function.cu | 4 +- .../common/fused_router/utils.h | 64 +++++++++---------- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 7e516af97b..b8026be3fd 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -63,8 +63,8 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const int warp_id = threadIdx.x / kThreadsPerWarp; const int lane_id = threadIdx.x % kThreadsPerWarp; if (warp_id == 0) { - CompType block_sum = warp_reduce_on_shmem(shmem_block, static_cast(blockDim.x), - ReduceFuncType::SUM, lane_id); + CompType block_sum = warp_reduce_on_shmem( + shmem_block, static_cast(blockDim.x), lane_id); if (lane_id == 0) { atomicAdd(&Coeff_buf[1], static_cast(block_sum * coeff)); } diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 09cd1d35ef..c0673cee59 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -134,7 +134,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi // Sigmoid/Sqrtsoftplus post-processing: normalize scores to sum to 1 if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { auto sum_logits = - warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); + warp_reduce_on_shmem(local_logits, num_experts, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] /= (sum_logits + epsilon); } @@ -199,7 +199,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType). // Radix selection is O(E), independent of K; naive is O(K*E). - // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index f275529a9e..9c6ed228a3 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -262,7 +262,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Sigmoid/Sqrtsoftplus post-processing when topk > 1 if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { if (topk > 1) { - CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); + CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); } @@ -329,7 +329,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType × BiasType). // Radix selection is O(E), independent of K; naive is O(K*E). - // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). + // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). if (topk < get_radix_topk_threshold()) { switch (score_function) { case 0: diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 9b842d06e1..dd694e0666 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -19,14 +19,15 @@ namespace fused_router { // Topk values below this threshold use naive O(K*E) selection; // at or above it, use radix O(E) selection. Configurable via -// NVTE_RADIX_TOPK_THRESHOLD (default 0, i.e. always radix). +// NVTE_RADIX_TOPK_THRESHOLD (default 8: radix is faster for K>=8, +// naive avoids overhead for very small K). // // NOTE: This is an inline function with a static local. Each translation unit // that includes this header gets its own copy of the static, so the env var is // read once per TU (not once globally). This is safe because environment // variables are immutable during process lifetime in our usage. inline int get_radix_topk_threshold() { - static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 0); + static int threshold = getenv("NVTE_RADIX_TOPK_THRESHOLD", 8); return threshold; } @@ -54,48 +55,43 @@ constexpr int kThreadsPerBlock = 128; // Using 4 warps in 1 CTA, Each warp is responsible for 1 token. constexpr float epsilon = 1e-20; -template -__device__ inline T max(T a, T b) { - return a > b ? a : b; -} - -template -__device__ inline T sum(T a, T b) { - return a + b; -} - enum ReduceFuncType { SUM, MAX, }; -template -__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, - int lane_id) { - T (*reduce_func)(T, T); - CompType default_val = 0.0; - if (type == ReduceFuncType::SUM) { - reduce_func = sum; - default_val = 0.0; - } else if (type == ReduceFuncType::MAX) { - reduce_func = max; - default_val = -std::numeric_limits::infinity(); - } +// Warp-level reduction over shared memory data. Templated on the reduction +// type to enable compile-time dispatch (no function pointer overhead). +template +__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, int lane_id) { + constexpr CompType default_val = + (type == ReduceFuncType::SUM) ? 0.0f : -std::numeric_limits::infinity(); - // Some value is handled in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread + // Each lane accumulates its strided slice CompType val = lane_id < data_size ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - val = reduce_func(val, data_ptr[i]); + if constexpr (type == ReduceFuncType::SUM) { + val += data_ptr[i]; + } else { + val = (data_ptr[i] > val) ? static_cast(data_ptr[i]) : val; + } } - // Warp shuffle between threads - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); - val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); + // Warp shuffle butterfly reduction + if constexpr (type == ReduceFuncType::SUM) { + val += __shfl_xor_sync(0xffffffff, val, 16); + val += __shfl_xor_sync(0xffffffff, val, 8); + val += __shfl_xor_sync(0xffffffff, val, 4); + val += __shfl_xor_sync(0xffffffff, val, 2); + val += __shfl_xor_sync(0xffffffff, val, 1); + } else { + auto shfl_max = [](CompType a, CompType b) { return a > b ? a : b; }; + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 16)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 8)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 4)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 2)); + val = shfl_max(val, __shfl_xor_sync(0xffffffff, val, 1)); + } __syncwarp(); return T(val); } From 0e510cf735a109280829d80f247eafe25d186415 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 14:11:44 +0800 Subject: [PATCH 11/15] [Common] Add simple forward kernel path for small topk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight forward kernel that avoids the async loader and persistent grid overhead. The simple kernel loads logits directly from global memory to shmem and uses Naive iterative-argmax topk — matching the baseline structure that was faster for small K due to lower launch/scheduling overhead. The optimized path (async loader + persistent grid + radix topk) remains the default for topk >= 8 where the compute savings dominate. Both topk and aux_loss forward kernels get the simple variant. Backward kernels are unchanged (always use the optimized path). Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 126 +++++++++-- .../fused_topk_with_score_function.cu | 203 ++++++++++++++++-- .../common/fused_router/utils.h | 23 ++ 3 files changed, 319 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index c0673cee59..f556078239 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -18,6 +18,100 @@ namespace transformer_engine { namespace fused_router { +// ============================================================================= +// Simple aux_loss forward kernel — exact upstream structure (no async loader, +// no persistent grid, runtime score_function dispatch). +// ============================================================================= + +template +__global__ void fused_score_for_moe_aux_loss_forward_simple_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, + float *scores, bool *routing_map, CompType *intermediate_output) { + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem_scores_for_aux_loss[]; + CompType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); + CompType *topk_logits_buf = + reinterpret_cast(logits_buf + num_experts * num_token_per_block); + int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); + CompType *local_logits = logits_buf + warp_id * num_experts; + CompType *topk_logits = topk_logits_buf + warp_id * topk; + int *topk_indices = topk_indices_buf + warp_id * topk; + + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + if (token_offset_cur_warp >= num_tokens) break; + + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the routing_map + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + routing_map[pos_offset + i] = false; + if (score_function == 1) { + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } + } + // Load the logits to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] = static_cast(logits[pos_offset + i]); + } + __threadfence_block(); + __syncwarp(); + + // Preprocess: apply score function + if (score_function == 1) { + apply_softmax_on_float(local_logits, num_experts, lane_id); + __syncwarp(); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + } else if (score_function == 0) { + apply_sigmoid_on_float(local_logits, num_experts, lane_id); + __syncwarp(); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + } else if (score_function == 2) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; + } + __syncwarp(); + apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); + } + + __syncwarp(); + + // Sigmoid/Sqrtsoftplus post-processing: normalize + if (score_function == 0 || score_function == 2) { + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] /= (sum_logits + epsilon); + } + __syncwarp(); + } + + // Topk + topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id); + __syncwarp(); + + // Write outputs + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + } + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[pos_offset + i] = local_logits[i]; + } + __threadfence_block(); + __syncwarp(); + } +} + +// ============================================================================= +// Optimized aux_loss forward kernel — async loader, persistent grid. +// ============================================================================= + template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, @@ -197,24 +291,26 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( NVTE_CHECK_CUDA(cudaGetLastError()); }; - // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType). - // Radix selection is O(E), independent of K; naive is O(K*E). + // Dispatch: small topk uses the simple kernel (no async loader overhead); + // large topk uses the optimized kernel with radix selection + persistent grid. // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). if (topk < get_radix_topk_threshold()) { - switch (score_function) { - case 0: - launch(fused_score_for_moe_aux_loss_forward_kernel); - break; - case 1: - launch(fused_score_for_moe_aux_loss_forward_kernel); - break; - case 2: - launch(fused_score_for_moe_aux_loss_forward_kernel); - break; - default: - NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); - } + // Simple path: exact upstream structure — no async loader, no persistent grid. + check_shared_memory_capacity_num_experts(other_shmem, num_experts); + + auto launch_simple = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + other_shmem)); + kernel<<>>( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + launch_simple( + fused_score_for_moe_aux_loss_forward_simple_kernel); } else { + // Optimized path: async loader + persistent grid + radix topk. NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, " (packed 8-bit histogram), got ", num_experts, "."); diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 9c6ed228a3..1475a9a84a 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -17,6 +17,174 @@ namespace transformer_engine { namespace fused_router { +// ============================================================================= +// Simple forward kernel — exact upstream structure (no async loader, no +// persistent grid, runtime score_function dispatch). Faster for small topk +// due to lower scheduling overhead and separate load/compute/store phases. +// ============================================================================= + +template +__global__ void fused_topk_forward_simple_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, bool *routing_map, + CompType *intermediate_output) { + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + extern __shared__ float shmem[]; + CompType *scores_buf = reinterpret_cast(shmem); + CompType *topk_scores_buf = scores_buf + num_experts * num_token_per_block; + CompType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; + int *topk_indices_buf = nullptr; + if (group_topk > 0) { + masked_scores_buf = topk_scores_buf + topk * num_token_per_block; + group_scores_buf = masked_scores_buf + num_experts * num_token_per_block; + topk_indices_buf = reinterpret_cast(group_scores_buf + num_groups * num_token_per_block); + } else { + topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); + } + CompType *scores = scores_buf + warp_id * num_experts; + CompType *topk_scores = topk_scores_buf + warp_id * topk; + CompType *masked_scores = masked_scores_buf + warp_id * num_experts; + CompType *group_scores = group_scores_buf + warp_id * num_groups; + int *topk_indices = topk_indices_buf + warp_id * topk; + + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_offset_cur_warp = round * num_token_per_block + warp_id; + if (token_offset_cur_warp >= num_tokens) break; + + int pos_offset = token_offset_cur_warp * num_experts; + // Clear the probs/routing_map (num_experts) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + probs[pos_offset + i] = 0.0; + routing_map[pos_offset + i] = false; + if (score_function == 1) { + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + } + } + // Load the logits to shmem + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] = logits[pos_offset + i]; + } + // If group_topk > 0, init the masked_scores to -inf + if (group_topk > 0) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + masked_scores[i] = -std::numeric_limits::infinity(); + } + } + __threadfence_block(); + __syncwarp(); + + // Preprocess: apply score function in-place on shmem + if (use_pre_softmax && score_function == 1) { + apply_softmax_on_float(scores, num_experts, lane_id); + __syncwarp(); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } else if (score_function == 0) { + apply_sigmoid_on_float(scores, num_experts, lane_id); + __syncwarp(); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + } else if (score_function == 2) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; + } + __syncwarp(); + apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); + } + + __syncwarp(); + + // Expert bias (sigmoid/sqrtsoftplus only) + if (expert_bias && (score_function == 0 || score_function == 2)) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + scores[i] += static_cast(expert_bias[i]); + } + __syncwarp(); + } + + // Topk selection + if (group_topk > 0) { + int group_size = num_experts / num_groups; + for (int i = 0; i < num_groups; i++) { + topk_and_mask( + scores + i * group_size, group_size, topk / group_topk, + topk_indices, topk_scores, lane_id); + __syncwarp(); + if (lane_id == 0) { + CompType tmp = 0.0; + for (int j = 0; j < topk / group_topk; j++) { + tmp = tmp + topk_scores[j]; + } + group_scores[i] = tmp; + } + __syncwarp(); + } + topk_and_mask( + group_scores, num_groups, group_topk, topk_indices, topk_scores, lane_id); + __syncwarp(); + for (int i = 0; i < group_topk; i++) { + int st = topk_indices[i] * group_size; + int ed = st + group_size; + for (int j = st + lane_id; j < ed; j += kThreadsPerWarp) { + masked_scores[j] = scores[j]; + } + } + __syncwarp(); + topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id); + } else { + topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id); + } + __syncwarp(); + + // Postprocess: revert bias, softmax, normalization + if (expert_bias && (score_function == 0 || score_function == 2)) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); + } + __syncwarp(); + } + + if (!use_pre_softmax && score_function == 1) { + apply_softmax_on_float(topk_scores, topk, lane_id); + __syncwarp(); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; + } + __syncwarp(); + } + + if (score_function == 0 || score_function == 2) { + if (topk > 1) { + CompType sum_scores = + warp_reduce_on_shmem(topk_scores, topk, lane_id); + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); + } + } + __syncwarp(); + } + + // Write outputs + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = true; + probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; + } + __threadfence_block(); + __syncwarp(); + } +} + +// ============================================================================= +// Optimized forward kernel — async loader, persistent grid, double buffering. +// Used for larger topk where radix selection and compute dominate. +// ============================================================================= + template __global__ void fused_topk_with_score_function_forward_kernel( @@ -327,27 +495,26 @@ void fused_topk_with_score_function_forward_kernel_launcher( NVTE_CHECK_CUDA(cudaGetLastError()); }; - // Dispatch on TopkFunc × ScoreFunc (6 instantiations per DataType × BiasType). - // Radix selection is O(E), independent of K; naive is O(K*E). + // Dispatch: small topk uses the simple kernel (no async loader overhead); + // large topk uses the optimized kernel with radix selection + persistent grid. // Threshold configurable via NVTE_RADIX_TOPK_THRESHOLD (default 8). if (topk < get_radix_topk_threshold()) { - switch (score_function) { - case 0: - launch(fused_topk_with_score_function_forward_kernel); - break; - case 1: - launch(fused_topk_with_score_function_forward_kernel); - break; - case 2: - launch(fused_topk_with_score_function_forward_kernel); - break; - default: - NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); - } + // Simple path: no async loader, no persistent grid — lower overhead for small K. + // Uses the exact upstream kernel structure with runtime score_function dispatch. + check_shared_memory_capacity_num_experts(other_shmem, num_experts); + + auto launch_simple = [&](auto kernel) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + other_shmem)); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + launch_simple(fused_topk_forward_simple_kernel); } else { + // Optimized path: async loader + persistent grid + radix topk. NVTE_CHECK(num_experts <= kMaxExpertsRadixTopk, "Radix topk requires num_experts <= ", kMaxExpertsRadixTopk, " (packed 8-bit histogram), got ", num_experts, "."); diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index dd694e0666..76929b1cd4 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -96,6 +96,29 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, int lane_id return T(val); } +// ============================================================================ +// Array (in-place on shmem) score functions — used by legacy simple kernel +// ============================================================================ + +__device__ inline void apply_sigmoid_on_float(float *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + scores[i] = 1.0f / (1.0f + expf(-scores[i])); + } +} + +__device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = scores[i]; + float softplus_val; + if (x > 20.0f) { + softplus_val = x; + } else { + softplus_val = log1pf(expf(x)); + } + scores[i] = sqrtf(softplus_val); + } +} + // ============================================================================ // Scalar (per-element) score functions — for fused paths // ============================================================================ From bd96bc2ee91bc3a249f8f51d152a0510fc6217bf Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 18:11:46 +0800 Subject: [PATCH 12/15] pre-commit run Signed-off-by: Harry Zhou --- .../common/fused_router/async_loader.h | 6 ++-- .../fused_score_for_moe_aux_loss.cu | 15 +++++---- .../fused_topk_with_score_function.cu | 33 ++++++++++--------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/fused_router/async_loader.h b/transformer_engine/common/fused_router/async_loader.h index 004a2f7fc6..dcc856c28f 100644 --- a/transformer_engine/common/fused_router/async_loader.h +++ b/transformer_engine/common/fused_router/async_loader.h @@ -64,10 +64,8 @@ inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)); - int blocks_double = - (total_double > 0) ? static_cast(max_smem_per_sm / total_double) : 0; - int blocks_single = - (total_single > 0) ? static_cast(max_smem_per_sm / total_single) : 0; + int blocks_double = (total_double > 0) ? static_cast(max_smem_per_sm / total_double) : 0; + int blocks_single = (total_single > 0) ? static_cast(max_smem_per_sm / total_single) : 0; if (blocks_double >= kMinBlocksPerSM) return 2; if (blocks_single >= kMinBlocksPerSM) return 1; diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index f556078239..22d2b2c3e3 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -5,10 +5,11 @@ ************************************************************************/ #include -#include #include #include +#include + #include "../common.h" #include "../util/logging.h" #include "../utils.cuh" @@ -24,9 +25,11 @@ namespace fused_router { // ============================================================================= template -__global__ void fused_score_for_moe_aux_loss_forward_simple_kernel( - const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, - float *scores, bool *routing_map, CompType *intermediate_output) { +__global__ void fused_score_for_moe_aux_loss_forward_simple_kernel(const DataType *logits, + int num_tokens, int num_experts, + int topk, int score_function, + float *scores, bool *routing_map, + CompType *intermediate_output) { int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; @@ -299,8 +302,8 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( check_shared_memory_capacity_num_experts(other_shmem, num_experts); auto launch_simple = [&](auto kernel) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - other_shmem)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); kernel<<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 1475a9a84a..7a9c9ed9c1 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -5,10 +5,11 @@ ************************************************************************/ #include -#include #include #include +#include + #include "../common.h" #include "../util/logging.h" #include "async_loader.h" @@ -24,11 +25,12 @@ namespace fused_router { // ============================================================================= template -__global__ void fused_topk_forward_simple_kernel( - const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, - int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, - CompType *intermediate_output) { +__global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num_tokens, + int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, + float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, + bool *routing_map, CompType *intermediate_output) { int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; @@ -112,9 +114,8 @@ __global__ void fused_topk_forward_simple_kernel( if (group_topk > 0) { int group_size = num_experts / num_groups; for (int i = 0; i < num_groups; i++) { - topk_and_mask( - scores + i * group_size, group_size, topk / group_topk, - topk_indices, topk_scores, lane_id); + topk_and_mask(scores + i * group_size, group_size, topk / group_topk, + topk_indices, topk_scores, lane_id); __syncwarp(); if (lane_id == 0) { CompType tmp = 0.0; @@ -125,8 +126,8 @@ __global__ void fused_topk_forward_simple_kernel( } __syncwarp(); } - topk_and_mask( - group_scores, num_groups, group_topk, topk_indices, topk_scores, lane_id); + topk_and_mask(group_scores, num_groups, group_topk, topk_indices, topk_scores, + lane_id); __syncwarp(); for (int i = 0; i < group_topk; i++) { int st = topk_indices[i] * group_size; @@ -430,7 +431,8 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Sigmoid/Sqrtsoftplus post-processing when topk > 1 if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { if (topk > 1) { - CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, lane_id); + CompType sum_scores = + warp_reduce_on_shmem(topk_scores, topk, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); } @@ -462,8 +464,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", static_cast(num_tokens) * num_experts); if (group_topk > 0) { - NVTE_CHECK(topk % group_topk == 0, - "topk must be divisible by group_topk, got topk=", topk, + NVTE_CHECK(topk % group_topk == 0, "topk must be divisible by group_topk, got topk=", topk, " group_topk=", group_topk); } size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; @@ -504,8 +505,8 @@ void fused_topk_with_score_function_forward_kernel_launcher( check_shared_memory_capacity_num_experts(other_shmem, num_experts); auto launch_simple = [&](auto kernel) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - other_shmem)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); kernel<<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); From 8efda2f1fd9c8e2199b0a5391e659660c20f94d7 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Tue, 19 May 2026 21:40:56 +0800 Subject: [PATCH 13/15] [Common] Fix bf16 ambiguous constructor in vec_fill_global call Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float) and __nv_bfloat16(double) constructors on older CUDA toolkits. Signed-off-by: Harry Zhou --- .../common/fused_router/fused_topk_with_score_function.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 7a9c9ed9c1..7939a11fad 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -288,7 +288,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( int pos_offset = token_offset_cur_warp * num_experts; // Clear the probs/routing_map (num_experts) - vec_fill_global(probs + pos_offset, static_cast(0), num_experts, lane_id); + vec_fill_global(probs + pos_offset, static_cast(0.0f), num_experts, lane_id); vec_fill_global(routing_map + pos_offset, false, num_experts, lane_id); if constexpr (ScoreFunc == 1) { // Softmax From 3bab7cb13006987bca8e50c0b1169783af11190b Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Thu, 21 May 2026 11:02:50 +0800 Subject: [PATCH 14/15] [Common] Address fused router review issues Signed-off-by: Harry Zhou --- .../common/fused_router/async_loader.h | 1 - .../fused_score_for_moe_aux_loss.cu | 51 +++++++++------ .../fused_topk_with_score_function.cu | 62 +++++++++++-------- 3 files changed, 68 insertions(+), 46 deletions(-) diff --git a/transformer_engine/common/fused_router/async_loader.h b/transformer_engine/common/fused_router/async_loader.h index dcc856c28f..32647f1545 100644 --- a/transformer_engine/common/fused_router/async_loader.h +++ b/transformer_engine/common/fused_router/async_loader.h @@ -245,7 +245,6 @@ class RawAsyncLoader { for (int i = lane_id; i < count; i += kThreadsPerWarp) { dst[i] = src[i]; } - cp_async_commit(); // No-op on sm_70; matches wait() expectation on sm_80+. } } }; diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 22d2b2c3e3..f40c20741b 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -281,9 +281,9 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t logits_raw_shmem = RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); size_t shared_memory_size = logits_raw_shmem + other_shmem; - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); auto launch = [&](auto kernel) { + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); size_t grid_size = @@ -350,16 +350,16 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, // No routing_map — all experts participate (unlike topk backward). // Double-buffered cp.async loads both inputs. Two-pass fused approach. // -// Shmem layout (B = 2, W = warps/block): -// grad_buf: B × E × W × sizeof(CompType) — double-buffered async load -// act_buf: B × E × W × sizeof(CompType) — double-buffered async load -constexpr int kAuxBwdNumBuffers = 2; +// Shmem layout (B = num_buffers, W = warps/block): +// grad_buf: B × E × W × sizeof(CompType) — async-loaded grad +// act_buf: B × E × W × sizeof(CompType) — async-loaded activations template __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, - DataType *grad_logits) { + DataType *grad_logits, + int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -374,13 +374,13 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int CompType *grad_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, - kAuxBwdNumBuffers); + num_buffers); shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers); + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, - kAuxBwdNumBuffers); + num_buffers); /*** * Section: Main Loop — persistent grid with double-buffered async load @@ -405,20 +405,27 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int if (token_idx >= num_tokens) break; int pos = token_idx * num_experts; + if (num_buffers == 1 && round != first_round) { + grad_loader.load_current(grad_scores + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + } + grad_loader.wait(); act_loader.wait(); CompType *raw_grad = grad_loader.current_buf(); CompType *raw_act = act_loader.current_buf(); - // Prefetch next round - int next_round = round + gridDim.x; - if (next_round < total_round) { - int next_token = next_round * num_token_per_block + warp_id; - if (next_token < num_tokens) { - int next_pos = next_token * num_experts; - grad_loader.start_load(grad_scores + next_pos, num_experts, lane_id); - act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + // Prefetch next round only when double-buffered; single-buffer loads above. + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_scores + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + } } } @@ -502,9 +509,13 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t single_buf_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(single_buf_shmem, 0); size_t shmem_bytes = - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers) + - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kAuxBwdNumBuffers); + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); auto launch = [&](auto kernel) { @@ -512,7 +523,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes)); size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); kernel<<>>( - intermediate_output, grad_scores, num_tokens, num_experts, grad_logits); + intermediate_output, grad_scores, num_tokens, num_experts, grad_logits, num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 7939a11fad..167738d96c 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -483,9 +483,9 @@ void fused_topk_with_score_function_forward_kernel_launcher( size_t logits_raw_shmem = RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); size_t shared_memory_size = logits_raw_shmem + other_shmem; - check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); auto launch = [&](auto kernel) { + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); size_t grid_size = @@ -564,17 +564,16 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, // Pass 1 (reduction): accumulate warp-level sums needed by normalization/softmax bwd. // Pass 2 (element-wise): compute per-element gradient and write to global memory. // -// Shmem layout (B = 2, W = warps/block): -// grad_raw: B × E × W × sizeof(DataType) — double-buffered async load -// act_buf: B × E × W × sizeof(CompType) — double-buffered async load -// mask_buf: B × E × W × sizeof(bool) — double-buffered async load -constexpr int kBwdNumBuffers = 2; +// Shmem layout (B = num_buffers, W = warps/block): +// grad_raw: B × E × W × sizeof(DataType) — async-loaded grad +// act_buf: B × E × W × sizeof(CompType) — async-loaded activations +// mask_buf: B × E × W × sizeof(bool) — async-loaded routing mask template __global__ void fused_topk_with_score_function_backward_kernel( const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, - DataType *grad_logits) { + DataType *grad_logits, int num_buffers) { /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -589,19 +588,19 @@ __global__ void fused_topk_with_score_function_backward_kernel( DataType *grad_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, - kBwdNumBuffers); + num_buffers); shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, - kBwdNumBuffers); + num_buffers); shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); bool *mask_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, num_token_per_block, - kBwdNumBuffers); + num_buffers); /*** * Section: Main Loop — persistent grid with double-buffered async load @@ -627,6 +626,12 @@ __global__ void fused_topk_with_score_function_backward_kernel( if (token_idx >= num_tokens) break; int pos = token_idx * num_experts; + if (num_buffers == 1 && round != first_round) { + grad_loader.load_current(grad_probs + pos, num_experts, lane_id); + act_loader.load_current(intermediate_output + pos, num_experts, lane_id); + mask_loader.load_current(routing_map + pos, num_experts, lane_id); + } + /*** * Section: Wait for async load + prefetch next round */ @@ -638,15 +643,17 @@ __global__ void fused_topk_with_score_function_backward_kernel( CompType *local_act = act_loader.current_buf(); bool *local_mask = mask_loader.current_buf(); - // Prefetch next round - int next_round = round + gridDim.x; - if (next_round < total_round) { - int next_token = next_round * num_token_per_block + warp_id; - if (next_token < num_tokens) { - int next_pos = next_token * num_experts; - grad_loader.start_load(grad_probs + next_pos, num_experts, lane_id); - act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); - mask_loader.start_load(routing_map + next_pos, num_experts, lane_id); + // Prefetch next round only when double-buffered; single-buffer loads above. + if (num_buffers > 1) { + int next_round = round + gridDim.x; + if (next_round < total_round) { + int next_token = next_round * num_token_per_block + warp_id; + if (next_token < num_tokens) { + int next_pos = next_token * num_experts; + grad_loader.start_load(grad_probs + next_pos, num_experts, lane_id); + act_loader.start_load(intermediate_output + next_pos, num_experts, lane_id); + mask_loader.start_load(routing_map + next_pos, num_experts, lane_id); + } } } @@ -771,10 +778,15 @@ void fused_topk_with_score_function_backward_kernel_launcher( size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t single_buf_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(single_buf_shmem, 0); size_t shmem_bytes = - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers) + - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers) + - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, kBwdNumBuffers); + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers) + + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); check_shared_memory_capacity_num_experts(shmem_bytes, num_experts); auto launch = [&](auto kernel) { @@ -783,7 +795,7 @@ void fused_topk_with_score_function_backward_kernel_launcher( size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, shmem_bytes, total_blocks); kernel<<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, - use_pre_softmax, scaling_factor, grad_logits); + use_pre_softmax, scaling_factor, grad_logits, num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; From 690f4179e668285afd71b79215f1b233ddb0ae0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 May 2026 03:03:52 +0000 Subject: [PATCH 15/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_router/fused_score_for_moe_aux_loss.cu | 3 +-- .../common/fused_router/fused_topk_with_score_function.cu | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index f40c20741b..549d8754a5 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -375,8 +375,7 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int CompType *grad_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, num_buffers); - shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 167738d96c..4384104688 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -589,14 +589,12 @@ __global__ void fused_topk_with_score_function_backward_kernel( DataType *grad_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader grad_loader(grad_shmem_base, warp_id, num_experts, num_token_per_block, num_buffers); - shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); CompType *act_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader act_loader(act_shmem_base, warp_id, num_experts, num_token_per_block, num_buffers); - shmem_ptr += - RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + shmem_ptr += RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); bool *mask_shmem_base = reinterpret_cast(shmem_ptr); RawAsyncLoader mask_loader(mask_shmem_base, warp_id, num_experts, num_token_per_block,