From bbe8bff66bf827a65fccb9d0de474777ce9cc5c8 Mon Sep 17 00:00:00 2001 From: Bingoo <33573610+BingooYang@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:12:32 +0800 Subject: [PATCH 1/3] [Optimization] Elemenwise fusion (#6880) * conflict * support more cast type * modify test * add type check * fix config issues * enable more backend * modify 2025->2026 * only support gpu backend and fix test issues * support gpu backend * modify format * modify except type * conflict fix * format * fix dynamic_load_weight config issues * fix dynamic_load_weight config issue * fix fake get_moe_scores config issues * add mock test * modify RL control env * Modify code standardization * fix test issue * fix test issue * fix conflict issue * fix mockdata config issue * enhance test converage * fix test --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu | 206 ++++++++ custom_ops/setup_ops.py | 2 + .../layers/moe/fused_cast_sigmoid_bias.py | 73 +++ .../layers/moe/fused_moe_cutlass_backend.py | 11 +- .../layers/moe/fused_moe_deepgemm_backend.py | 6 +- .../layers/moe/fused_moe_triton_backend.py | 7 +- fastdeploy/model_executor/layers/moe/moe.py | 14 +- tests/layers/test_deepgemm_fused_moe.py | 21 + tests/layers/test_fused_cast_sigmoid_bias.py | 497 ++++++++++++++++++ .../layers/test_fused_moe_cutlass_backend.py | 151 +++++- tests/layers/test_fused_moe_triton_backend.py | 89 ++++ 12 files changed, 1081 insertions(+), 7 deletions(-) create mode 100644 custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu create mode 100644 fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py create mode 100644 tests/layers/test_fused_cast_sigmoid_bias.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 15866c57643..591cf363f06 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,10 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type); + std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -1700,6 +1704,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("fused_cast_sigmoid_bias", + &FusedCastSigmoidBias, + "Fused cast+sigmoid+bias for MoE gating scores", + py::arg("input"), + py::arg("bias"), + py::arg("cast_type") = std::string("float32")); + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu new file mode 100644 index 00000000000..f25084076c4 --- /dev/null +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -0,0 +1,206 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +// Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> +// scores_with_bias +// +// For each element (token i, expert j): +// scores[i][j] = OutT(sigmoid(float(input[i][j]))) +// scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) +// +// Input: input [num_tokens, num_experts] bf16/fp16/fp32 +// bias [num_experts] or [1, num_experts] fp32 +// Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// +// Precision guarantee: +// All intermediate computations (cast, sigmoid, bias addition) are performed +// in float32, regardless of input/output types. The cast to OutT only happens +// at the final store. This matches the reference implementation: +// gate_fp32 = gate_out.cast("float32") +// scores_fp32 = sigmoid(gate_fp32) +// scores_with_bias_fp32 = scores_fp32 + bias // bias is always float32 +// scores = scores_fp32.cast(cast_type) +// scores_with_bias = scores_with_bias_fp32.cast(cast_type) +// +// When cast_type is "float32", the fused kernel is numerically identical to +// the reference. For fp16/bf16 output, the only precision loss comes from +// the final static_cast, equivalent to .cast() in the reference path. +// +// Note: bias is intentionally kept as float32 (not converted to OutT) to +// ensure the addition s + bias[j] is always computed in full float32 +// precision before the final downcast. + +template +__global__ void fused_cast_sigmoid_bias_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + // All intermediate computation in float32 for precision + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias[j] (float32) -> float32 addition, then downcast + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +// Vectorized version for better memory throughput +template +__global__ void fused_cast_sigmoid_bias_vec_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, // kept as float32 for full-precision add + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + using in_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + using bias_vec_t = AlignedVector; // float32 bias vectors + + const int vec_count = num_experts / kVecSize; + for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { + const int base = idx * kVecSize; + in_vec_t in_vec; + bias_vec_t bias_vec; + Load(input + offset + base, &in_vec); + Load(bias + base, &bias_vec); + + out_vec_t s_vec, sb_vec; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + // All intermediate computation in float32 for precision + float val = static_cast(in_vec[i]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias_vec[i] (float32) -> float32 addition, then downcast + s_vec[i] = static_cast(s); + sb_vec[i] = static_cast(s + bias_vec[i]); + } + + Store(s_vec, scores + offset + base); + Store(sb_vec, scores_with_bias + offset + base); + } + + // Handle remaining elements (same float32 precision guarantee) + const int remaining_start = vec_count * kVecSize; + for (int j = remaining_start + threadIdx.x; j < num_experts; + j += blockDim.x) { + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +static paddle::DataType ParseCastType(const std::string& cast_type) { + if (cast_type == "float32") return paddle::DataType::FLOAT32; + if (cast_type == "float16") return paddle::DataType::FLOAT16; + if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; + PD_THROW("Unsupported cast_type: " + cast_type + + ". Only float32, float16, bfloat16 are supported."); +} + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type) { + auto input_shape = input.shape(); + PD_CHECK(input_shape.size() == 2, + "input must be 2D [num_tokens, num_experts]"); + auto bias_shape = bias.shape(); + // Support both [num_experts] and [1, num_experts] bias shapes + PD_CHECK( + bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), + "bias must be 1D [num_experts] or 2D [1, num_experts]"); + + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; + PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, + "bias must be float32, got ", + bias.dtype()); + + auto place = input.place(); + auto stream = input.stream(); + auto out_dtype = ParseCastType(cast_type); + + auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); + auto scores_with_bias = + paddle::empty({num_tokens, num_experts}, out_dtype, place); + + if (num_tokens == 0) { + return {scores, scores_with_bias}; + } + + dim3 grid(num_tokens); + int block_size = std::min(static_cast(1024), num_experts); + // Round up to warp size + block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(block_size); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { + DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { + constexpr int kVecSize = 16 / sizeof(in_scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); + }); + + return {scores, scores_with_bias}; +} + +std::vector FusedCastSigmoidBiasInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& bias_dtype, + std::string cast_type) { + auto out_dtype = ParseCastType(cast_type); + return {out_dtype, out_dtype}; +} + +std::vector> FusedCastSigmoidBiasInferShape( + const std::vector& input_shape, + const std::vector& bias_shape) { + return {input_shape, input_shape}; +} + +PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) + .Inputs({"input", "bias"}) + .Outputs({"scores", "scores_with_bias"}) + .Attrs({"cast_type: std::string"}) + .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 268cde02825..7ae1e964761 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -330,6 +330,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -684,6 +685,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..44d7e54ae88 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -0,0 +1,73 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +_FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = None + +try: + from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, + ) +except ImportError as e: + _fused_cast_sigmoid_bias_cuda = None + _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = e + + +def is_available() -> bool: + """Return whether the fused GPU custom op is available.""" + return _fused_cast_sigmoid_bias_cuda is not None + + +def fused_cast_sigmoid_bias( + gate_out: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + cast_type: str = "float32", +) -> tuple: + """ + Fused operation: cast gate_out to the specified type, apply sigmoid, and add bias. + + This function fuses the following three separate operations: + 1. gate_out = gate_out.cast(cast_type) + 2. scores = sigmoid(gate_out) + 3. scores_with_bias = scores + e_score_correction_bias + + Args: + gate_out: [num_tokens, num_experts], bf16/fp16/fp32 dtype - raw gate output + e_score_correction_bias: [num_experts], fp32 dtype - correction bias + cast_type: output dtype string, supports "float32", "float16", "bfloat16" + + Returns: + scores: [num_tokens, num_experts], cast_type dtype - result of sigmoid(gate_out) + scores_with_bias: [num_tokens, num_experts], cast_type dtype - scores with bias added + + Precision: + All intermediate computations (cast, sigmoid, bias addition) are performed + in float32 precision; conversion to cast_type happens only at the final store. + When cast_type is "float32", the result is bit-exact with the following + reference implementation: + gate_fp32 = gate_out.cast("float32") + scores = sigmoid(gate_fp32) + scores_with_bias = scores + bias + When cast_type is "float16"/"bfloat16", the only precision loss comes from + the final type conversion, equivalent to calling .cast(cast_type) after + computing in float32. + """ + if _fused_cast_sigmoid_bias_cuda is None: + raise ImportError( + "fused_cast_sigmoid_bias is not available. " "Please ensure the GPU custom ops are compiled." + ) from _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 92e039dd742..a6df323a6d9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -340,9 +340,11 @@ def apply_tp( Paddle Cutlass compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -351,8 +353,10 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -405,6 +409,9 @@ def apply_tp( return fused_moe_out if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -414,6 +421,7 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) ( @@ -438,6 +446,7 @@ def apply_tp( topk_only_mode=True, ) else: + gate_out = gate_out.cast("float32") ( permute_input, token_nums_per_expert, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index a16e5ccbe9c..acc0751c11a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -742,9 +742,11 @@ def apply_tp( below is TP compute method. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( gate_out, layer.n_group, @@ -754,8 +756,10 @@ def apply_tp( layer.gate_correction_bias, getattr(layer, "renormalize", True), topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 65d1d23b9be..5b46d31362c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -28,6 +28,7 @@ set_weight_attrs, weight_fully_copied, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase @@ -299,7 +300,6 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k @@ -307,6 +307,9 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": + use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -315,8 +318,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 9cb9340cf01..5dc20cac0a2 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -39,8 +39,14 @@ from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant except: logger.warning("import noaux_tc Failed!") + import numpy as np +if current_platform.is_cuda(): + from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + ) + def get_moe_method(layer=None): """ @@ -91,13 +97,17 @@ def get_moe_scores( tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + use_fused_cast: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. """ - scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias + if use_fused_cast and current_platform.is_cuda(): + scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") + else: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc diff --git a/tests/layers/test_deepgemm_fused_moe.py b/tests/layers/test_deepgemm_fused_moe.py index 5381ee866a3..66910544756 100644 --- a/tests/layers/test_deepgemm_fused_moe.py +++ b/tests/layers/test_deepgemm_fused_moe.py @@ -205,6 +205,27 @@ def hook(topk_ids): assert "topk_ids" in captured assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm + def test_apply_tp_noaux_tc_with_use_fused_false(self): + """noaux_tc path with FD_ENABLE_RL=True: triggers use_fused=False and gate_out.cast('float32').""" + layer = DummyLayer() + layer.topk_method = "noaux_tc" + gate = DummyGate(layer.num_local_experts) + method = _make_method() + + x = paddle.randn([NUM_TOKENS, HIDDEN_SIZE], dtype="bfloat16") + + import fastdeploy.envs as fd_envs + + original_fd_enable_rl = fd_envs.FD_ENABLE_RL + fd_envs.FD_ENABLE_RL = True + + try: + out = method.apply(layer, x, gate) + assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + finally: + fd_envs.FD_ENABLE_RL = original_fd_enable_rl + @requires_deepgemm def test_apply_tp_aux_path(self): """Non-noaux_tc: moe_topk_select → fp8_quant_blockwise → moe_permute → deepgemm → moe_unpermute.""" diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..21bfb0901fd --- /dev/null +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -0,0 +1,497 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import os +import sys +from unittest import mock + +import paddle +import paddle.nn.functional as F +import pytest + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + is_available, +) + +DTYPE_MAP = { + "float16": paddle.float16, + "bfloat16": paddle.bfloat16, + "float32": paddle.float32, +} + + +def _ensure_gpu_test_environment(): + """Ensure GPU runtime and required custom ops are available for this test module.""" + if not paddle.is_compiled_with_cuda(): + pytest.skip( + "fused_cast_sigmoid_bias requires CUDA-enabled Paddle.", + allow_module_level=True, + ) + paddle.set_device("gpu") + + +_ensure_gpu_test_environment() + + +def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Reference implementation: compute in fp32, cast output to cast_type.""" + gate_fp32 = gate_out.cast("float32") + scores_fp32 = F.sigmoid(gate_fp32) + scores_with_bias_fp32 = scores_fp32 + bias + scores = scores_fp32.cast(cast_type) + scores_with_bias = scores_with_bias_fp32.cast(cast_type) + return scores, scores_with_bias + + +def test_functionality(): + """Test basic functionality: correct shapes and dtypes (default cast_type=float32).""" + print("=" * 60) + print("Test 1: Functionality (default cast_type=float32)") + print("=" * 60) + + for dtype_name in ["float16", "bfloat16", "float32"]: + for num_tokens in [1, 7, 128, 1024]: + for num_experts in [8, 64, 128, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + assert scores.shape == [ + num_tokens, + num_experts, + ], f"scores shape mismatch: {scores.shape} vs {[num_tokens, num_experts]}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert scores.dtype == paddle.float32, f"scores dtype mismatch: {scores.dtype}" + assert ( + scores_with_bias.dtype == paddle.float32 + ), f"scores_with_bias dtype mismatch: {scores_with_bias.dtype}" + + # Sigmoid output should be in [0, 1] + assert bool(paddle.all(scores >= 0.0).item()) and bool( + paddle.all(scores <= 1.0).item() + ), "scores out of [0,1] range" + print(f" [PASS] dtype={dtype_name}") + + print(" All functionality tests passed.\n") + + +def test_functionality_cast_types(): + """Test functionality with different cast_type values.""" + print("=" * 60) + print("Test 1b: Functionality with different cast_type") + print("=" * 60) + + for input_dtype in ["float16", "bfloat16", "float32"]: + for cast_type in ["float16", "bfloat16", "float32"]: + expected_paddle_dtype = DTYPE_MAP[cast_type] + for num_tokens in [1, 64, 256]: + for num_experts in [8, 64, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + assert scores.shape == [num_tokens, num_experts], f"scores shape mismatch: {scores.shape}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert ( + scores.dtype == expected_paddle_dtype + ), f"scores dtype mismatch: got {scores.dtype}, expected {expected_paddle_dtype}" + assert ( + scores_with_bias.dtype == expected_paddle_dtype + ), f"scores_with_bias dtype mismatch: got {scores_with_bias.dtype}, expected {expected_paddle_dtype}" + + print(f" [PASS] input_dtype={input_dtype}, cast_type={cast_type}") + + print(" All cast_type functionality tests passed.\n") + + +def test_accuracy(): + """Test numerical accuracy against reference implementation (default cast_type=float32).""" + print("=" * 60) + print("Test 2: Accuracy (default cast_type=float32)") + print("=" * 60) + + test_cases = [ + ("float16", 1, 8), + ("float16", 128, 256), + ("float16", 1024, 256), + ("bfloat16", 1, 8), + ("bfloat16", 128, 256), + ("bfloat16", 1024, 256), + ("float32", 1, 8), + ("float32", 128, 256), + ("float32", 1024, 256), + ] + + for dtype_name, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias) + + # Compare + scores_diff = paddle.abs(fused_scores - ref_scores).max().item() + scores_bias_diff = paddle.abs(fused_scores_with_bias - ref_scores_with_bias).max().item() + + atol = 1e-6 if dtype_name == "float32" else 1e-3 + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] dtype={dtype_name}, tokens={num_tokens}, experts={num_experts} | " + f"scores_max_diff={scores_diff:.2e}, scores_with_bias_max_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for dtype={dtype_name}, tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, scores_bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All accuracy tests passed.\n") + + +def test_accuracy_cast_types(): + """Test numerical accuracy with different cast_type values.""" + print("=" * 60) + print("Test 2b: Accuracy with different cast_type") + print("=" * 60) + + # (input_dtype, cast_type, num_tokens, num_experts) + test_cases = [ + # cast to float32 (original behavior) + ("float16", "float32", 128, 256), + ("bfloat16", "float32", 128, 256), + ("float32", "float32", 128, 256), + # cast to float16 + ("float16", "float16", 128, 256), + ("bfloat16", "float16", 128, 256), + ("float32", "float16", 128, 256), + # cast to bfloat16 + ("float16", "bfloat16", 128, 256), + ("bfloat16", "bfloat16", 128, 256), + ("float32", "bfloat16", 128, 256), + # different shapes + ("bfloat16", "float16", 1, 8), + ("bfloat16", "float16", 1024, 256), + ("float16", "bfloat16", 1, 8), + ("float16", "bfloat16", 1024, 256), + ] + + for input_dtype, cast_type, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Compare in float32 for stable diff computation + scores_diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + scores_bias_diff = ( + paddle.abs(fused_scores_with_bias.cast("float32") - ref_scores_with_bias.cast("float32")).max().item() + ) + + # Tolerance depends on cast_type precision + if cast_type == "float32": + atol = 1e-6 + elif cast_type == "bfloat16": + atol = 1e-2 # bfloat16 has fewer mantissa bits + else: # float16 + atol = 1e-3 + + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts} | " + f"scores_diff={scores_diff:.2e}, bias_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All cast_type accuracy tests passed.\n") + + +def test_accuracy_extreme_values(): + """Test accuracy with extreme input values.""" + print("=" * 60) + print("Test 3: Accuracy with extreme values") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for dtype_name in ["float16", "bfloat16"]: + # Large positive values -> sigmoid ~ 1.0 + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=dtype_name) + bias = paddle.zeros([num_experts], dtype="float32") + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large positive: max_diff={diff:.2e}") + + # Large negative values -> sigmoid ~ 0.0 + gate_out = paddle.full([num_tokens, num_experts], -10.0, dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large negative: max_diff={diff:.2e}") + + # Zero values -> sigmoid = 0.5 + gate_out = paddle.zeros([num_tokens, num_experts], dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + assert diff < 1e-6, f"Zero input test failed: diff={diff}" + print(f" [PASS] dtype={dtype_name}, zeros: max_diff={diff:.2e}") + + print(" All extreme value tests passed.\n") + + +def test_accuracy_extreme_values_cast_types(): + """Test accuracy with extreme values across different cast_type values.""" + print("=" * 60) + print("Test 3b: Accuracy with extreme values + different cast_type") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for input_dtype in ["float16", "bfloat16"]: + for cast_type in ["float16", "bfloat16", "float32"]: + bias = paddle.zeros([num_experts], dtype="float32") + + # Large positive + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + status = "PASS" if diff < atol else "FAIL" + print(f" [{status}] input={input_dtype}, cast={cast_type}, " f"large positive: diff={diff:.2e}") + + # Zero values + gate_out = paddle.zeros([num_tokens, num_experts], dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + assert diff < atol, f"Zero input test failed: input={input_dtype}, cast={cast_type}, diff={diff}" + print(f" [PASS] input={input_dtype}, cast={cast_type}, " f"zeros: diff={diff:.2e}") + + print(" All extreme value cast_type tests passed.\n") + + +@pytest.mark.skipif( + os.getenv("RUN_PERFORMANCE_TESTS") != "1", + reason="Performance benchmark is disabled by default. Set RUN_PERFORMANCE_TESTS=1 to enable.", +) +def test_performance(): + """Benchmark fused kernel vs reference implementation using CUDA events.""" + print("=" * 60) + print("Test 4: Performance (CUDA event timing)") + print("=" * 60) + + configs = [ + ("bfloat16", 1, 256), # single token decode + ("bfloat16", 8, 256), # small batch decode + ("bfloat16", 64, 256), # medium batch + ("bfloat16", 256, 256), # typical DeepSeek-V3 config + ("bfloat16", 1024, 256), # large prefill + ("bfloat16", 4096, 256), # very large prefill + ] + + warmup_iters = 100 + bench_iters = 500 + + for dtype_name, num_tokens, num_experts in configs: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Warmup fused + for _ in range(warmup_iters): + fused_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark fused with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + fused_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + fused_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + # Warmup reference + for _ in range(warmup_iters): + reference_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark reference with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + reference_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + ref_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + speedup = ref_time / fused_time if fused_time > 0 else float("inf") + print( + f" tokens={num_tokens:5d}, experts={num_experts:3d} | " + f"ref={ref_time:8.1f}us, fused={fused_time:8.1f}us, speedup={speedup:.2f}x" + ) + + print() + print(" Note: The CUDA custom op fuses cast+sigmoid+bias into a single kernel,") + print(" eliminating 2 intermediate tensors and reducing kernel launches from 3 to 1.") + print(" Expected speedup: ~3x over the reference 3-op implementation.") + print(" Performance benchmark complete.\n") + + +def test_is_available(): + """Test is_available() function returns True when GPU ops are available.""" + print("=" * 60) + print("Test: is_available()") + print("=" * 60) + + # In normal GPU test environment, is_available should return True + result = is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is True, f"is_available() should return True when GPU ops are compiled, got {result}" + print(f" [PASS] is_available() returned {result}") + print(" is_available() test passed.\n") + + +def test_import_error(): + """Test that ImportError is raised when GPU ops are not available.""" + print("=" * 60) + print("Test 5: Import error handling") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # The module should load successfully, but calling the function + # should raise ImportError because the cuda op is unavailable. + dummy_gate = paddle.randn([1, 8], dtype="float32") + dummy_bias = paddle.randn([8], dtype="float32") + try: + reloaded.fused_cast_sigmoid_bias(dummy_gate, dummy_bias) + raise AssertionError("Expected ImportError was not raised") + except ImportError as e: + assert "fused_cast_sigmoid_bias is not available" in str(e), f"Unexpected error message: {e}" + print(f" [PASS] ImportError raised with correct message: {e}") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" Import error handling test passed.\n") + + +def test_is_available_when_ops_unavailable(): + """Test is_available() returns False when GPU ops are not available.""" + print("=" * 60) + print("Test: is_available() when ops unavailable") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # is_available should return False when ops are not available + result = reloaded.is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is False, f"is_available() should return False when GPU ops are unavailable, got {result}" + print(f" [PASS] is_available() returned {result} when ops unavailable") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" is_available() when ops unavailable test passed.\n") + + +if __name__ == "__main__": + print("Running fused_cast_sigmoid_bias tests...\n") + + test_is_available() + test_functionality() + test_functionality_cast_types() + test_accuracy() + test_accuracy_cast_types() + test_accuracy_extreme_values() + test_accuracy_extreme_values_cast_types() + test_import_error() + test_is_available_when_ops_unavailable() + if os.getenv("RUN_PERFORMANCE_TESTS") == "1": + test_performance() + else: + print("Skipping performance benchmark. Set RUN_PERFORMANCE_TESTS=1 to enable.\n") + + print("=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 98185a04c38..f3854acaf65 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -57,7 +57,7 @@ def name(self): class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.model_config = types.SimpleNamespace(model="dummy", prefix_layer_name="prefix") - self.load_config = types.SimpleNamespace(load_choices=load_choices) + self.load_config = types.SimpleNamespace(load_choices=load_choices, dynamic_load_weight=False) class DummyLayer(paddle.nn.Layer): @@ -394,7 +394,15 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): def fake_get_moe_scores( - gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) @@ -831,6 +839,74 @@ def spy_permute(*args, **kwargs): assert not paddle.isnan(out).any(), "output contains NaN" assert not paddle.isinf(out).any(), "output contains Inf" + def test_apply_tp_noaux_tc_with_use_fused_false(self, monkeypatch): + fc1_called = {"count": 0} + + class FC1Proj(paddle.nn.Layer): + def forward(self, x): + fc1_called["count"] += 1 + return x * 2 + + fc1_latent_proj = FC1Proj() + + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): + return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) + + def fake_dispatch(*args, **kwargs): + return ( + paddle.ones([1, 2]), + paddle.to_tensor([1, 0]), + paddle.to_tensor([0]), + paddle.to_tensor([[0.6, 0.4]]), + paddle.to_tensor([[0, 1]]), + paddle.to_tensor([0]), + None, + None, + ) + + def fake_reduce(*args, **kwargs): + return paddle.ones([1, 2]) * 5 + + def fake_compute_ffn(*args, **kwargs): + return paddle.ones([1, 2]) * 2 + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores, raising=False) + monkeypatch.setattr(backend, "moe_expert_dispatch", fake_dispatch, raising=False) + monkeypatch.setattr(backend, "moe_expert_reduce", fake_reduce, raising=False) + + # Mock compute_ffn on the class to avoid real GPU op data type issues + monkeypatch.setattr(backend.CutlassMoEMethod, "compute_ffn", fake_compute_ffn) + + # Set FD_ENABLE_RL=True to trigger use_fused = False + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + layer = DummyLayer(with_bias=False) + layer.topk_method = "noaux_tc" + # Add necessary attributes for compute_ffn access + layer.up_gate_proj_weight = paddle.zeros([2, 2 * 1], dtype="float16") + layer.down_proj_weight = paddle.zeros([2, 2], dtype="float16") + layer.activation = "silu" + + method = backend.CutlassMoEMethod(None) + + x = paddle.ones([1, 2]) + gate = paddle.nn.Identity() + + method.apply(layer, x, gate, fc1_latent_proj=fc1_latent_proj) + + # Verify fc1_latent_proj was called (line 354/425-426 was executed) + assert fc1_called["count"] > 0, "fc1_latent_proj should have been called" + @requires_cuda def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch): """FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute / @@ -950,3 +1026,74 @@ def spy_permute(*args, **kwargs): assert list(out.shape) == [num_tokens, hidden_size], f"wrong shape: {out.shape}" assert not paddle.isnan(out).any(), "output contains NaN" assert not paddle.isinf(out).any(), "output contains Inf" + + def test_apply_tp_with_both_latent_projs(self, monkeypatch): + """Test apply_tp with both fc1_latent_proj and fc2_latent_proj applied.""" + fc1_called = {"count": 0} + fc2_called = {"count": 0} + + class FC1Proj(paddle.nn.Layer): + def forward(self, x): + fc1_called["count"] += 1 + return x * 2 + + class FC2Proj(paddle.nn.Layer): + def forward(self, x): + fc2_called["count"] += 1 + return x + 10 + + fc1_latent_proj = FC1Proj() + fc2_latent_proj = FC2Proj() + + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): + return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) + + def fake_dispatch(*args, **kwargs): + permute_input = paddle.ones([1, 2]) * 2 # fc1_latent_proj applied + token_nums_per_expert = paddle.to_tensor([1, 0]) + permute_indices_per_token = paddle.to_tensor([0]) + topk_weights = paddle.to_tensor([[0.6, 0.4]]) + topk_idx = paddle.to_tensor([[0, 1]]) + expert_idx_per_token = paddle.to_tensor([0]) + dequant_scale = None + max_tokens_per_expert = None + return ( + permute_input, + token_nums_per_expert, + permute_indices_per_token, + topk_weights, + topk_idx, + expert_idx_per_token, + dequant_scale, + max_tokens_per_expert, + ) + + def fake_reduce(*args, **kwargs): + return paddle.ones([1, 2]) * 5 + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores, raising=False) + monkeypatch.setattr(backend, "moe_expert_dispatch", fake_dispatch, raising=False) + monkeypatch.setattr(backend, "moe_expert_reduce", fake_reduce, raising=False) + + layer = DummyLayer(topk_method="noaux_tc") + method = backend.CutlassMoEMethod(None) + monkeypatch.setattr(method, "compute_ffn", lambda *args, **kwargs: paddle.ones([1, 2]) * 4) + + x = paddle.ones([1, 2]) + gate = paddle.nn.Identity() + out = method.apply_tp(layer, x, gate, fc1_latent_proj=fc1_latent_proj, fc2_latent_proj=fc2_latent_proj) + + # Output should be 5 (from reduce) + 10 (from fc2_latent_proj) = 15 + np.testing.assert_allclose(out.numpy(), np.full((1, 2), 15.0)) + assert fc1_called["count"] == 1, "fc1_latent_proj should be called exactly once" + assert fc2_called["count"] == 1, "fc2_latent_proj should be called exactly once" diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 7dacbbe390d..f196240e1d8 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -695,3 +695,92 @@ def fake_transform_scale_ue8m0(sf, mn, weight_block_size=None): # Verify the quant_weight_ue8m0 branch was executed assert len(quant_calls) > 0, "quant_weight_ue8m0 should have been called" assert len(transform_calls) > 0, "transform_scale_ue8m0 should have been called" + + def test_triton_weight_only_apply_noaux_tc_with_fd_enable_rl(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Set FD_ENABLE_RL=True to trigger use_fused = False at line 313 + # This should trigger gate_out.cast('float32') at line 315 + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + x = paddle.randn([1, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured + + def test_triton_weight_only_apply_noaux_tc_with_non_cuda(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + # Ensure topk_method is "noaux_tc" to enter the target branch + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Mock current_platform.is_cuda() to return False to trigger use_fused = False at line 313 + # This should trigger gate_out.cast("float32") at line 315 + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: False)) + + x = paddle.randn([2, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + def fake_get_moe_scores(*args, **kwargs): + gate_out = args[0] + token_num = gate_out.shape[0] + top_k = args[3] + topk_ids = paddle.zeros([token_num, top_k], dtype="int64") + topk_weights = paddle.ones([token_num, top_k], dtype="float32") + return gate_out, topk_weights, topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured From 80e85bd915fd0860e197202c153c85c522a777c1 Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Mon, 11 May 2026 19:52:00 +0800 Subject: [PATCH 2/3] Add unit_test file --- tests/operators/test_grouped_topk_op.py | 485 ++++++++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 tests/operators/test_grouped_topk_op.py diff --git a/tests/operators/test_grouped_topk_op.py b/tests/operators/test_grouped_topk_op.py new file mode 100644 index 00000000000..1e76328eb93 --- /dev/null +++ b/tests/operators/test_grouped_topk_op.py @@ -0,0 +1,485 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the `grouped_topk` custom CUDA op (low-level interface). + +grouped_topk fuses sigmoid into the kernel and accepts raw logits directly, +unlike noaux_tc which requires Python-side sigmoid preprocessing. + +Algorithm: + 1. scores = sigmoid(gating_output) [fused inside kernel] + 2. scores_with_bias = scores + e_score_correction_bias + 3. group_scores = sum of top-2 biased expert scores per group + 4. Select top-topk_group groups + 5. Within selected groups select top-topk experts by biased score + 6. Gather unbiased sigmoid scores for selected experts as topk_values + 7. Optionally renormalize and scale by routed_scaling_factor + +Model configs covered: + DeepSeek-V3 / R1 num_experts=256, n_group=8, topk_group=4, topk=8, renorm=True, scale=2.5 + GLM-4.5-Air num_experts=128, n_group=1, topk_group=1, topk=8, renorm=True, scale=1.0 + Qwen3-30B-A3B num_experts=128, n_group=4, topk_group=2, topk=8, renorm=False, scale=1.0 + Kimi-K2 num_experts=384, n_group=8, topk_group=2, topk=8, renorm=False, scale=1.0 +""" + +import unittest + +import numpy as np +import paddle + +try: + from fastdeploy.model_executor.ops.gpu import grouped_topk + + _GROUPED_TOPK_AVAILABLE = True +except Exception: + _GROUPED_TOPK_AVAILABLE = False + + +def native_grouped_topk( + gating_output: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): + """Pure-Python reference that mirrors the grouped_topk kernel semantics. + + Args: + gating_output: raw logits, shape [num_tokens, num_experts] + e_score_correction_bias: bias added to sigmoid scores, shape [1, num_experts] or [num_experts] + n_group: number of expert groups + topk_group: number of groups selected per token + topk: number of experts selected per token + renormalize: whether to L1-normalise the selected weights + routed_scaling_factor: multiplicative scale applied after renorm + + Returns: + (scores_out, topk_values, topk_indices) + scores_out – sparse score tensor, shape [num_tokens, num_experts] + topk_values – weights for selected experts, shape [num_tokens, topk] + topk_indices – expert indices, shape [num_tokens, topk] (int64) + """ + num_tokens, num_experts = gating_output.shape + experts_per_group = num_experts // n_group + + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias + + # Step 1: group scores = sum of top-2 biased scores per group + biased = scores_with_bias.reshape([num_tokens, n_group, experts_per_group]) + group_scores = biased.topk(min(2, experts_per_group), axis=-1)[0].sum(axis=-1) + + # Step 2: select top-topk_group groups + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores) + group_mask.put_along_axis_(group_idx, paddle.ones_like(group_idx, dtype=group_mask.dtype), axis=-1) + score_mask = ( + group_mask.unsqueeze(-1).expand([num_tokens, n_group, experts_per_group]).reshape([num_tokens, num_experts]) + ) + + # Step 3: select top-topk experts within selected groups (biased score) + tmp_scores = scores_with_bias.masked_fill(~score_mask.cast(paddle.bool), float("-inf")) + topk_indices = paddle.topk(tmp_scores, topk, axis=-1)[1] + + # Step 4: gather unbiased sigmoid scores + topk_values = paddle.take_along_axis(scores, topk_indices, axis=1) + + # Step 5: renormalize + scale + if renormalize: + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + if routed_scaling_factor != 1.0: + topk_values = topk_values * routed_scaling_factor + + scores_out = paddle.zeros_like(scores) + scores_out.put_along_axis_(topk_indices, topk_values, axis=1) + + return scores_out, topk_values, topk_indices.cast(paddle.int64) + + +@unittest.skipUnless(_GROUPED_TOPK_AVAILABLE, "grouped_topk custom op not available (not compiled)") +class TestGroupedTopkOp(unittest.TestCase): + """Tests for the grouped_topk custom CUDA op.""" + + ATOL = 1e-3 + RTOL = 1e-3 + + def setUp(self): + paddle.seed(42) + + # ------------------------------------------------------------------ + # Parametrised helper + # ------------------------------------------------------------------ + def _run_case( + self, + num_tokens: int, + num_experts: int, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + input_dtype=paddle.float32, + bias_scale: float = 0.1, + seed: int = 42, + ): + paddle.seed(seed) + gating = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = (paddle.rand([1, num_experts], dtype=paddle.float32) - 0.5) * bias_scale + + # Reference always runs in fp32 + gating_fp32 = gating.cast(paddle.float32) if input_dtype != paddle.float32 else gating + ref_scores, ref_tv, ref_ti = native_grouped_topk( + gating_fp32.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + op_scores, op_tv, op_ti = grouped_topk( + gating.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + label = ( + f"T={num_tokens}, E={num_experts}, n_group={n_group}, " + f"topk_group={topk_group}, topk={topk}, " + f"renorm={renormalize}, scale={routed_scaling_factor}, dtype={input_dtype}" + ) + + self.assertEqual(op_tv.shape, [num_tokens, topk], f"[{label}] topk_values shape") + self.assertEqual(op_ti.shape, [num_tokens, topk], f"[{label}] topk_indices shape") + self.assertEqual(op_ti.dtype, paddle.int64, f"[{label}] topk_indices dtype") + self.assertEqual(op_tv.dtype, paddle.float32, f"[{label}] topk_values dtype") + + # Compare set-level index match (position order not guaranteed) + ref_sorted = paddle.sort(ref_ti, axis=-1) + op_sorted = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_sorted, op_sorted).item(): + n_diff = (ref_sorted != op_sorted).sum().item() + self.fail(f"[{label}] topk_indices set mismatch: {n_diff} positions differ") + + # Align values by expert index before comparing + ref_ord = paddle.argsort(ref_ti, axis=-1) + op_ord = paddle.argsort(op_ti, axis=-1) + ref_tv_s = paddle.take_along_axis(ref_tv, ref_ord, axis=-1) + op_tv_s = paddle.take_along_axis(op_tv, op_ord, axis=-1) + if not paddle.allclose(op_tv_s, ref_tv_s, atol=self.ATOL, rtol=self.RTOL).item(): + max_diff = (op_tv_s - ref_tv_s).abs().max().item() + self.fail(f"[{label}] topk_values max_diff={max_diff:.2e}") + + # ------------------------------------------------------------------ + # GLM-4.5-Air: n_experts=128, n_group=1, topk_group=1, topk=8, renorm=True + # ------------------------------------------------------------------ + def test_glm45air_T1(self): + self._run_case(1, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T32(self): + self._run_case(32, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T128(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T512(self): + self._run_case(512, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T1024(self): + self._run_case(1024, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T4096(self): + self._run_case(4096, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T8192(self): + self._run_case(8192, 128, 1, 1, 8, True, 1.0) + + # ------------------------------------------------------------------ + # DeepSeek-V3 / R1: n_experts=256, n_group=8, topk_group=4, topk=8, + # renorm=True, scale=2.5 + # ------------------------------------------------------------------ + def test_deepseek_v3_T1(self): + self._run_case(1, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T32(self): + self._run_case(32, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T128(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T512(self): + self._run_case(512, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T4096(self): + self._run_case(4096, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T8192(self): + self._run_case(8192, 256, 8, 4, 8, True, 2.5) + + # ------------------------------------------------------------------ + # Qwen3-30B-A3B: n_experts=128, n_group=4, topk_group=2, topk=8, + # renorm=False + # ------------------------------------------------------------------ + def test_qwen3_30b_T1(self): + self._run_case(1, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T128(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T512(self): + self._run_case(512, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T4096(self): + self._run_case(4096, 128, 4, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # Kimi-K2: n_experts=384, n_group=8, topk_group=2, topk=8, renorm=False + # ------------------------------------------------------------------ + def test_kimi_k2_T1(self): + self._run_case(1, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T128(self): + self._run_case(128, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T512(self): + self._run_case(512, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T4096(self): + self._run_case(4096, 384, 8, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # bfloat16 input path: kernel should cast internally to fp32 + # ------------------------------------------------------------------ + def test_bf16_input_glm45air(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0, input_dtype=paddle.bfloat16) + + def test_bf16_input_deepseek_v3(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5, input_dtype=paddle.bfloat16) + + def test_bf16_input_qwen3_30b(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0, input_dtype=paddle.bfloat16) + + # ------------------------------------------------------------------ + # Output shape and dtype sanity + # ------------------------------------------------------------------ + def test_output_shapes(self): + """Verify output shapes for various (T, E, topk) combinations.""" + cases = [ + (1, 128, 1, 1, 8), + (32, 256, 8, 4, 8), + (64, 384, 8, 2, 8), + ] + for T, E, ng, tkg, topk in cases: + gating = paddle.randn([T, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertEqual(tv.shape, [T, topk], f"T={T},E={E}: topk_values shape") + self.assertEqual(ti.shape, [T, topk], f"T={T},E={E}: topk_indices shape") + + def test_output_dtype_is_float32(self): + """topk_values must always be float32 regardless of input dtype.""" + for dtype in [paddle.float32, paddle.bfloat16]: + gating = paddle.randn([16, 128], dtype=dtype) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertEqual(tv.dtype, paddle.float32, f"dtype={dtype}: topk_values not float32") + self.assertEqual(ti.dtype, paddle.int64, f"dtype={dtype}: topk_indices not int64") + + # ------------------------------------------------------------------ + # Correctness invariants + # ------------------------------------------------------------------ + def test_topk_indices_in_valid_range(self): + """All selected expert indices must lie in [0, num_experts).""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8), (384, 8, 2, 8)]: + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertTrue((ti >= 0).all().item(), f"E={E}: negative index found") + self.assertTrue((ti < E).all().item(), f"E={E}: index >= num_experts") + + def test_no_duplicate_experts_per_token(self): + """Each token must select exactly topk distinct experts.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + for row in ti.numpy(): + self.assertEqual(len(set(row.tolist())), topk, f"E={E}: duplicate expert indices in row {row}") + + def test_topk_values_non_negative(self): + """Sigmoid output is in (0,1); routing weights must be >= 0.""" + gating = paddle.randn([64, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertTrue((tv >= 0).all().item(), "topk_values contains negative weights") + + def test_renormalized_weights_sum_to_one(self): + """With renormalize=True and scale=1.0, per-token weights sum ≈ 1.""" + num_tokens = 64 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.ones(num_tokens, dtype=np.float32), + atol=1e-3, + err_msg="Renormalized weights do not sum to 1 per token", + ) + + def test_scaled_weights_sum(self): + """With renormalize=True and scale=2.5, per-token weights sum ≈ 2.5.""" + num_tokens, scale = 64, 2.5 + gating = paddle.randn([num_tokens, 256], dtype=paddle.float32) + bias = paddle.zeros([1, 256], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 8, 4, 8, True, scale) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.full(num_tokens, scale, dtype=np.float32), + atol=1e-2, + err_msg=f"Scaled weights do not sum to {scale} per token", + ) + + def test_no_renorm_weights_are_raw_sigmoid(self): + """With renormalize=False, topk_values must equal sigmoid(logits) at selected positions.""" + num_tokens, E = 32, 128 + gating = paddle.randn([num_tokens, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, False, 1.0) + expected = paddle.take_along_axis(paddle.nn.functional.sigmoid(gating), ti, axis=1) + np.testing.assert_allclose( + tv.numpy(), + expected.numpy(), + atol=1e-4, + err_msg="Without renorm, topk_values should equal sigmoid(gating) at selected positions", + ) + + def test_deterministic(self): + """Two identical calls must produce bit-for-bit identical outputs.""" + gating = paddle.randn([32, 256], dtype=paddle.float32) + bias = (paddle.rand([1, 256], dtype=paddle.float32) - 0.5) * 0.1 + _, tv1, ti1 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + _, tv2, ti2 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + self.assertTrue( + paddle.allclose(tv1, tv2, atol=0.0, rtol=0.0).item(), + "topk_values not deterministic across two identical calls", + ) + self.assertTrue( + paddle.equal_all(ti1, ti2).item(), + "topk_indices not deterministic across two identical calls", + ) + + def test_zero_bias(self): + """All-zero bias: biased == unbiased; reference and op must agree.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(16) + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + self.assertTrue( + paddle.equal_all(ref_s, op_s).item(), + f"E={E}/zero_bias: topk_indices set mismatch", + ) + + def test_large_bias_steers_routing(self): + """Large positive bias on first half of experts must dominate selection.""" + E, topk = 128, 8 + paddle.seed(17) + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.concat( + [ + paddle.full([1, E // 2], 2.0, dtype=paddle.float32), + paddle.full([1, E // 2], -2.0, dtype=paddle.float32), + ], + axis=1, + ) + _, _, ti = grouped_topk(gating, bias, 1, 1, topk, True, 1.0) + self.assertTrue( + (ti < E // 2).all().item(), + "Large positive bias on experts [0, E/2) did not steer all selections there", + ) + + def test_extreme_logits_no_nan_inf(self): + """Very large logits must not produce NaN or Inf in outputs.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(18) + gating = paddle.randn([8, E], dtype=paddle.float32) * 50.0 + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, ng, tkg, topk, False, 1.0) + self.assertFalse(paddle.isnan(tv).any().item(), f"E={E}: NaN in topk_values") + self.assertFalse(paddle.isinf(tv).any().item(), f"E={E}: Inf in topk_values") + + def test_single_expert_selected(self): + """topk=1: each token selects exactly one expert; weight == 1.0 with renorm.""" + num_tokens = 16 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 1, True, 1.0) + self.assertEqual(tv.shape, [num_tokens, 1]) + self.assertEqual(ti.shape, [num_tokens, 1]) + np.testing.assert_allclose( + tv.numpy(), + np.ones((num_tokens, 1), dtype=np.float32), + atol=1e-5, + err_msg="With topk=1 and renorm=True, each weight should be 1.0", + ) + + def test_sparse_scores_consistency(self): + """Sparse scores tensor: non-zero at selected positions must equal topk_values; zero elsewhere.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([16, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + s, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + gathered = paddle.take_along_axis(s, ti, axis=1) + np.testing.assert_allclose( + gathered.numpy(), + tv.numpy(), + atol=1e-6, + err_msg=f"E={E}: sparse scores at topk positions != topk_values", + ) + nonzero_count = (s != 0).sum(axis=-1) + self.assertTrue( + (nonzero_count == topk).all().item(), + f"E={E}: non-zero count per token != topk", + ) + + def test_irregular_token_counts(self): + """Non-power-of-2 token counts must produce correct shapes and values.""" + irregular_T = [3, 7, 15, 33, 65, 127, 129, 257, 511, 513, 900] + for T in irregular_T: + gating = paddle.randn([T, 128], dtype=paddle.float32) + bias = (paddle.rand([1, 128], dtype=paddle.float32) - 0.5) * 0.1 + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + self.assertEqual(op_tv.shape, [T, 8], f"T={T}: topk_values shape mismatch") + self.assertEqual(op_ti.shape, [T, 8], f"T={T}: topk_indices shape mismatch") + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_s, op_s).item(): + n_diff = (ref_s != op_s).sum().item() + self.fail(f"T={T}: topk_indices mismatch, {n_diff} positions differ") + + +if __name__ == "__main__": + unittest.main() From 280e2e48d6b2413d87c471f87fc7402adf84cdc0 Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Mon, 11 May 2026 18:20:08 +0800 Subject: [PATCH 3/3] [Ops][Optimization]Kernel fusion: cast+sigmoid+bias+noauxtc --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + custom_ops/gpu_ops/grouped_topk_kernels.cu | 759 ++++++++++++++++++++ custom_ops/setup_ops.py | 2 + fastdeploy/model_executor/layers/moe/moe.py | 35 +- 4 files changed, 795 insertions(+), 12 deletions(-) create mode 100644 custom_ops/gpu_ops/grouped_topk_kernels.cu diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 591cf363f06..e2a2fc1b92f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,15 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); + std::vector FusedCastSigmoidBias(const paddle::Tensor& input, const paddle::Tensor& bias, std::string cast_type); @@ -1704,6 +1713,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("grouped_topk", &grouped_topk, "fused grouped topk for MoE routing"); + m.def("fused_cast_sigmoid_bias", &FusedCastSigmoidBias, "Fused cast+sigmoid+bias for MoE gating scores", diff --git a/custom_ops/gpu_ops/grouped_topk_kernels.cu b/custom_ops/gpu_ops/grouped_topk_kernels.cu new file mode 100644 index 00000000000..d9c99908e72 --- /dev/null +++ b/custom_ops/gpu_ops/grouped_topk_kernels.cu @@ -0,0 +1,759 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "helper.h" + +namespace cg = cooperative_groups; + +constexpr unsigned FUSED_FULL_WARP_MASK = 0xffffffff; + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ inline float cuda_cast(__half val) { + return __half2float(val); +} + +template <> +__device__ inline __half cuda_cast<__half, float>(float val) { + return __float2half(val); +} + +// Numerically stable sigmoid via tanh: σ(x) = 0.5 * tanh(0.5*x) + 0.5 +template +__device__ __forceinline__ T sigmoid_device(T x) { + float xf = cuda_cast(x); + return cuda_cast(0.5f * tanhf(0.5f * xf) + 0.5f); +} + +// Sigmoid matching fused_cast_sigmoid_bias: 1 / (1 + exp(-x)). +// Must use the same formula to get bit-identical results when comparing +// against the fused_cast_sigmoid_bias + noaux_tc path. +template +__device__ __forceinline__ float sigmoid_to_float(InT x) { + float xf = cuda_cast(x); + return 1.0f / (1.0f + expf(-xf)); +} + +template +__device__ inline T neg_inf() { + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + +template +__device__ inline bool is_finite_val(T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + +namespace warp_topk_fused { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) return 0; + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, + T baseline, + idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + return res; +} + +template +struct BitonicMerge { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) + is_better = is_better_than( + val, other_val, idx_arr[i], idx_arr[other_i]); + else + is_better = is_better_than(val, other_val); + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = + __shfl_xor_sync(FUSED_FULL_WARP_MASK, *idx_arr, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FUSED_FULL_WARP_MASK, idx, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + __device__ __forceinline__ idxT get_idx(int i = 0) const { + return idx_arr_[i]; + } + __device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + int const lane_; + idxT const k_; + T const dummy_; +}; + +// WarpSelect WITHOUT __syncthreads() in done() — safe when only one warp is +// active. +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_idx_(0), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + else + do_add = is_better_than(val, k_th_); + + uint32_t mask = __ballot_sync(FUSED_FULL_WARP_MASK, do_add); + if (mask == 0) return; + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + // NOTE: no __syncthreads() here — callers must sync externally if needed. + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync( + FUSED_FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) + k_th_idx_ = __shfl_sync( + FUSED_FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + T& old = val_arr_[max_arr_len_ - 1]; + bool is_better; + if constexpr (is_stable) + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + else + is_better = is_better_than(val, old); + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + BitonicMerge::merge( + val_arr_, idx_arr_); + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; + +} // namespace warp_topk_fused + +// --------------------------------------------------------------------------- +// Fused kernel: group-score computation + group selection + expert topk +// + sparse scores write-back, in one kernel launch. +// +// gridDim = num_tokens (one block per token) +// blockDim = n_group * WARP_SIZE (one warp per group) +// --------------------------------------------------------------------------- +template +__global__ void grouped_topk_fused_kernel( + float* scores, // output: sparse routing weights [num_tokens, num_experts] + float* topk_values, // output: topk routing weights [num_tokens, topk] + IdxT* topk_indices, // output: topk expert indices [num_tokens, topk] + InT const* gating_output, // input: raw logits (float or bf16) + // [num_tokens, num_experts] + float const* e_score_correction_bias, // input: bias [num_experts] + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double routed_scaling_factor) { + int32_t const token_id = static_cast(blockIdx.x); + if (token_id >= static_cast(num_tokens)) return; + + int32_t const warp_id = threadIdx.x / WARP_SIZE; + int32_t const lane_id = threadIdx.x % WARP_SIZE; + int32_t const n_group_i32 = static_cast(n_group); + int32_t const topk_group_i32 = static_cast(topk_group); + int32_t const topk_i32 = static_cast(topk); + int32_t const num_warps = blockDim.x / WARP_SIZE; + + if (warp_id >= n_group_i32 || num_warps < n_group_i32) return; + + int32_t const num_experts_per_group = + static_cast(num_experts) / n_group_i32; + int32_t const align_epg = warp_topk_fused::round_up_to_multiple_of( + num_experts_per_group); + + InT const* gate_token = gating_output + (int64_t)token_id * num_experts; + float* scores_token = scores + (int64_t)token_id * num_experts; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + // smem layout: [val_staging (256B-aligned) | idx_staging | (16B pad) | + // s_group_scores] + extern __shared__ char smem_buf[]; + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + uintptr_t ptr = + (reinterpret_cast(smem_buf + val_aligned + idx_bytes) + 15) & + ~static_cast(15); + float* s_group_scores = reinterpret_cast(ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // ------------------------------------------------------------------ + // Phase 1 (all warps): compute group score = top2 sum of (gate + bias) + // ------------------------------------------------------------------ + { + int32_t const offset = warp_id * num_experts_per_group; + InT const* gate_g = gate_token + offset; + float const* bias_g = e_score_correction_bias + offset; + + float largest = neg_inf(); + float second_largest = neg_inf(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + float val = sigmoid_to_float(gate_g[i]) + bias_g[i]; + if (val > largest) { + second_largest = largest; + largest = val; + } else if (val > second_largest) { + second_largest = val; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) + largest = sigmoid_to_float(gate_g[i]) + bias_g[i]; + } + + float max1 = cg::reduce(tile, largest, cg::greater()); + float max2 = max1; + int cnt = __popc(__ballot_sync(FUSED_FULL_WARP_MASK, largest == max1)); + if (cnt == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + if (lane_id == 0) s_group_scores[warp_id] = max1 + max2; + } + + __syncthreads(); + + // ------------------------------------------------------------------ + // Phase 2 (warp0 only): group selection → expert selection → output + // ------------------------------------------------------------------ + if (warp_id != 0) return; + + topk_values += (int64_t)token_id * topk; + topk_indices += (int64_t)token_id * topk; + + // Select top-topk_group groups + warp_topk_fused::WarpSelect group_sel( + topk_group_i32, neg_inf()); + + float gscore = + (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf(); + group_sel.add(gscore, lane_id); + group_sel.done(); // no __syncthreads() — only warp0 is active here + + // Check if enough valid groups exist + bool proceed = false; + if (topk_group_i32 > 0) { + float kth = __shfl_sync( + FUSED_FULL_WARP_MASK, group_sel.get_val(0), topk_group_i32 - 1); + proceed = (kth != neg_inf()); + } + + if (!proceed) { + // Fallback: zero scores, uniform topk + for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) + scores_token[i] = 0.0f; + __syncwarp(); + for (int i = lane_id; i < topk_i32; i += WARP_SIZE) { + topk_indices[i] = static_cast(i); + topk_values[i] = 1.0f / static_cast(topk_i32); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif + return; + } + + // Select top-topk experts from selected groups (using biased scores as + // candidates) + warp_topk_fused::WarpSelect expert_sel( + topk_i32, neg_inf()); // reuses same smem — group_sel is done + + int32_t sel_gid = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0; + for (int32_t g = 0; g < topk_group_i32; ++g) { + int32_t gid = __shfl_sync(FUSED_FULL_WARP_MASK, sel_gid, g); + int32_t offset = gid * num_experts_per_group; + for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { + float cand = neg_inf(); + int32_t idx = 0; + if (i < num_experts_per_group) { + idx = offset + i; + float biased = + sigmoid_to_float(gate_token[idx]) + e_score_correction_bias[idx]; + if (is_finite_val(biased)) cand = biased; + } + expert_sel.add(cand, idx); + } + } + expert_sel.done(); + + // Compute routing weights from unbiased scores + float lane_score = 0.0f; + IdxT lane_idx = 0; + if (lane_id < topk_i32) { + lane_idx = static_cast(expert_sel.get_idx(0)); + lane_score = sigmoid_to_float(gate_token[static_cast(lane_idx)]); + } + + float topk_sum = 1e-20f; + if (renormalize) topk_sum += cg::reduce(tile, lane_score, cg::plus()); + + float scale = static_cast(routed_scaling_factor); + if (renormalize) scale /= topk_sum; + + // Fill sparse scores: first zero out, then write selected experts' weights + for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) + scores_token[i] = 0.0f; + __syncwarp(); + + if (lane_id < topk_i32) { + float val = lane_score * scale; + scores_token[static_cast(lane_idx)] = val; + topk_indices[lane_id] = lane_idx; + topk_values[lane_id] = val; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// --------------------------------------------------------------------------- +// Launch wrapper +// --------------------------------------------------------------------------- +template +void invokeFusedNoAuxTc(InT* gating_output, + float* e_score_correction_bias, + float* scores, + float* topk_values, + IdxT* topk_indices, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + cudaStream_t const stream) { + auto* kernel = &grouped_topk_fused_kernel; + + // blockDim = n_group * WARP_SIZE (one warp per group) + int32_t const num_warps = static_cast(n_group); + + // smem = WarpSelect staging (float) + 16B pad + group_scores buffer (float) + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(float); + size_t const smem_bytes = val_aligned + idx_bytes + extra_bytes; + + cudaLaunchConfig_t config; + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + config.dynamicSmemBytes = smem_bytes; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + kernel, + scores, + topk_values, + topk_indices, + gating_output, + e_score_correction_bias, + num_tokens, + num_experts, + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor); +} + +#define INSTANTIATE_FUSED_NOAUX_TC(InT, IdxT) \ + template void invokeFusedNoAuxTc( \ + InT * gating_output, \ + float* e_score_correction_bias, \ + float* scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); + +INSTANTIATE_FUSED_NOAUX_TC(float, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__nv_bfloat16, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__half, int64_t); + +// --------------------------------------------------------------------------- +// Paddle op wrapper +// --------------------------------------------------------------------------- +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor) { + auto input_shape = gating_output.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto place = gating_output.place(); + PD_CHECK(n_group <= 32, "grouped_topk: n_group must be <= 32"); + PD_CHECK(topk <= 32, "grouped_topk: topk must be <= WARP_SIZE (32)"); + + // Outputs are always float32 regardless of input dtype + auto scores = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto topk_values = + paddle::empty({num_tokens, topk}, paddle::DataType::FLOAT32, place); + auto topk_indices = + paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + + auto stream = gating_output.stream(); + auto dtype = gating_output.dtype(); + + float* scores_ptr = reinterpret_cast(scores.data()); + float* topk_values_ptr = reinterpret_cast(topk_values.data()); + int64_t* topk_idx_ptr = + reinterpret_cast(topk_indices.data()); + float* bias_ptr = + reinterpret_cast(e_score_correction_bias.data()); + + if (dtype == paddle::DataType::BFLOAT16) { + invokeFusedNoAuxTc<__nv_bfloat16, int64_t>( + reinterpret_cast<__nv_bfloat16*>( + gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else if (dtype == paddle::DataType::FLOAT16) { + invokeFusedNoAuxTc<__half, int64_t>( + reinterpret_cast<__half*>(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else { + PD_CHECK( + dtype == paddle::DataType::FLOAT32, + "grouped_topk: gating_output must be float32, float16, or bfloat16"); + invokeFusedNoAuxTc( + reinterpret_cast(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } + + return {scores, topk_values, topk_indices}; +} + +std::vector GroupedTopkInferDtype( + const paddle::DataType& /*gating_output_dtype*/, + const paddle::DataType& /*e_score_correction_bias_dtype*/) { + // Outputs are always float32: cast is fused into the kernel. + return {paddle::DataType::FLOAT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT64}; +} + +std::vector> GroupedTopkInferShape( + const std::vector& gating_output_shape, + const std::vector&, + const int topk) { + auto num_tokens = gating_output_shape[0]; + auto num_experts = gating_output_shape[1]; + return {{num_tokens, num_experts}, {num_tokens, topk}, {num_tokens, topk}}; +} + +PD_BUILD_STATIC_OP(grouped_topk) + .Inputs({"gating_output", "e_score_correction_bias"}) + .Outputs({"output_tensor", "topk_values", "topk_indices"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk: int", + "renormalize: bool", + "routed_scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(grouped_topk)) + .SetInferShapeFn(PD_INFER_SHAPE(GroupedTopkInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GroupedTopkInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 7ae1e964761..27043cc946e 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -330,6 +330,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", @@ -685,6 +686,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 5dc20cac0a2..b150f588648 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -36,16 +36,18 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant + from fastdeploy.model_executor.ops.gpu import ( + grouped_topk, + noaux_tc, + noaux_tc_redundant, + ) except: logger.warning("import noaux_tc Failed!") import numpy as np if current_platform.is_cuda(): - from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, - ) + pass def get_moe_method(layer=None): @@ -103,11 +105,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - if use_fused_cast and current_platform.is_cuda(): - scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") - else: - scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias + use_fused = use_fused_cast and current_platform.is_cuda() if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc @@ -116,7 +114,9 @@ def get_moe_scores( renormalize = False routed_scaling_factor = 1.0 - if expert_id_to_ep_rank_array is None: + if expert_id_to_ep_rank_array is None and not use_fused: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, @@ -126,9 +126,20 @@ def get_moe_scores( renormalize, routed_scaling_factor, ) + elif expert_id_to_ep_rank_array is None and use_fused: + # fused kernel: cast + sigmoid + add + noaux_tc + scores, topk_values, topk_idx = grouped_topk( + gating_output, + e_score_correction_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + ) else: - # noaux_tc_redundant returns 4 values: scores, topk_values, topk_idx, - # and tokens_per_expert_stats_list_out (inplace updated) + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx, _ = noaux_tc_redundant( scores, scores_with_bias,