Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2072,5 +2072,42 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("config_for_attention",
&ConfigForAttention,
"config for attention function");
#endif

/**

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 删除 #endif 导致 trtllm pybind11 注册落入 #ifdef ENABLE_DECODE_UNIFIED_ATTENTION 条件编译块内

上方第 2051 行有 #ifdef ENABLE_DECODE_UNIFIED_ATTENTION,原来此处的 #endif 是其闭合。删除后,新增的 init_trtllm_so / trtllm_allreduce_residual_rmsnorm pybind11 注册代码被置于该 #ifdef 内,仅在 SM ≥ 90 时编译生效。虽然当前 Python 侧通过 tvm_ffi.load_module 直接调用而未使用 pybind11 路径,但条件编译作用域错误仍应修复以避免后续维护风险。

建议修复方式:在新增代码之前恢复 #endif

        "config for attention function");
#endif

  /**
   * trtllm_allreduce_op.cc

* trtllm_allreduce_op.cc
* FlashInfer fused allreduce + residual + RMSNorm via trtllm kernel
*/
bool InitTrtllmSo(const std::string& so_path);
std::vector<paddle::Tensor> TrtllmAllreduceResidualRmsnorm(
const paddle::Tensor& input_tensor,
const paddle::Tensor& residual,
const paddle::Tensor& weight,
const paddle::Tensor& workspace_ptrs,
int64_t world_size,
int64_t world_rank,
bool use_oneshot,
bool trigger_completion_at_end,
bool fp32_acc,
double rms_eps);

m.def("init_trtllm_so",
&InitTrtllmSo,
py::arg("so_path"),
"dlopen trtllm_comm.so and register TVM FFI functions. "
"Must be called once before trtllm_allreduce_residual_rmsnorm.");

m.def(
"trtllm_allreduce_residual_rmsnorm",
&TrtllmAllreduceResidualRmsnorm,
py::arg("input_tensor"),
py::arg("residual"),
py::arg("weight"),
py::arg("workspace_ptrs"),
py::arg("world_size"),
py::arg("world_rank"),
py::arg("use_oneshot"),
py::arg("trigger_completion_at_end"),
py::arg("fp32_acc"),
py::arg("rms_eps"),
"Fused allreduce + residual add + RMSNorm via FlashInfer trtllm kernel.");
}
182 changes: 182 additions & 0 deletions custom_ops/gpu_ops/trtllm_allreduce_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// C++ wrapper for flashinfer trtllm_allreduce_fusion.
// Converts paddle::Tensor -> DLTensor -> tvm::ffi::TensorView, then calls
// the registered TVM FFI function from the pre-compiled trtllm_comm.so.

#include <dlfcn.h>
#include <stdexcept>
#include <string>
#include <vector>

#include "paddle/extension.h"

// TVM FFI C++ API (umbrella header)
#include "tvm/ffi/tvm_ffi.h"

// ---------------------------------------------------------------------------
// Global state: dlopen handle for trtllm_comm.so
// ---------------------------------------------------------------------------

static void* g_trtllm_so_handle = nullptr;

