From 42dfaf9d39ea64edf245d110513a239f24b1e90c Mon Sep 17 00:00:00 2001 From: zhangxiao35 Date: Tue, 2 Jun 2026 13:28:20 +0800 Subject: [PATCH] [SOT] Support flashinfer_allreduce --- custom_ops/gpu_ops/cpp_extensions.cc | 39 +++- custom_ops/gpu_ops/trtllm_allreduce_op.cc | 182 ++++++++++++++++++ custom_ops/setup_ops.py | 14 +- .../layers/flashinfer_comm_fusion.py | 64 +++--- .../layers/flashinfer_comm_op.py | 163 ++++++++++++++++ 5 files changed, 426 insertions(+), 36 deletions(-) create mode 100644 custom_ops/gpu_ops/trtllm_allreduce_op.cc create mode 100644 fastdeploy/model_executor/layers/flashinfer_comm_op.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 6718e97d56c..80f511a75c3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -2072,5 +2072,42 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("config_for_attention", &ConfigForAttention, "config for attention function"); -#endif + + /** + * trtllm_allreduce_op.cc + * FlashInfer fused allreduce + residual + RMSNorm via trtllm kernel + */ + bool InitTrtllmSo(const std::string& so_path); + std::vector TrtllmAllreduceResidualRmsnorm( + const paddle::Tensor& input_tensor, + const paddle::Tensor& residual, + const paddle::Tensor& weight, + const paddle::Tensor& workspace_ptrs, + int64_t world_size, + int64_t world_rank, + bool use_oneshot, + bool trigger_completion_at_end, + bool fp32_acc, + double rms_eps); + + m.def("init_trtllm_so", + &InitTrtllmSo, + py::arg("so_path"), + "dlopen trtllm_comm.so and register TVM FFI functions. " + "Must be called once before trtllm_allreduce_residual_rmsnorm."); + + m.def( + "trtllm_allreduce_residual_rmsnorm", + &TrtllmAllreduceResidualRmsnorm, + py::arg("input_tensor"), + py::arg("residual"), + py::arg("weight"), + py::arg("workspace_ptrs"), + py::arg("world_size"), + py::arg("world_rank"), + py::arg("use_oneshot"), + py::arg("trigger_completion_at_end"), + py::arg("fp32_acc"), + py::arg("rms_eps"), + "Fused allreduce + residual add + RMSNorm via FlashInfer trtllm kernel."); } diff --git a/custom_ops/gpu_ops/trtllm_allreduce_op.cc b/custom_ops/gpu_ops/trtllm_allreduce_op.cc new file mode 100644 index 00000000000..7e9ee613807 --- /dev/null +++ b/custom_ops/gpu_ops/trtllm_allreduce_op.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// C++ wrapper for flashinfer trtllm_allreduce_fusion. +// Converts paddle::Tensor -> DLTensor -> tvm::ffi::TensorView, then calls +// the registered TVM FFI function from the pre-compiled trtllm_comm.so. + +#include +#include +#include +#include + +#include "paddle/extension.h" + +// TVM FFI C++ API (umbrella header) +#include "tvm/ffi/tvm_ffi.h" + +// --------------------------------------------------------------------------- +// Global state: dlopen handle for trtllm_comm.so +// --------------------------------------------------------------------------- + +static void* g_trtllm_so_handle = nullptr; + +static bool LoadTrtllmSo(const std::string& so_path) { + if (g_trtllm_so_handle != nullptr) return true; + g_trtllm_so_handle = dlopen(so_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (!g_trtllm_so_handle) { + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// Paddle DataType -> DLDataType +// --------------------------------------------------------------------------- + +static DLDataType PaddleDTypeToDL(paddle::DataType dtype) { + switch (dtype) { + case paddle::DataType::FLOAT16: + return DLDataType{kDLFloat, 16, 1}; + case paddle::DataType::BFLOAT16: + return DLDataType{kDLBfloat, 16, 1}; + case paddle::DataType::FLOAT32: + return DLDataType{kDLFloat, 32, 1}; + case paddle::DataType::INT64: + return DLDataType{kDLInt, 64, 1}; + case paddle::DataType::INT32: + return DLDataType{kDLInt, 32, 1}; + case paddle::DataType::UINT8: + return DLDataType{kDLUInt, 8, 1}; + default: + throw std::runtime_error("PaddleDTypeToDL: unsupported dtype"); + } +} + +// --------------------------------------------------------------------------- +// Helper: paddle::Tensor -> DLTensor (shape stored in out_shape vector) +// +// shape_buf must outlive the returned DLTensor (shape ptr points into it). +// Call MakeTensorView() to get a TensorView (TensorView copies DLTensor). +// --------------------------------------------------------------------------- + +static DLTensor MakeDLTensor(const paddle::Tensor& t, + std::vector& shape_buf) { + int ndim = static_cast(t.dims().size()); + shape_buf.resize(ndim); + for (int i = 0; i < ndim; ++i) shape_buf[i] = t.dims()[i]; + + DLTensor dl; + dl.data = const_cast(t.data()); + dl.device = DLDevice{kDLCUDA, t.place().GetDeviceId()}; + dl.ndim = ndim; + dl.shape = shape_buf.data(); + dl.strides = nullptr; + dl.byte_offset = 0; + dl.dtype = PaddleDTypeToDL(t.dtype()); + return dl; +} + +// TensorView copies DLTensor by value (shape ptr NOT copied — still points +// into the original DLTensor's shape array). Keep the DLTensor and its +// shape_buf alive for the duration of the TVM FFI call. +static tvm::ffi::TensorView MakeTensorView(const DLTensor& dl) { + return tvm::ffi::TensorView(&dl); +} + +// --------------------------------------------------------------------------- +// Main kernel: paddle::Tensor in -> (norm_out, residual_out) +// --------------------------------------------------------------------------- +// +// AllReduceFusionPattern::kARResidualRMSNorm == 1 +// (from flashinfer/comm/trtllm_allreduce_fusion.cuh) +// +static constexpr int64_t kARResidualRMSNorm = 1; + +std::vector TrtllmAllreduceResidualRmsnorm( + const paddle::Tensor& input_tensor, + const paddle::Tensor& residual, + const paddle::Tensor& weight, + const paddle::Tensor& workspace_ptrs, + int64_t world_size, + int64_t world_rank, + bool use_oneshot, + bool trigger_completion_at_end, + bool fp32_acc, + double rms_eps) { + if (g_trtllm_so_handle == nullptr) { + throw std::runtime_error( + "TrtllmAllreduceResidualRmsnorm: trtllm_comm.so not loaded. " + "Call init_trtllm_so(path) first."); + } + + auto norm_out = paddle::empty_like(input_tensor); + auto residual_out = paddle::empty_like(residual); + + int64_t token_num = input_tensor.dims()[0]; + // Support empty tensor + if (token_num == 0) { + return {norm_out, residual_out}; + } + + // Build DLTensors + keep shape buffers alive for the duration of the call. + std::vector sh_input, sh_residual, sh_weight, sh_workspace, + sh_residual_out, sh_norm_out; + DLTensor dl_input = MakeDLTensor(input_tensor, sh_input); + DLTensor dl_residual = MakeDLTensor(residual, sh_residual); + DLTensor dl_weight = MakeDLTensor(weight, sh_weight); + DLTensor dl_workspace = MakeDLTensor(workspace_ptrs, sh_workspace); + DLTensor dl_residual_out = MakeDLTensor(residual_out, sh_residual_out); + DLTensor dl_norm_out = MakeDLTensor(norm_out, sh_norm_out); + + int64_t hidden_dim = input_tensor.dims()[1]; + + // Look up the TVM FFI registered function (registered when .so was dlopen'd) + auto fn_opt = tvm::ffi::Function::GetGlobal("trtllm_allreduce_fusion"); + if (!fn_opt.has_value()) { + throw std::runtime_error( + "TrtllmAllreduceResidualRmsnorm: 'trtllm_allreduce_fusion' not found " + "in TVM FFI global registry."); + } + auto& fn = *fn_opt; + + using OptTV = tvm::ffi::Optional; + + fn(MakeTensorView(dl_input), + world_size, + world_rank, + token_num, + hidden_dim, + MakeTensorView(dl_workspace), + /*launch_with_pdl=*/true, + use_oneshot, + trigger_completion_at_end, + fp32_acc, + kARResidualRMSNorm, + OptTV(std::nullopt), // allreduce_out = None + OptTV(MakeTensorView(dl_residual)), // residual_in + OptTV(MakeTensorView(dl_residual_out)), // residual_out + OptTV(MakeTensorView(dl_norm_out)), // norm_out + OptTV(std::nullopt), // quant_out = None + OptTV(std::nullopt), // scale_out = None + OptTV(MakeTensorView(dl_weight)), // rms_gamma + tvm::ffi::Optional(rms_eps), + OptTV(std::nullopt), // scale_factor = None + tvm::ffi::Optional(std::nullopt) // layout_code = None + ); + + return {norm_out, residual_out}; +} + +bool InitTrtllmSo(const std::string& so_path) { return LoadTrtllmSo(so_path); } diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index a7b42f3ae74..019f745c516 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -341,6 +341,7 @@ def find_end_files(directory, end_str): "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", + "gpu_ops/trtllm_allreduce_op.cc", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", "gpu_ops/update_attn_mask_offsets.cu", @@ -407,6 +408,15 @@ def find_end_files(directory, end_str): "-Igpu_ops", "-Ithird_party/nlohmann_json/include", ] + + # tvm_ffi include for trtllm_allreduce_op.cc + import tvm_ffi as _tvm_ffi_mod + + _tvm_ffi_base = os.path.dirname(_tvm_ffi_mod.__file__) + _tvm_ffi_include = os.path.join(_tvm_ffi_base, "include") + _tvm_ffi_lib = os.path.join(_tvm_ffi_base, "lib") + cc_compile_args += [f"-I{_tvm_ffi_include}"] + nvcc_compile_args += [f"-I{_tvm_ffi_include}"] max_jobs, nvcc_threads = get_compile_parallelism() print(f"MAX_JOBS = {max_jobs}, nvcc -t = {nvcc_threads}") nvcc_compile_args += ["-t", str(nvcc_threads)] @@ -588,8 +598,8 @@ def find_end_files(directory, end_str): ext_modules=CUDAExtension( sources=sources, extra_compile_args={"cxx": cc_compile_args, "nvcc": nvcc_compile_args}, - libraries=["cublasLt"], - extra_link_args=["-lcuda", "-lnvidia-ml"], + libraries=["cublasLt", "tvm_ffi"], + extra_link_args=["-lcuda", "-lnvidia-ml", f"-L{_tvm_ffi_lib}", f"-Wl,-rpath,{_tvm_ffi_lib}", "-ldl"], ), packages=find_packages(where="third_party/DeepGEMM"), package_dir={"": "third_party/DeepGEMM"}, diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py index 7f27b52975d..44d783a9ef3 100644 --- a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -28,6 +28,10 @@ _flashinfer_comm = None _workspace_manager = None +# Oneshot heuristics from flashinfer: world_size -> comm_size_mb threshold +# Source: flashinfer/comm/trtllm_ar.py::_use_oneshot_heuristics +_USE_ONESHOT_HEURISTICS = {2: 512, 4: 64, 8: 42} + def _get_flashinfer_comm(): """Lazily import flashinfer.comm to avoid side effects at module load time.""" @@ -145,7 +149,11 @@ def flashinfer_allreduce_residual_rmsnorm( fp32_acc: bool = False, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """ - Use FlashInfer's fused allreduce + residual + RMS norm operation + Use FlashInfer's fused allreduce + residual + RMS norm operation. + + The actual kernel is dispatched through trtllm_allreduce_residual_rmsnorm_op + (from flashinfer_comm_op.py), which is registered as a Paddle custom Python + op so that SOT treats the entire call as a single opaque node. """ comm = _get_flashinfer_comm() if not has_flashinfer() or comm is None: @@ -158,8 +166,6 @@ def flashinfer_allreduce_residual_rmsnorm( logger.debug("Single GPU, no need for allreduce fusion") return None, None - assert input_tensor.shape[0] <= max_token_num - if not ensure_workspace_initialized( fd_config=fd_config, max_token_num=max_token_num, @@ -169,38 +175,30 @@ def flashinfer_allreduce_residual_rmsnorm( logger.debug("FlashInfer workspace not available") return None, None - token_num, hidden_dim = input_tensor.shape - - residual_out = paddle.empty_like(residual) - norm_out = paddle.empty_like(input_tensor) - # support empty tensor - if input_tensor.shape[0] == 0: - return norm_out, residual_out - comm.trtllm_allreduce_fusion( - allreduce_in=input_tensor, - world_size=world_size, - world_rank=dist.get_rank(), - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=_workspace_manager.workspace_tensor, - launch_with_pdl=True, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), - allreduce_out=None, - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - quant_out=None, - scale_out=None, - rms_gamma=weight, - rms_eps=eps, - scale_factor=None, - layout_code=None, + # Compute use_oneshot from concrete values before SOT-traced region. + # max_token_num is always a concrete Python int (model config parameter). + # input_tensor.shape[-1] is hidden_size, always a concrete model constant. + if use_oneshot is None: + hidden_dim = input_tensor.shape[-1] + comm_size_mb = max_token_num * hidden_dim * 2 * world_size * 2 / 1024 / 1024 + use_oneshot = comm_size_mb <= _USE_ONESHOT_HEURISTICS.get(world_size, 0) + + from fastdeploy.model_executor.layers.flashinfer_comm_op import ( + trtllm_allreduce_residual_rmsnorm_op, ) - return norm_out, residual_out + return trtllm_allreduce_residual_rmsnorm_op( + input_tensor, + residual, + weight, + _workspace_manager.workspace_tensor, + world_size, + dist.get_rank(), + bool(use_oneshot), + trigger_completion_at_end, + fp32_acc, + eps, + ) def cleanup_flashinfer_workspace(): diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_op.py b/fastdeploy/model_executor/layers/flashinfer_comm_op.py new file mode 100644 index 00000000000..2db6a225901 --- /dev/null +++ b/fastdeploy/model_executor/layers/flashinfer_comm_op.py @@ -0,0 +1,163 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python glue layer for the flashinfer trtllm_allreduce_residual_rmsnorm op. + +Loads trtllm_comm.so via tvm_ffi.load_module, retrieves the +trtllm_allreduce_fusion function from the TVM module, then wraps the kernel +call with @register_custom_python_op so that Paddle SOT treats it as a +single opaque op node (no tracing into the function body). +""" + +from typing import Tuple + +import paddle + +from fastdeploy.utils import get_logger, register_custom_python_op + +logger = get_logger("flashinfer", "flashinfer.log") + +# --------------------------------------------------------------------------- +# Lazy init: load trtllm_comm.so via tvm_ffi.load_module() +# --------------------------------------------------------------------------- + +_initialized = False +_trtllm_fn = None # tvm_ffi.core.Function for trtllm_allreduce_fusion + + +def _ensure_trtllm_so_loaded(): + global _initialized, _trtllm_fn + if _initialized: + return True + + try: + import paddle as _paddle + import tvm_ffi + + with _paddle.use_compat_guard(enable=True, scope={"flashinfer"}): + from flashinfer.jit import env as jit_env + from flashinfer.jit.comm import gen_trtllm_comm_module + + so_path = jit_env.FLASHINFER_JIT_DIR / "trtllm_comm" / "trtllm_comm.so" + if not so_path.exists(): + from flashinfer.jit.core import build_jit_specs + + build_jit_specs([gen_trtllm_comm_module()]) + + mod = tvm_ffi.load_module(str(so_path)) + _trtllm_fn = mod["trtllm_allreduce_fusion"] + + _initialized = True + logger.info("flashinfer_comm_op: trtllm_comm.so loaded via tvm_ffi.") + except Exception as e: + logger.warning(f"flashinfer_comm_op: init failed: {e}") + return False + + return True + + +# --------------------------------------------------------------------------- +# infer_meta for register_custom_python_op +# --------------------------------------------------------------------------- + + +def _trtllm_ar_rmsnorm_infer_meta( + input_tensor: "paddle.static.MetaTensor", + residual: "paddle.static.MetaTensor", + weight: "paddle.static.MetaTensor", + workspace_ptrs: "paddle.static.MetaTensor", + world_size: int, + world_rank: int, + use_oneshot: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + rms_eps: float, +) -> Tuple["paddle.static.MetaTensor", "paddle.static.MetaTensor"]: + norm_out = paddle.static.MetaTensor(shape=input_tensor.shape, dtype=input_tensor.dtype) + residual_out = paddle.static.MetaTensor(shape=residual.shape, dtype=residual.dtype) + return norm_out, residual_out + + +# --------------------------------------------------------------------------- +# @register_custom_python_op wrapper +# +# SOT sees this as a single opaque Paddle op and calls _infer_meta for shape +# inference; the function body is never traced. +# --------------------------------------------------------------------------- + +# AllReduceFusionPattern::kARResidualRMSNorm == 1 +_kARResidualRMSNorm = 1 + + +@register_custom_python_op( + name="trtllm_allreduce_residual_rmsnorm", + infer_meta=_trtllm_ar_rmsnorm_infer_meta, + input_names=["input_tensor", "residual", "weight", "workspace_ptrs"], + output_names=["norm_out", "residual_out"], + inplace_map={}, +) +def trtllm_allreduce_residual_rmsnorm_op( + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + workspace_ptrs: paddle.Tensor, + world_size: int, + world_rank: int, + use_oneshot: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + rms_eps: float, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Fused allreduce + residual add + RMSNorm via FlashInfer trtllm C++ kernel. + + This function is opaque to Paddle SOT — _trtllm_fn is a tvm_ffi.core.Function + (C++ kernel), called directly without any Python-level tensor operations that + SOT would trace. + """ + if not _ensure_trtllm_so_loaded(): + return None, None + + token_num = input_tensor.shape[0] + hidden_dim = input_tensor.shape[1] + + norm_out = paddle.empty_like(input_tensor) + residual_out = paddle.empty_like(residual) + + if token_num == 0: + return norm_out, residual_out + + _trtllm_fn( + input_tensor, + world_size, + world_rank, + token_num, + hidden_dim, + workspace_ptrs, + True, # launch_with_pdl + use_oneshot, + trigger_completion_at_end, + fp32_acc, + _kARResidualRMSNorm, + None, # allreduce_out + residual, # residual_in + residual_out, # residual_out + norm_out, # norm_out + None, # quant_out + None, # scale_out + weight, # rms_gamma + rms_eps, + None, # scale_factor + None, # layout_code + ) + return norm_out, residual_out