Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions custom_ops/cpu_ops/get_padding_offset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);

const int token_num_data = cpu_token_num.data<int64_t>()[0];
// token num is cpu tensor
const int token_num_data = token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
Expand Down
14 changes: 3 additions & 11 deletions custom_ops/xpu_ops/src/ops/adjust_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

template <paddle::DataType T>
std::vector<paddle::Tensor> AdjustBatchKernel(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
Expand All @@ -49,7 +48,6 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
using data_t = typename PDTraits<T>::data_t;
const int token_num = x.dims()[0];
const int dim = x.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1];

Expand Down Expand Up @@ -87,8 +85,7 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
}

using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
Expand All @@ -102,8 +99,7 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
int max_input_length);

std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx,
Expand Down Expand Up @@ -135,7 +131,6 @@ std::vector<paddle::Tensor> AdjustBatch(
}

return func(x,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
Expand All @@ -151,7 +146,6 @@ std::vector<paddle::Tensor> AdjustBatch(

std::vector<std::vector<int64_t>> AdjustBatchInferShape(
const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &encoder_seq_lod_shape,
const std::vector<int64_t> &decoder_seq_lod_shape,
const std::vector<int64_t> &encoder_batch_idx_shape,
Expand All @@ -172,7 +166,6 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(

std::vector<paddle::DataType> AdjustBatchInferDtype(
const paddle::DataType &x_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &encoder_seq_lod_dtype,
const paddle::DataType &decoder_seq_lod_dtype,
const paddle::DataType &encoder_batch_idx_dtype,
Expand All @@ -188,7 +181,6 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(

PD_BUILD_STATIC_OP(adjust_batch)
.Inputs({"x",
"cum_offsets",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_batch_idx",
Expand Down
5 changes: 0 additions & 5 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
Expand Down Expand Up @@ -122,7 +121,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
auto qkv_shape = qkv.dims();
auto cache_shape = key_cache.dims();
auto block_table_shape = block_tables.dims();
const int bsz = cum_offsets.dims()[0];
const int block_batch = block_table_shape[0];
const int max_block_per_seq = block_table_shape[1];
const int kv_num_heads = cache_shape[1];
Expand Down Expand Up @@ -984,7 +982,6 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
Expand Down Expand Up @@ -1023,7 +1020,6 @@ std::vector<paddle::Tensor> BlockAttn(
return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \
value_cache, \
cum_offsets, \
rotary_embs, \
block_tables, \
prefix_block_tables, \
Expand Down Expand Up @@ -1099,7 +1095,6 @@ PD_BUILD_STATIC_OP(block_attn)
.Inputs({"qkv",
"key_cache",
"value_cache",
"cum_offsets",
"rotary_embs",
"block_tables",
"prefix_block_tables",
Expand Down
12 changes: 4 additions & 8 deletions custom_ops/xpu_ops/src/ops/gather_next_token.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
#endif

std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map,
Expand All @@ -46,7 +45,7 @@ std::vector<paddle::Tensor> GatherNextToken(
typedef paddle::bfloat16 data_t;
const int dim = x.dims()[1];
const int token_num = x.shape()[0];
int bsz = cum_offsets.shape()[0];
int bsz = -1;
int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1];
if (max_bsz > 0) {
Expand Down Expand Up @@ -116,7 +115,6 @@ std::vector<paddle::Tensor> GatherNextToken(

std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& encoder_seq_lod_shape,
const std::vector<int64_t>& decoder_seq_lod_shape,
const std::vector<int64_t>& encoder_batch_map_shape,
Expand All @@ -130,19 +128,18 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
int64_t bsz = cum_offsets_shape[0];
// int64_t bsz = cum_offsets_shape[0];
int64_t bsz = 0;
int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}};
}
}

std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& encoder_seq_lod_dtype,
const paddle::DataType& decoder_seq_lod_dtype,
const paddle::DataType& encoder_batch_map_dtype,
Expand All @@ -158,7 +155,6 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(

PD_BUILD_STATIC_OP(gather_next_token)
.Inputs({"x",
"cum_offsets",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_batch_map",
Expand Down
6 changes: 3 additions & 3 deletions custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
Expand Down Expand Up @@ -72,7 +72,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(not_need_stop_device.data<bool>()),
max_dec_len.data<int64_t>(),
Expand Down Expand Up @@ -102,7 +102,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"output_cum_offsets",
"cu_seqlens_q_output",
"stop_flags",
"not_need_stop",
"max_dec_len",
Expand Down
133 changes: 133 additions & 0 deletions custom_ops/xpu_ops/src/ops/mtp/speculate_preprocess.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

namespace api = baidu::xpu::api;

std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();

// just for ut to run base line
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
if (input_ids.place().GetType() == phi::AllocationType::CPU) {
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
ctx = cpu_ctx.get();
}

std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
const int token_num_data = cpu_token_num;
auto ids_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
const int max_draft_tokens_per_batch = draft_tokens.shape()[1];

auto seq_lens_output =
paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place());
auto cu_seq_lens_q_output =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token_output =
paddle::empty({bsz * max_draft_tokens_per_batch},
paddle::DataType::INT32,
input_ids.place());
auto real_output_token_num =
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
if (token_num_data == 0) {
return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num};
}

int64_t *ids_remove_padding_ptr = ids_remove_padding.data<int64_t>();
int *batch_id_per_token_ptr = batch_id_per_token.data<int>();
int *cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
int *cu_seqlens_k_ptr = cu_seqlens_k.data<int>();
int *seq_lens_output_ptr = seq_lens_output.data<int>();
int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data<int>();
int *batch_id_per_token_output_ptr = batch_id_per_token_output.data<int>();
int *real_output_token_num_ptr = real_output_token_num.data<int>();
const int64_t *input_data_ptr = input_ids.data<int64_t>();
const int *seq_len_ptr = seq_len.data<int>();
const int64_t *draft_tokens_ptr = draft_tokens.data<int64_t>();
const int *seq_lens_encoder_ptr = seq_lens_encoder.data<int>();

int r =
fastdeploy::plugin::speculate_preprocess(ctx,
ids_remove_padding_ptr,
batch_id_per_token_ptr,
cu_seqlens_q_ptr,
cu_seqlens_k_ptr,
seq_lens_output_ptr,
cu_seq_lens_q_output_ptr,
batch_id_per_token_output_ptr,
real_output_token_num_ptr,
input_data_ptr,
seq_len_ptr,
draft_tokens_ptr,
seq_lens_encoder_ptr,
max_seq_len,
max_draft_tokens_per_batch,
token_num_data,
bsz);

return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num};
}

PD_BUILD_STATIC_OP(speculate_pre_process)
.Inputs({"input_ids",
"seq_len",
"draft_tokens",
"seq_lens_encoder",
"seq_lens_decoder"})
.Outputs({"ids_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k",
"cu_seq_lens_q_output",
"batch_id_per_token_output",
"real_output_token_num"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculatePreProcess));
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void SpeculateTokenPenaltyMultiScores(
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
Expand Down Expand Up @@ -72,8 +72,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
batch_id_per_token_output.data<int>(),
cu_seqlens_q_output.data<int>(),
bs,
length,
length_id,
Expand All @@ -100,8 +100,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
batch_id_per_token_output.data<int>(),
cu_seqlens_q_output.data<int>(),
bs,
length,
length_id,
Expand All @@ -125,8 +125,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
batch_id_per_token_output.data<int>(),
cu_seqlens_q_output.data<int>(),
bs,
length,
length_id,
Expand Down Expand Up @@ -157,8 +157,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
"min_len",
"eos_token_id",
"seq_lens_this_time",
"output_padding_offset",
"output_cum_offsets"})
"batch_id_per_token_output",
"cu_seqlens_q_output"})
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
Expand Down
Loading
Loading