static bool LoadTrtllmSo(const std::string& so_path) {
if (g_trtllm_so_handle != nullptr) return true;
g_trtllm_so_handle = dlopen(so_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (!g_trtllm_so_handle) {
return false;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 LoadTrtllmSo 失败时丢弃了 dlerror() 信息,不利于排查 .so 加载问题

dlopen 失败时 dlerror() 会返回具体的错误原因(如符号缺失、路径不存在等),但当前直接 return false 将该信息丢弃。调用方 InitTrtllmSo 也只是透传布尔值,导致 Python 侧无法得知失败原因。

建议修复方式:

  if (!g_trtllm_so_handle) {
    fprintf(stderr, "LoadTrtllmSo: dlopen failed: %s\n", dlerror());
    return false;
  }

或将 dlerror 信息封装到异常/日志中返回。

}
return true;
}

// ---------------------------------------------------------------------------
// Paddle DataType -> DLDataType
// ---------------------------------------------------------------------------

static DLDataType PaddleDTypeToDL(paddle::DataType dtype) {
switch (dtype) {
case paddle::DataType::FLOAT16:
return DLDataType{kDLFloat, 16, 1};
case paddle::DataType::BFLOAT16:
return DLDataType{kDLBfloat, 16, 1};
case paddle::DataType::FLOAT32:
return DLDataType{kDLFloat, 32, 1};
case paddle::DataType::INT64:
return DLDataType{kDLInt, 64, 1};
case paddle::DataType::INT32:
return DLDataType{kDLInt, 32, 1};
case paddle::DataType::UINT8:
return DLDataType{kDLUInt, 8, 1};
default:
throw std::runtime_error("PaddleDTypeToDL: unsupported dtype");
}
}

// ---------------------------------------------------------------------------
// Helper: paddle::Tensor -> DLTensor (shape stored in out_shape vector)
//
// shape_buf must outlive the returned DLTensor (shape ptr points into it).
// Call MakeTensorView() to get a TensorView (TensorView copies DLTensor).
// ---------------------------------------------------------------------------

static DLTensor MakeDLTensor(const paddle::Tensor& t,
std::vector<int64_t>& shape_buf) {
int ndim = static_cast<int>(t.dims().size());
shape_buf.resize(ndim);
for (int i = 0; i < ndim; ++i) shape_buf[i] = t.dims()[i];

DLTensor dl;
dl.data = const_cast<void*>(t.data());
dl.device = DLDevice{kDLCUDA, t.place().GetDeviceId()};
dl.ndim = ndim;
dl.shape = shape_buf.data();
dl.strides = nullptr;
dl.byte_offset = 0;
dl.dtype = PaddleDTypeToDL(t.dtype());
return dl;
}

// TensorView copies DLTensor by value (shape ptr NOT copied — still points
// into the original DLTensor's shape array). Keep the DLTensor and its
// shape_buf alive for the duration of the TVM FFI call.
static tvm::ffi::TensorView MakeTensorView(const DLTensor& dl) {
return tvm::ffi::TensorView(&dl);
}

// ---------------------------------------------------------------------------
// Main kernel: paddle::Tensor in -> (norm_out, residual_out)
// ---------------------------------------------------------------------------
//
// AllReduceFusionPattern::kARResidualRMSNorm == 1
// (from flashinfer/comm/trtllm_allreduce_fusion.cuh)
//
static constexpr int64_t kARResidualRMSNorm = 1;

std::vector<paddle::Tensor> TrtllmAllreduceResidualRmsnorm(
const paddle::Tensor& input_tensor,
const paddle::Tensor& residual,
const paddle::Tensor& weight,
const paddle::Tensor& workspace_ptrs,
int64_t world_size,
int64_t world_rank,
bool use_oneshot,
bool trigger_completion_at_end,
bool fp32_acc,
double rms_eps) {
if (g_trtllm_so_handle == nullptr) {
throw std::runtime_error(
"TrtllmAllreduceResidualRmsnorm: trtllm_comm.so not loaded. "
"Call init_trtllm_so(path) first.");
}

auto norm_out = paddle::empty_like(input_tensor);
auto residual_out = paddle::empty_like(residual);

int64_t token_num = input_tensor.dims()[0];
// Support empty tensor
if (token_num == 0) {
return {norm_out, residual_out};
}

// Build DLTensors + keep shape buffers alive for the duration of the call.
std::vector<int64_t> sh_input, sh_residual, sh_weight, sh_workspace,
sh_residual_out, sh_norm_out;
DLTensor dl_input = MakeDLTensor(input_tensor, sh_input);
DLTensor dl_residual = MakeDLTensor(residual, sh_residual);
DLTensor dl_weight = MakeDLTensor(weight, sh_weight);
DLTensor dl_workspace = MakeDLTensor(workspace_ptrs, sh_workspace);
DLTensor dl_residual_out = MakeDLTensor(residual_out, sh_residual_out);
DLTensor dl_norm_out = MakeDLTensor(norm_out, sh_norm_out);

int64_t hidden_dim = input_tensor.dims()[1];

// Look up the TVM FFI registered function (registered when .so was dlopen'd)
auto fn_opt = tvm::ffi::Function::GetGlobal("trtllm_allreduce_fusion");
if (!fn_opt.has_value()) {
throw std::runtime_error(
"TrtllmAllreduceResidualRmsnorm: 'trtllm_allreduce_fusion' not found "
"in TVM FFI global registry.");
}
auto& fn = *fn_opt;

using OptTV = tvm::ffi::Optional<tvm::ffi::TensorView>;

fn(MakeTensorView(dl_input),
world_size,
world_rank,
token_num,
hidden_dim,
MakeTensorView(dl_workspace),
/*launch_with_pdl=*/true,
use_oneshot,
trigger_completion_at_end,
fp32_acc,
kARResidualRMSNorm,
OptTV(std::nullopt), // allreduce_out = None
OptTV(MakeTensorView(dl_residual)), // residual_in
OptTV(MakeTensorView(dl_residual_out)), // residual_out
OptTV(MakeTensorView(dl_norm_out)), // norm_out
OptTV(std::nullopt), // quant_out = None
OptTV(std::nullopt), // scale_out = None
OptTV(MakeTensorView(dl_weight)), // rms_gamma
tvm::ffi::Optional<double>(rms_eps),
OptTV(std::nullopt), // scale_factor = None
tvm::ffi::Optional<int64_t>(std::nullopt) // layout_code = None
);

return {norm_out, residual_out};
}

bool InitTrtllmSo(const std::string& so_path) { return LoadTrtllmSo(so_path); }
14 changes: 12 additions & 2 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def find_end_files(directory, end_str):
"gpu_ops/grouped_topk_kernels.cu",
"gpu_ops/fused_cast_sigmoid_bias.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/trtllm_allreduce_op.cc",
"gpu_ops/merge_prefill_decode_output.cu",
"gpu_ops/limit_thinking_content_length.cu",
"gpu_ops/update_attn_mask_offsets.cu",
Expand Down Expand Up @@ -407,6 +408,15 @@ def find_end_files(directory, end_str):
"-Igpu_ops",
"-Ithird_party/nlohmann_json/include",
]

