From 091f86225c6bf9078067373e64d33b6e4a1a6cca Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Tue, 12 May 2026 14:12:22 +0800 Subject: [PATCH] opt moe_align_kernel --- custom_ops/gpu_ops/helper.h | 2 + custom_ops/gpu_ops/moe/moe_align_kernel.cu | 604 ++++++++++++++++++ .../gpu_ops/moe/tritonmoe_preprocess.cu | 165 ++--- custom_ops/setup_ops.py | 2 + tests/operators/test_tritonmoe_preprocess.py | 392 +++++++++++- 5 files changed, 1035 insertions(+), 130 deletions(-) create mode 100644 custom_ops/gpu_ops/moe/moe_align_kernel.cu diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 83f3ad1077d..cb8c2e3e623 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -73,6 +73,8 @@ namespace cub = hipcub; using json = nlohmann::json; #endif +#define CEILDIV(a, b) (((a + b - 1) / b)) + #define CUDA_CHECK(call) \ do { \ const cudaError_t error_code = call; \ diff --git a/custom_ops/gpu_ops/moe/moe_align_kernel.cu b/custom_ops/gpu_ops/moe/moe_align_kernel.cu new file mode 100644 index 00000000000..4d2a01d8dd9 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_align_kernel.cu @@ -0,0 +1,604 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Reference +// https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/sgl-kernel/csrc/moe/moe_align_kernel.cu +// Licensed under Apache License 2.0 +// with further performance optimizations applied. + +#include + +#include "helper.h" +#include "paddle/extension.h" + +#define VEC_SIZE 4 +using Vec = int4; + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ int warp_exclusive_scan( + int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} +#endif + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size, + int32_t max_num_tokens_padded) { + // Use a separate thread block to populate sorted_token_ids + if (blockIdx.x == 1) { + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = threadIdx.x; i < total_vecs; i += blockDim.x) { + out_ptr[i] = fill_vec; + } + } + return; + } + + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + atomicAdd(&shared_counts[expert_id], 1); + } + + __syncthreads(); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + +#ifndef __CUDA_ARCH__ // HIP + + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + + __syncthreads(); + + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + +#else // CUDA + + // Intra warp prefix sum + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; + __syncthreads(); + + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; + } + __syncthreads(); + + // Every thread obtains the whole block's sum + if (tid == 0) { + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); + + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); + + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); + } + __syncthreads(); + + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; + __syncthreads(); + + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; +#endif + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } +} + +// ===== Cooperative fused kernel for large batch (single launch, grid.sync) + +namespace cg = cooperative_groups; + +template +__global__ void moe_align_block_size_cooperative_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ global_counts, // [num_experts+1], zeroed by caller + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + cg::grid_group grid = cg::this_grid(); + + extern __shared__ int32_t smem[]; + // smem layout: [num_experts] local_hist + [num_experts+1] expert_starts + int32_t* local_hist = smem; + int32_t* expert_starts_local = smem + num_experts; + + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int nthreads = blockDim.x; + const int nblocks = gridDim.x; + + __shared__ int32_t s_total; + + // ===== Stage 0: Cooperative initialization ===== + // Fill sorted_token_ids with sentinel value (all blocks cooperate) + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = + static_cast(numel); + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = bid * nthreads + tid; i < total_vecs; + i += nblocks * nthreads) { + out_ptr[i] = fill_vec; + } + } + + // Initialize local histogram to 0 + for (int i = tid; i < num_experts; i += nthreads) { + local_hist[i] = 0; + } + __syncthreads(); + + // ===== Stage 1: Local histogram + global atomic merge ===== + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + atomicAdd(&local_hist[expert_id], 1); + } + __syncthreads(); + + // Merge local counts into global via atomic fetch-and-add. + // Return value = prefix_before (reuse local_hist to store it). + for (int i = tid; i < num_experts; i += nthreads) { + int32_t count = local_hist[i]; + int32_t prefix_before = atomicAdd(&global_counts[i], count); + local_hist[i] = prefix_before; + } + + grid.sync(); // all histograms merged, global_counts has totals + + // ===== Stage 2: Redundant prefix sum per block ===== + if (tid == 0) { + int32_t running_sum = 0; + for (int i = 0; i < num_experts; i++) { + int32_t count = global_counts[i]; + int32_t padded = (count + block_size - 1) / block_size * block_size; + expert_starts_local[i] = running_sum; + running_sum += padded; + } + expert_starts_local[num_experts] = running_sum; // total + s_total = running_sum; + } + + grid.sync(); + + // Block 0 writes total_tokens_post_pad and cumsum (global_counts) + if (bid == 0) { + // Write expert starts to global_counts for external consumers + if (tid <= num_experts) { + global_counts[tid] = expert_starts_local[tid]; + } + if (tid == 0) { + *total_tokens_post_pad = s_total; + } + } + + // ===== Stage 3: Fill expert_ids (all blocks cooperate) ===== + const int32_t num_blocks_out = s_total / block_size; + for (int32_t i = bid * nthreads + tid; i < num_blocks_out; + i += nblocks * nthreads) { + int32_t block_start = i * block_size; + // Binary search: find the expert whose start <= block_start < next start + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (expert_starts_local[mid + 1] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 1; // expert indexing: topk_ids uses +1 offset + } + + // ===== Stage 4: Scatter tokens using shared memory atomics ===== + // local_hist[i] currently holds prefix_before for this block. + // We do atomic_add on local_hist to get each token's rank within the expert, + // then add expert_starts_local to get the final position. + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + int32_t rank = atomicAdd(&local_hist[expert_id], 1); + int32_t pos = rank + expert_starts_local[expert_id]; + sorted_token_ids[pos] = i; + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + // Adapted from + // https://github.com/vllm-project/vllm/pull/29642/files#diff-5647b1413f4ae9aacba904eca8f8a8aee9079321eadff4c10101a2c6962dcc53R226 + // Use an additional group of threads to fill sorted_token_ids. + // Since the kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + if (pad_sorted_token_ids) { + for (int32_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[it] = numel; + } + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + ++tokens_cnts[(tid + 1) * num_experts + expert_id]; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid - 1; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } +} + +template +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids) { + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + auto stream = topk_ids.stream(); + + const size_t numel = topk_ids.numel(); + const int64_t max_num_tokens_padded = sorted_token_ids.shape()[0]; + + bool small_batch_expert_mode = (numel < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t expert_threads = max((int32_t)num_experts, WARP_SIZE); + constexpr int32_t fill_threads = 256; + const int32_t shared_mem_size = + ((expert_threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + auto small_batch_expert_kernel = + moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, + fill_threads + expert_threads, + shared_mem_size, + stream>>>(topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + pad_sorted_token_ids, + max_num_tokens_padded); + } else { + // Use cooperative fused kernel for large inputs where multi-block + // parallelism outweighs cooperative launch overhead + if (numel >= 16384) { + const int coop_threads = 256; + const size_t coop_smem = (2 * num_experts + 1) * sizeof(int32_t); + + auto coop_kernel = moe_align_block_size_cooperative_kernel; + + static int cached_max_blocks_per_sm = 0; + static int cached_num_sms = 0; + if (cached_num_sms == 0) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&cached_max_blocks_per_sm, + (void*)coop_kernel, + coop_threads, + coop_smem); + int device_id; + cudaGetDevice(&device_id); + cudaDeviceGetAttribute( + &cached_num_sms, cudaDevAttrMultiProcessorCount, device_id); + } + + int max_coop_blocks = cached_max_blocks_per_sm * cached_num_sms; + int desired_blocks = std::max( + 1, std::min(256, static_cast(numel / (coop_threads * 4)))); + int coop_blocks = std::min(desired_blocks, max_coop_blocks); + if (coop_blocks < 1) coop_blocks = 1; + + const scalar_t* topk_ids_ptr = topk_ids.data(); + int32_t* sorted_token_ids_ptr = sorted_token_ids.data(); + int32_t* experts_ids_ptr = experts_ids.data(); + int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data(); + int32_t* cumsum_ptr = cumsum_buffer.data(); + int32_t num_experts_i32 = static_cast(num_experts); + int32_t block_size_i32 = static_cast(block_size); + size_t numel_val = numel; + bool pad_val = pad_sorted_token_ids; + int32_t max_padded_i32 = static_cast(max_num_tokens_padded); + + void* args[] = {&topk_ids_ptr, + &sorted_token_ids_ptr, + &experts_ids_ptr, + &num_tokens_post_pad_ptr, + &cumsum_ptr, + &num_experts_i32, + &block_size_i32, + &numel_val, + &pad_val, + &max_padded_i32}; + + cudaError_t err = cudaLaunchCooperativeKernel((void*)coop_kernel, + dim3(coop_blocks), + dim3(coop_threads), + args, + coop_smem, + stream); + + if (err == cudaSuccess) { + return; + } + // Fall through to original path if cooperative launch failed + } + + // Original 2-kernel approach (for medium inputs or cooperative fallback) + auto align_kernel = moe_align_block_size_kernel; + + const size_t scan_size = next_pow_2(num_experts); + const size_t shared_mem_size = + (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * + sizeof(int32_t); + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + cumsum_buffer.data(), + pad_sorted_token_ids, + scan_size, + max_num_tokens_padded); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = ((int)numel + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data(), + sorted_token_ids.data(), + cumsum_buffer.data(), + numel); + } +} + +// Explicit instantiations for use from other translation units (e.g. +// tritonmoe_preprocess.cu) +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); diff --git a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu index 071e0a9b418..eb680ea744e 100644 --- a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu +++ b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu @@ -15,83 +15,40 @@ #include "helper.h" #include "paddle/extension.h" -#define CEILDIV(a, b) (((a + b - 1) / b)) - template -__global__ void count_and_sort_expert_tokens_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - size_t numel) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; - } -} - -template -__global__ void moe_align_block_size_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t GEMM_BLOCK_SIZE_M, - size_t numel, - int32_t* __restrict__ cumsum_buffer) { - __shared__ int32_t tokens_per_ep[num_experts]; - - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { - tokens_per_ep[i] = 0; - } - - __syncthreads(); - - for (int i = threadIdx.x; i < numel; i += blockDim.x) { - int expert_id = topk_ids[i]; - atomicAdd(&tokens_per_ep[expert_id], 1); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - cumsum_buffer[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = tokens_per_ep[i - 1]; - cumsum_buffer[i] = - cumsum_buffer[i - 1] + - CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M; - } - *total_tokens_post_pad = cumsum_buffer[num_experts]; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; - i += GEMM_BLOCK_SIZE_M) { - expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x; - } - } -} +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids); std::vector> tritonmoe_preprocessInferShape( const std::vector& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids[0] * topk_ids[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = 1; + for (int64_t dim : topk_ids) { + topk_ids_numel *= static_cast(dim); + } + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } std::vector sorted_ids = {max_num_tokens_padded}; - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; - std::vector expert_ids = {max_num_m_blocks}; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; + std::vector experts_ids = {max_num_m_blocks}; std::vector num_tokens_post_pad = {1}; - return {sorted_ids, expert_ids, num_tokens_post_pad}; + return {sorted_ids, experts_ids, num_tokens_post_pad}; } std::vector tritonmoe_preprocessIferDtype( @@ -127,76 +84,50 @@ std::vector tritonmoe_preprocess_kernel( const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = static_cast(topk_ids.numel()); + + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } auto sorted_ids = paddle::full({max_num_tokens_padded}, topk_ids_numel, paddle::DataType::INT32, topk_ids.place()); - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; - auto expert_ids = paddle::empty( + auto experts_ids = paddle::empty( {max_num_m_blocks}, paddle::DataType::INT32, topk_ids.place()); auto num_tokens_post_pad = paddle::empty({1}, paddle::DataType::INT32, topk_ids.place()); - auto cumsum_buffer = paddle::empty( - {num_experts + 1}, paddle::DataType::INT32, topk_ids.place()); + auto cumsum_buffer = paddle::zeros( + {num_experts + 2}, paddle::DataType::INT32, topk_ids.place()); - auto stream = topk_ids.stream(); using scalar_t = int64_t; - -#define run_align_kernel(num_experts) \ - auto align_kernel = moe_align_block_size_kernel; \ - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data(), \ - expert_ids.data(), \ - num_tokens_post_pad.data(), \ - GEMM_BLOCK_SIZE_M, \ - topk_ids_numel, \ - cumsum_buffer.data()); - - if (num_experts == 8) { - run_align_kernel(8); - } else if (num_experts == 256) { - run_align_kernel(256); - } else if (num_experts == 2) { - run_align_kernel(2); - } else if (num_experts == 64) { - run_align_kernel(64); - } else if (num_experts == 128) { - run_align_kernel(128); - } else if (num_experts == 160) { - run_align_kernel(160); - } else if (num_experts == 32) { - run_align_kernel(32); - } else { - PD_THROW("Not support num_experts: %d", num_experts); - } - - const int block_threads = 256; - const int num_blocks = CEILDIV(topk_ids_numel, block_threads); - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - auto sort_kernel = count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data(), - sorted_ids.data(), - cumsum_buffer.data(), - topk_ids_numel); - - return {sorted_ids, expert_ids, num_tokens_post_pad}; + moe_align_block_size(topk_ids, + num_experts + 1, + GEMM_BLOCK_SIZE_M, + sorted_ids, + experts_ids, + num_tokens_post_pad, + cumsum_buffer, + true); + + return {sorted_ids, experts_ids, num_tokens_post_pad}; } PD_BUILD_STATIC_OP(tritonmoe_preprocess) .Inputs({"topk_ids"}) .Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"}) - .Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"}) + .Outputs({"sorted_ids", "experts_ids", "num_tokens_post_pad"}) .SetKernelFn(PD_KERNEL(tritonmoe_preprocess_kernel)) .SetInferShapeFn(PD_INFER_SHAPE(tritonmoe_preprocessInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(tritonmoe_preprocessIferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 268cde02825..bcb02f4759e 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -237,6 +237,7 @@ def find_end_files(directory, end_str): "gpu_ops/set_data_ipc.cu", "gpu_ops/unset_data_ipc.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/step_system_cache.cu", "gpu_ops/get_output_ep.cc", "gpu_ops/speculate_decoding/speculate_get_padding_offset.cu", @@ -693,6 +694,7 @@ def find_end_files(directory, end_str): "gpu_ops/append_attn/mla_cache_kernel.cu", "gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/moe/moe_topk_select.cu", "gpu_ops/get_img_boundaries.cc", "gpu_ops/remote_cache_kv_ipc.cc", diff --git a/tests/operators/test_tritonmoe_preprocess.py b/tests/operators/test_tritonmoe_preprocess.py index 94d85c956e1..7071e275225 100644 --- a/tests/operators/test_tritonmoe_preprocess.py +++ b/tests/operators/test_tritonmoe_preprocess.py @@ -12,12 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Correctness tests for tritonmoe_preprocess +========================================== + +Tests the fastdeploy wrapper: + tritonmoe_preprocess(topk_ids, num_experts, block_size) + -> (sorted_token_ids, expert_ids, num_tokens_post_padded) + +The verification approach mirrors FlagTree/python/tutorials/tle/02-moe_align_block_size.py: + - Use paddle.bincount as an independent reference (no second kernel to cross-compare). + - Validate three dimensions: + 1. num_tokens_post_padded – total token count after per-expert block alignment + 2. expert_ids – each block is mapped to the correct expert + 3. sorted_token_ids – every token is routed to the right expert's slot, + and padding slots carry sentinel values >= num_tokens +""" + import unittest import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess +# --------------------------------------------------------------------------- +# Import guard – skip entire module when CUDA is unavailable or +# fastdeploy is not installed (e.g. CPU-only CI environments). +# --------------------------------------------------------------------------- +try: + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + + _AVAILABLE = paddle.device.is_compiled_with_cuda() +except Exception: + _AVAILABLE = False + +DEVICE = "gpu" + +# 仅对小规模 case 打印详细 tensor,超过此阈值只打印统计摘要 +_PRINT_TENSOR_NUMEL_LIMIT = 64 + + +def _fmt_tensor(t: paddle.Tensor, name: str) -> str: + t_cpu = t.cpu() + if t_cpu.numel() <= _PRINT_TENSOR_NUMEL_LIMIT: + return f"{name}{list(t_cpu.shape)} = {t_cpu.tolist()}" + return ( + f"{name}{list(t_cpu.shape)} | " + f"min={int(t_cpu.min())} max={int(t_cpu.max())} " + f"mean={float(t_cpu.cast('float32').mean()):.2f} numel={t_cpu.numel()}" + ) + + +# --------------------------------------------------------------------------- +# Reference helpers (CPU, independent of the kernel under test) +# --------------------------------------------------------------------------- + + +def _ref_counts_and_cumsum(topk_ids_flat: paddle.Tensor, num_experts: int, block_size: int): + """ + Compute per-expert token counts and the cumulative sum of block-aligned counts. + + Returns: + counts : int32 tensor of shape (num_experts,) + cumsum : int32 tensor of shape (num_experts,) – cumulative aligned counts + """ + # Only consider valid expert ids [0, num_experts); ignore -1 (EP filtered) + valid_mask = (topk_ids_flat >= 0) & (topk_ids_flat < num_experts) + valid_ids = topk_ids_flat[valid_mask] + counts = paddle.bincount(valid_ids.cast("int64"), minlength=num_experts).cast("int32") + aligned = ((counts + block_size - 1) // block_size) * block_size + cumsum = paddle.cumsum(aligned, axis=0).cast("int32") + return counts, cumsum + + +# --------------------------------------------------------------------------- +# Core verification logic (shared across all test cases) +# --------------------------------------------------------------------------- + + +def _verify(topk_ids: paddle.Tensor, block_size: int, num_experts: int, label: str = ""): + """ + Run tritonmoe_preprocess and verify all three output tensors. + topk_ids may be 1-D or 2-D; dtype int32 or int64. + Prints inputs, golden references, kernel outputs, and per-check comparison. + """ + tag = f"[{label}] " if label else "" + sep = "=" * 70 + + sorted_token_ids, expert_ids, num_tokens_post_pad = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + topk_ids_flat = topk_ids.flatten().cast("int64").cpu() + num_tokens = topk_ids_flat.numel() + + counts, cumsum = _ref_counts_and_cumsum(topk_ids_flat, num_experts, block_size) + aligned = ((counts + block_size - 1) // block_size) * block_size + valid_length = int(cumsum[-1].item()) + num_blocks = valid_length // block_size + + expected_expert_ids = paddle.repeat_interleave( + paddle.arange(num_experts, dtype="int32"), # CPU + (aligned // block_size).cast("int32"), + ) + + np.testing.assert_array_equal( + num_tokens_post_pad.cpu().numpy(), + cumsum[-1:].cpu().numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 2: expert_ids – each block maps to the expected expert # + # ------------------------------------------------------------------ # + got_eids = expert_ids[:num_blocks].cpu() + want_eids = expected_expert_ids.cpu() + np.testing.assert_array_equal( + got_eids.numpy(), + want_eids.numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 3: sorted_token_ids – routing correctness per expert # + # ------------------------------------------------------------------ # + + start = 0 + for expert_id in range(num_experts): + end = int(cumsum[expert_id].item()) + tokens = sorted_token_ids[start:end].cpu() + valid_tokens = tokens[tokens < num_tokens] + # padding_tokens = tokens[tokens >= num_tokens] + + want_count = int(counts[expert_id].item()) + got_count = valid_tokens.numel() + count_ok = got_count == want_count + + assert count_ok, f"expert {expert_id}: expected {want_count} valid tokens, got {got_count}" + if counts[expert_id] > 0: + np.testing.assert_array_equal( + topk_ids_flat[valid_tokens.cast("int64")].numpy(), + paddle.full_like(valid_tokens, expert_id).numpy(), + ) + start = end + + # padding 区域哨兵检查 + if valid_length < sorted_token_ids.numel(): + padding_region = sorted_token_ids[valid_length:].cpu() + sentinel_ok = paddle.all(padding_region >= num_tokens).item() + + assert sentinel_ok, "padding slots beyond valid_length contain non-sentinel values" + + print(f"\n{tag}ALL CHECKS PASSED") + print(sep) + + +# --------------------------------------------------------------------------- +# Original unittest-based tests (kept for backward compatibility) +# --------------------------------------------------------------------------- class TestTritonMOEPreprocess(unittest.TestCase): @@ -35,10 +182,14 @@ def _check_output_shapes( self, sorted_ids, expert_ids, num_tokens_post_pad, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M ): """Check output shapes and dtypes""" - expected_max_num_tokens_padded = topk_ids_np.size + num_experts * (GEMM_BLOCK_SIZE_M - 1) + if topk_ids_np.size < num_experts + 1: + expected_max_num_tokens_padded = topk_ids_np.size * GEMM_BLOCK_SIZE_M + else: + expected_max_num_tokens_padded = topk_ids_np.size + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1) + self.assertEqual(sorted_ids.shape[0], expected_max_num_tokens_padded) - expected_max_num_m_blocks = expected_max_num_tokens_padded // GEMM_BLOCK_SIZE_M + expected_max_num_m_blocks = (expected_max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) // GEMM_BLOCK_SIZE_M self.assertEqual(expert_ids.shape[0], expected_max_num_m_blocks) self.assertEqual(num_tokens_post_pad.shape[0], 1) @@ -104,17 +255,232 @@ def test_basic_case(self): ) self._check_output_values_basic(sorted_ids, expert_ids, num_tokens_post_pad) - def test_unsupported_num_experts(self): - """Test unsupported num_experts raises OSError""" - topk_ids_np = np.array([[0, 1], [1, 0]], dtype=np.int64) - unsupported_experts = [3, 9, 65, 129] - GEMM_BLOCK_SIZE_M = 4 - for num_experts in unsupported_experts: - with self.subTest(num_experts=num_experts): - with self.assertRaises(OSError): - self._run_op(topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M) +# --------------------------------------------------------------------------- +# Correctness tests (ported from test_moe_align_block_size.py) +# --------------------------------------------------------------------------- + + +class TestTritonMoePreprocessBasic(unittest.TestCase): + """Basic / small cases – easy to reason about manually.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_docstring_example(self): + """Reproduce the example from the function docstring.""" + topk_ids = paddle.to_tensor([[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], dtype="int64") + _verify(topk_ids, block_size=4, num_experts=5, label="docstring_example") + + def test_single_token_single_expert(self): + """Minimal input: one token assigned to one expert.""" + topk_ids = paddle.to_tensor([[0]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="single_token_single_expert") + + def test_all_tokens_same_expert(self): + """All tokens go to expert 0 – only one expert's slot is used.""" + topk_ids = paddle.zeros((64, 1), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="all_tokens_same_expert") + + def test_uniform_1d(self): + """1-D topk_ids (top_k=1 squeezed) with uniform distribution.""" + paddle.seed(42) + topk_ids = paddle.randint(0, 8, (128,), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="uniform_1d") + + def test_topk_equals_num_experts(self): + """Every token selects all experts (top_k == num_experts).""" + num_experts = 4 + topk_ids = paddle.arange(num_experts, dtype="int64").unsqueeze(0).expand((8, num_experts)) + _verify(topk_ids, block_size=4, num_experts=num_experts, label="topk_equals_num_experts") + + def test_num_tokens_less_than_num_experts(self): + """Fewer tokens than experts – exercises the small-input branch.""" + topk_ids = paddle.to_tensor([[0], [3]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=64, label="num_tokens_less_than_num_experts") + + def test_exact_block_boundary(self): + """Token count per expert is exactly block_size – no padding needed.""" + block_size = 16 + num_experts = 4 + topk_ids = paddle.concat([paddle.full((block_size,), e, dtype="int64") for e in range(num_experts)]) + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label="exact_block_boundary") + + def test_block_size_1(self): + """block_size=1 means no padding is ever added.""" + paddle.seed(0) + topk_ids = paddle.randint(0, 16, (64,), dtype="int64") + _verify(topk_ids, block_size=1, num_experts=16, label="block_size_1") + + +class TestTritonMoePreprocessEdgeCases(unittest.TestCase): + """Edge / boundary cases.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_empty_topk_ids(self): + """Zero-token input should not crash; num_tokens_post_pad == 0.""" + topk_ids = paddle.empty((0,), dtype="int64").cuda() + sorted_ids, expert_ids_out, num_post = tritonmoe_preprocess(topk_ids, 8, 16) + got = int(num_post.item()) + + self.assertEqual(got, 0) + + def test_one_expert(self): + """Single expert: all tokens must end up in expert 0's bucket.""" + paddle.seed(1) + topk_ids = paddle.zeros((32,), dtype="int64") + _verify(topk_ids, block_size=8, num_experts=1, label="one_expert") + + def test_large_block_size(self): + """block_size larger than total tokens.""" + topk_ids = paddle.randint(0, 4, (8,), dtype="int64") + _verify(topk_ids, block_size=128, num_experts=4, label="large_block_size") + + def test_int64_dtype(self): + """topk_ids in int64 – the kernel should handle dtype conversion.""" + paddle.seed(7) + topk_ids = paddle.randint(0, 8, (64, 2), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="int64_dtype") + + +class TestTritonMoePreprocessRealistic(unittest.TestCase): + """Larger, more realistic MoE shapes.""" + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def _run_uniform_distribution(self, num_tokens, num_experts, block_size): + """Uniform random token-to-expert assignment across common MoE shapes.""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"uniform_T{num_tokens}_E{num_experts}_B{block_size}", + ) + + def test_uniform_distribution(self): + """Uniform random token-to-expert assignment across common MoE shapes.""" + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 128, 128), + (16384, 256, 128), + (16384, 512, 256), + (32768, 512, 256), + (32768, 512, 64), + (163840, 1024, 256), + ]: + with self.subTest(num_tokens=num_tokens, num_experts=num_experts, block_size=block_size): + self._run_uniform_distribution(num_tokens, num_experts, block_size) + + def _run_topk_2d(self, num_tokens, top_k, num_experts, block_size): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens, top_k), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"topk2d_T{num_tokens}_K{top_k}_E{num_experts}_B{block_size}", + ) + + def test_topk_2d(self): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + with self.subTest(num_tokens=num_tokens, top_k=top_k, num_experts=num_experts, block_size=block_size): + self._run_topk_2d(num_tokens, top_k, num_experts, block_size) + + def _run_zipf_distribution(self, alpha): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + num_tokens, num_experts, block_size = 8192, 64, 16 + ranks = paddle.arange(1, num_experts + 1, dtype="float32") + probs = 1.0 / ranks**alpha + probs = probs / probs.sum() + paddle.seed(0) + topk_ids = paddle.multinomial(probs, num_tokens, replacement=True).cast("int64") + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label=f"zipf_alpha{alpha}") + + def test_zipf_distribution(self): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + for alpha in [0.5, 1.2, 2.0]: + with self.subTest(alpha=alpha): + self._run_zipf_distribution(alpha) + + def test_deterministic_with_fixed_seed(self): + """Same seed must produce the same outputs (kernel is deterministic).""" + num_tokens, num_experts, block_size = 4096, 64, 16 + + paddle.seed(99) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s1, e1, n1 = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + paddle.seed(99) + topk_ids2 = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s2, e2, n2 = tritonmoe_preprocess(topk_ids2, num_experts, block_size) + + valid = int(n1.item()) + + np.testing.assert_array_equal(n1.numpy(), n2.numpy()) + np.testing.assert_array_equal(e1[: valid // block_size].numpy(), e2[: valid // block_size].numpy()) + np.testing.assert_array_equal(paddle.sort(s1[:valid]).numpy(), paddle.sort(s2[:valid]).numpy()) + + +# --------------------------------------------------------------------------- +# Direct-run entry point (python test_tritonmoe_preprocess.py) +# --------------------------------------------------------------------------- if __name__ == "__main__": - unittest.main() + if not _AVAILABLE: + print("SKIP: CUDA or fastdeploy not available.") + else: + basic = TestTritonMoePreprocessBasic() + basic.test_docstring_example() + basic.test_single_token_single_expert() + basic.test_all_tokens_same_expert() + basic.test_uniform_1d() + basic.test_topk_equals_num_experts() + basic.test_num_tokens_less_than_num_experts() + basic.test_exact_block_boundary() + basic.test_block_size_1() + + edge = TestTritonMoePreprocessEdgeCases() + edge.test_empty_topk_ids() + edge.test_one_expert() + edge.test_large_block_size() + edge.test_int64_dtype() + + real = TestTritonMoePreprocessRealistic() + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 256, 128), + ]: + real._run_uniform_distribution(num_tokens, num_experts, block_size) + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + real._run_topk_2d(num_tokens, top_k, num_experts, block_size) + for alpha in [0.5, 1.2, 2.0]: + real._run_zipf_distribution(alpha) + real.test_deterministic_with_fixed_seed() + + print("\n*** All direct-run tests passed ***")