# tvm_ffi include for trtllm_allreduce_op.cc
import tvm_ffi as _tvm_ffi_mod

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 无保护 import tvm_ffi 使其成为全量 custom ops 编译的硬依赖

tvm_ffi 仅供 trtllm_allreduce_op.cc 编译使用,但此处无 try-except 保护。若环境中未安装 tvm_ffi,整个 setup_ops.py 将 ImportError 中断,导致所有 custom ops 无法编译——即使用户不需要 flashinfer allreduce 功能。

建议修复方式:用 try-except 包裹,失败时跳过 tvm_ffi 相关编译参数并从 sources 中移除 trtllm_allreduce_op.cc

try:
    import tvm_ffi as _tvm_ffi_mod
    _tvm_ffi_base = os.path.dirname(_tvm_ffi_mod.__file__)
    _tvm_ffi_include = os.path.join(_tvm_ffi_base, "include")
    _tvm_ffi_lib = os.path.join(_tvm_ffi_base, "lib")
    cc_compile_args += [f"-I{_tvm_ffi_include}"]
    nvcc_compile_args += [f"-I{_tvm_ffi_include}"]
except ImportError:
    _tvm_ffi_lib = None
    sources = [s for s in sources if "trtllm_allreduce_op" not in s]


_tvm_ffi_base = os.path.dirname(_tvm_ffi_mod.__file__)
_tvm_ffi_include = os.path.join(_tvm_ffi_base, "include")
_tvm_ffi_lib = os.path.join(_tvm_ffi_base, "lib")
cc_compile_args += [f"-I{_tvm_ffi_include}"]
nvcc_compile_args += [f"-I{_tvm_ffi_include}"]
max_jobs, nvcc_threads = get_compile_parallelism()
print(f"MAX_JOBS = {max_jobs}, nvcc -t = {nvcc_threads}")
nvcc_compile_args += ["-t", str(nvcc_threads)]
Expand Down Expand Up @@ -588,8 +598,8 @@ def find_end_files(directory, end_str):
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={"cxx": cc_compile_args, "nvcc": nvcc_compile_args},
libraries=["cublasLt"],
extra_link_args=["-lcuda", "-lnvidia-ml"],
libraries=["cublasLt", "tvm_ffi"],
extra_link_args=["-lcuda", "-lnvidia-ml", f"-L{_tvm_ffi_lib}", f"-Wl,-rpath,{_tvm_ffi_lib}", "-ldl"],
),
packages=find_packages(where="third_party/DeepGEMM"),
package_dir={"": "third_party/DeepGEMM"},
Expand Down
64 changes: 31 additions & 33 deletions fastdeploy/model_executor/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
_flashinfer_comm = None
_workspace_manager = None

# Oneshot heuristics from flashinfer: world_size -> comm_size_mb threshold
# Source: flashinfer/comm/trtllm_ar.py::_use_oneshot_heuristics
_USE_ONESHOT_HEURISTICS = {2: 512, 4: 64, 8: 42}


def _get_flashinfer_comm():
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
Expand Down Expand Up @@ -145,7 +149,11 @@ def flashinfer_allreduce_residual_rmsnorm(
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
Use FlashInfer's fused allreduce + residual + RMS norm operation.

The actual kernel is dispatched through trtllm_allreduce_residual_rmsnorm_op
(from flashinfer_comm_op.py), which is registered as a Paddle custom Python
op so that SOT treats the entire call as a single opaque node.
"""
comm = _get_flashinfer_comm()
if not has_flashinfer() or comm is None:
Expand All @@ -158,8 +166,6 @@ def flashinfer_allreduce_residual_rmsnorm(
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

assert input_tensor.shape[0] <= max_token_num

if not ensure_workspace_initialized(
fd_config=fd_config,
max_token_num=max_token_num,
Expand All @@ -169,38 +175,30 @@ def flashinfer_allreduce_residual_rmsnorm(
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
# support empty tensor
if input_tensor.shape[0] == 0:
return norm_out, residual_out
comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
# Compute use_oneshot from concrete values before SOT-traced region.
# max_token_num is always a concrete Python int (model config parameter).
# input_tensor.shape[-1] is hidden_size, always a concrete model constant.
if use_oneshot is None:
hidden_dim = input_tensor.shape[-1]
comm_size_mb = max_token_num * hidden_dim * 2 * world_size * 2 / 1024 / 1024
use_oneshot = comm_size_mb <= _USE_ONESHOT_HEURISTICS.get(world_size, 0)

from fastdeploy.model_executor.layers.flashinfer_comm_op import (
trtllm_allreduce_residual_rmsnorm_op,
)

return norm_out, residual_out
return trtllm_allreduce_residual_rmsnorm_op(
input_tensor,
residual,
weight,
_workspace_manager.workspace_tensor,
world_size,
dist.get_rank(),
bool(use_oneshot),
trigger_completion_at_end,
fp32_acc,
eps,
)


def cleanup_flashinfer_workspace():
Expand Down
Loading
Loading