diff --git a/custom_ops/cpu_ops/get_padding_offset.cc b/custom_ops/cpu_ops/get_padding_offset.cc index 50af5a2951d..45ef2623050 100644 --- a/custom_ops/cpu_ops/get_padding_offset.cc +++ b/custom_ops/cpu_ops/get_padding_offset.cc @@ -58,9 +58,8 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const int bsz = seq_len.shape()[0]; const int seq_length = input_ids_shape[1]; auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - - const int token_num_data = cpu_token_num.data()[0]; + // token num is cpu tensor + const int token_num_data = token_num.data()[0]; auto x_remove_padding = paddle::empty( {token_num_data}, paddle::DataType::INT64, input_ids.place()); auto padding_offset = paddle::empty( diff --git a/custom_ops/xpu_ops/src/ops/adjust_batch.cc b/custom_ops/xpu_ops/src/ops/adjust_batch.cc index ce94e9b809b..5bd9dc424c5 100644 --- a/custom_ops/xpu_ops/src/ops/adjust_batch.cc +++ b/custom_ops/xpu_ops/src/ops/adjust_batch.cc @@ -24,8 +24,7 @@ template std::vector AdjustBatchKernel( - const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, @@ -49,7 +48,6 @@ std::vector AdjustBatchKernel( using data_t = typename PDTraits::data_t; const int token_num = x.dims()[0]; const int dim = x.dims()[1]; - const int bsz = cum_offsets.shape()[0]; int enc_batch = len_info_cpu.data()[0]; int dec_batch = len_info_cpu.data()[1]; @@ -87,8 +85,7 @@ std::vector AdjustBatchKernel( } using AdjustBatchKernelFuncPtr = std::vector (*)( - const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, @@ -102,8 +99,7 @@ using AdjustBatchKernelFuncPtr = std::vector (*)( int max_input_length); std::vector AdjustBatch( - const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, @@ -135,7 +131,6 @@ std::vector AdjustBatch( } return func(x, - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_idx, @@ -151,7 +146,6 @@ std::vector AdjustBatch( std::vector> AdjustBatchInferShape( const std::vector &x_shape, - const std::vector &cum_offsets_shape, const std::vector &encoder_seq_lod_shape, const std::vector &decoder_seq_lod_shape, const std::vector &encoder_batch_idx_shape, @@ -172,7 +166,6 @@ std::vector> AdjustBatchInferShape( std::vector AdjustBatchInferDtype( const paddle::DataType &x_dtype, - const paddle::DataType &cum_offsets_dtype, const paddle::DataType &encoder_seq_lod_dtype, const paddle::DataType &decoder_seq_lod_dtype, const paddle::DataType &encoder_batch_idx_dtype, @@ -188,7 +181,6 @@ std::vector AdjustBatchInferDtype( PD_BUILD_STATIC_OP(adjust_batch) .Inputs({"x", - "cum_offsets", "encoder_seq_lod", "decoder_seq_lod", "encoder_batch_idx", diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index c055bfb873d..b14ac36f7df 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -66,7 +66,6 @@ std::vector BlockAttnKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& cum_offsets, const paddle::Tensor& rotary_embs, const paddle::Tensor& block_tables, const paddle::Tensor& prefix_block_tables, @@ -122,7 +121,6 @@ std::vector BlockAttnKernel( auto qkv_shape = qkv.dims(); auto cache_shape = key_cache.dims(); auto block_table_shape = block_tables.dims(); - const int bsz = cum_offsets.dims()[0]; const int block_batch = block_table_shape[0]; const int max_block_per_seq = block_table_shape[1]; const int kv_num_heads = cache_shape[1]; @@ -984,7 +982,6 @@ std::vector BlockAttn( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& cum_offsets, const paddle::Tensor& rotary_embs, const paddle::Tensor& block_tables, const paddle::Tensor& prefix_block_tables, @@ -1023,7 +1020,6 @@ std::vector BlockAttn( return BlockAttnKernel(qkv, \ key_cache, \ value_cache, \ - cum_offsets, \ rotary_embs, \ block_tables, \ prefix_block_tables, \ @@ -1099,7 +1095,6 @@ PD_BUILD_STATIC_OP(block_attn) .Inputs({"qkv", "key_cache", "value_cache", - "cum_offsets", "rotary_embs", "block_tables", "prefix_block_tables", diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 186a4d12bad..31c2142ca07 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -22,8 +22,7 @@ #endif std::vector GatherNextToken( - const paddle::Tensor& x, // [token_num, dim_embed] - const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_map, @@ -46,7 +45,7 @@ std::vector GatherNextToken( typedef paddle::bfloat16 data_t; const int dim = x.dims()[1]; const int token_num = x.shape()[0]; - int bsz = cum_offsets.shape()[0]; + int bsz = -1; int enc_batch = len_info_cpu.data()[0]; int dec_batch = len_info_cpu.data()[1]; if (max_bsz > 0) { @@ -116,7 +115,6 @@ std::vector GatherNextToken( std::vector> GatherNextTokenInferShape( const std::vector& x_shape, - const std::vector& cum_offsets_shape, const std::vector& encoder_seq_lod_shape, const std::vector& decoder_seq_lod_shape, const std::vector& encoder_batch_map_shape, @@ -130,19 +128,18 @@ std::vector> GatherNextTokenInferShape( // if (output_padding_offset_shape) { // PD_THROW("speculative decoding is not supported in XPU."); // } - int64_t bsz = cum_offsets_shape[0]; + // int64_t bsz = cum_offsets_shape[0]; + int64_t bsz = 0; int64_t dim_embed = x_shape[1]; if (output_padding_offset_shape) { return {{-1, dim_embed}}; } else { - int64_t bsz = cum_offsets_shape[0]; return {{bsz, dim_embed}}; } } std::vector GatherNextTokenInferDtype( const paddle::DataType& x_dtype, - const paddle::DataType& cum_offsets_dtype, const paddle::DataType& encoder_seq_lod_dtype, const paddle::DataType& decoder_seq_lod_dtype, const paddle::DataType& encoder_batch_map_dtype, @@ -158,7 +155,6 @@ std::vector GatherNextTokenInferDtype( PD_BUILD_STATIC_OP(gather_next_token) .Inputs({"x", - "cum_offsets", "encoder_seq_lod", "decoder_seq_lod", "encoder_batch_map", diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc index 29ab89084f2..b006ed26e34 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc @@ -28,7 +28,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& stop_flags, const paddle::Tensor& not_need_stop, const paddle::Tensor& max_dec_len, @@ -72,7 +72,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(step_idx.data()), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), const_cast(stop_flags.data()), const_cast(not_need_stop_device.data()), max_dec_len.data(), @@ -102,7 +102,7 @@ PD_BUILD_STATIC_OP(draft_model_update) "seq_lens_encoder", "seq_lens_decoder", "step_idx", - "output_cum_offsets", + "cu_seqlens_q_output", "stop_flags", "not_need_stop", "max_dec_len", diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_preprocess.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_preprocess.cc new file mode 100644 index 00000000000..30d4e0f425e --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_preprocess.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; + +std::vector SpeculatePreProcess( + const int64_t cpu_token_num, + const paddle::Tensor &input_ids, + const paddle::Tensor &seq_len, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + + // just for ut to run base line + std::unique_ptr cpu_ctx; + if (input_ids.place().GetType() == phi::AllocationType::CPU) { + cpu_ctx = std::make_unique(baidu::xpu::api::kCPU); + ctx = cpu_ctx.get(); + } + + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int max_seq_len = input_ids_shape[1]; + const int token_num_data = cpu_token_num; + auto ids_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto batch_id_per_token = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + const int max_draft_tokens_per_batch = draft_tokens.shape()[1]; + + auto seq_lens_output = + paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place()); + auto cu_seq_lens_q_output = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + auto batch_id_per_token_output = + paddle::empty({bsz * max_draft_tokens_per_batch}, + paddle::DataType::INT32, + input_ids.place()); + auto real_output_token_num = + paddle::empty({1}, paddle::DataType::INT32, input_ids.place()); + if (token_num_data == 0) { + return {ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num}; + } + + int64_t *ids_remove_padding_ptr = ids_remove_padding.data(); + int *batch_id_per_token_ptr = batch_id_per_token.data(); + int *cu_seqlens_q_ptr = cu_seqlens_q.data(); + int *cu_seqlens_k_ptr = cu_seqlens_k.data(); + int *seq_lens_output_ptr = seq_lens_output.data(); + int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data(); + int *batch_id_per_token_output_ptr = batch_id_per_token_output.data(); + int *real_output_token_num_ptr = real_output_token_num.data(); + const int64_t *input_data_ptr = input_ids.data(); + const int *seq_len_ptr = seq_len.data(); + const int64_t *draft_tokens_ptr = draft_tokens.data(); + const int *seq_lens_encoder_ptr = seq_lens_encoder.data(); + + int r = + fastdeploy::plugin::speculate_preprocess(ctx, + ids_remove_padding_ptr, + batch_id_per_token_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + seq_lens_output_ptr, + cu_seq_lens_q_output_ptr, + batch_id_per_token_output_ptr, + real_output_token_num_ptr, + input_data_ptr, + seq_len_ptr, + draft_tokens_ptr, + seq_lens_encoder_ptr, + max_seq_len, + max_draft_tokens_per_batch, + token_num_data, + bsz); + + return {ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num}; +} + +PD_BUILD_STATIC_OP(speculate_pre_process) + .Inputs({"input_ids", + "seq_len", + "draft_tokens", + "seq_lens_encoder", + "seq_lens_decoder"}) + .Outputs({"ids_remove_padding", + "batch_id_per_token", + "cu_seqlens_q", + "cu_seqlens_k", + "cu_seq_lens_q_output", + "batch_id_per_token_output", + "real_output_token_num"}) + .Attrs({"cpu_token_num: int64_t"}) + .SetKernelFn(PD_KERNEL(SpeculatePreProcess)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc index be6e66420e7..3cb50c1788a 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc @@ -33,8 +33,8 @@ void SpeculateTokenPenaltyMultiScores( const paddle::Tensor& min_len, const paddle::Tensor& eos_token_id, const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { namespace api = baidu::xpu::api; phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); @@ -72,8 +72,8 @@ void SpeculateTokenPenaltyMultiScores( min_len.data(), eos_token_id.data(), bad_tokens.data(), - output_padding_offset.data(), - output_cum_offsets.data(), + batch_id_per_token_output.data(), + cu_seqlens_q_output.data(), bs, length, length_id, @@ -100,8 +100,8 @@ void SpeculateTokenPenaltyMultiScores( min_len.data(), eos_token_id.data(), bad_tokens.data(), - output_padding_offset.data(), - output_cum_offsets.data(), + batch_id_per_token_output.data(), + cu_seqlens_q_output.data(), bs, length, length_id, @@ -125,8 +125,8 @@ void SpeculateTokenPenaltyMultiScores( min_len.data(), eos_token_id.data(), bad_tokens.data(), - output_padding_offset.data(), - output_cum_offsets.data(), + batch_id_per_token_output.data(), + cu_seqlens_q_output.data(), bs, length, length_id, @@ -157,8 +157,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores) "min_len", "eos_token_id", "seq_lens_this_time", - "output_padding_offset", - "output_cum_offsets"}) + "batch_id_per_token_output", + "cu_seqlens_q_output"}) .Outputs({"logits_out"}) .Attrs({"max_seq_len: int"}) .SetInplaceMap({{"logits", "logits_out"}}) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 8452b233e50..ae2ddc0ce8c 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -26,24 +26,24 @@ namespace api = baidu::xpu::api; -void SpeculateVerify(const paddle::Tensor& sampled_token_ids, - const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& step_idx, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& verify_tokens, - const paddle::Tensor& verify_scores, - const paddle::Tensor& max_dec_len, - const paddle::Tensor& end_tokens, - const paddle::Tensor& is_block_step, - const paddle::Tensor& output_cum_offsets, - const paddle::Tensor& actual_candidate_len, - const paddle::Tensor& actual_draft_token_nums, - const paddle::Tensor& topp, +void SpeculateVerify(const paddle::Tensor &sampled_token_ids, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &accept_num, + const paddle::Tensor &step_idx, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &verify_tokens, + const paddle::Tensor &verify_scores, + const paddle::Tensor &max_dec_len, + const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor &actual_candidate_len, + const paddle::Tensor &actual_draft_token_nums, + const paddle::Tensor &topp, int max_seq_len, int verify_window, bool enable_topp, @@ -57,7 +57,8 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - api::Context* ctx = static_cast(dev_ctx)->x_context(); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); bool xpu_ctx_flag = true; if (draft_tokens.is_cpu()) { ctx = new api::Context(api::kCPU); @@ -65,17 +66,17 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, } bool use_topk = false; - char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); + char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); if (env_var) { use_topk = static_cast(std::stoi(env_var)); } bool use_target_sampling = false; - char* env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING"); + char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING"); if (env_var_1) { use_target_sampling = static_cast(std::stoi(env_var_1)); } bool prefill_one_step_stop = false; - if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { // std::cout << "Your PATH is: " << env_p << '\n'; if (env_p[0] == '1') { prefill_one_step_stop = true; @@ -90,7 +91,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, std::mt19937_64 engine(infer_seed[i]); dev_curand_states_cpu.push_back(dist(engine)); } - float* dev_curand_states = dev_curand_states_cpu.data(); + float *dev_curand_states = dev_curand_states_cpu.data(); auto dev_curand_states_tensor = paddle::empty({static_cast(dev_curand_states_cpu.size())}, paddle::DataType::FLOAT32, @@ -110,10 +111,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, ret = fastdeploy::plugin::speculate_verify( ctx, sampled_token_ids.data(), - const_cast(accept_tokens.data()), - const_cast(accept_num.data()), - const_cast(step_idx.data()), - const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), seq_lens_encoder.data(), seq_lens_decoder.data(), draft_tokens.data(), @@ -126,7 +127,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, @@ -143,10 +144,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, ret = fastdeploy::plugin::speculate_verify( ctx, sampled_token_ids.data(), - const_cast(accept_tokens.data()), - const_cast(accept_num.data()), - const_cast(step_idx.data()), - const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), seq_lens_encoder.data(), seq_lens_decoder.data(), draft_tokens.data(), @@ -159,7 +160,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, @@ -178,10 +179,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, ret = fastdeploy::plugin::speculate_verify( ctx, sampled_token_ids.data(), - const_cast(accept_tokens.data()), - const_cast(accept_num.data()), - const_cast(step_idx.data()), - const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), seq_lens_encoder.data(), seq_lens_decoder.data(), draft_tokens.data(), @@ -194,7 +195,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, @@ -211,10 +212,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, ret = fastdeploy::plugin::speculate_verify( ctx, sampled_token_ids.data(), - const_cast(accept_tokens.data()), - const_cast(accept_num.data()), - const_cast(step_idx.data()), - const_cast(stop_flags.data()), + const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(stop_flags.data()), seq_lens_encoder.data(), seq_lens_decoder.data(), draft_tokens.data(), @@ -227,7 +228,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, @@ -262,7 +263,7 @@ PD_BUILD_STATIC_OP(speculate_verify) "max_dec_len", "end_tokens", "is_block_step", - "output_cum_offsets", + "cu_seqlens_q_output", "actual_candidate_len", "actual_draft_token_nums", "topp"}) diff --git a/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc b/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc index f54761bd64d..a404bb1f728 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc @@ -41,7 +41,7 @@ namespace api = baidu::xpu::api; std::vector TopPCandidates( const paddle::Tensor& probs, const paddle::Tensor& top_p, - const paddle::Tensor& output_padding_offset, + const paddle::Tensor& batch_id_per_token_output, int candidates_len, int max_seq_len) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); @@ -77,7 +77,7 @@ std::vector TopPCandidates( ctx, reinterpret_cast(probs.data()), reinterpret_cast(top_p.data()), - output_padding_offset.data(), + batch_id_per_token_output.data(), verify_tokens.data(), reinterpret_cast( verify_scores.data()), @@ -100,7 +100,7 @@ std::vector TopPCandidates( ctx, reinterpret_cast(probs.data()), reinterpret_cast(top_p.data()), - output_padding_offset.data(), + batch_id_per_token_output.data(), verify_tokens.data(), reinterpret_cast( verify_scores.data()), @@ -120,7 +120,7 @@ std::vector TopPCandidates( ctx, probs.data(), top_p.data(), - output_padding_offset.data(), + batch_id_per_token_output.data(), verify_tokens.data(), verify_scores.data(), actual_candidate_lens.data(), @@ -139,7 +139,7 @@ std::vector TopPCandidates( std::vector> TopPCandidatesInferShape( const std::vector& probs_shape, const std::vector& top_p_shape, - const std::vector& output_padding_offset_shape, + const std::vector& batch_id_per_token_output_shape, int max_candidates_len) { int token_num = probs_shape[0]; return {{token_num, max_candidates_len}, @@ -150,12 +150,12 @@ std::vector> TopPCandidatesInferShape( std::vector TopPCandidatesInferDtype( const paddle::DataType& probs_dtype, const paddle::DataType& top_p_dtype, - const paddle::DataType& output_padding_offset_dtype) { + const paddle::DataType& batch_id_per_token_output_dtype) { return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(top_p_candidates) - .Inputs({"probs", "top_p", "output_padding_offset"}) + .Inputs({"probs", "top_p", "batch_id_per_token_output"}) .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) .Attrs({"candidates_len: int", "max_seq_len: int"}) .SetKernelFn(PD_KERNEL(TopPCandidates)) diff --git a/custom_ops/xpu_ops/src/ops/mtp/unified_update_model_status.cc b/custom_ops/xpu_ops/src/ops/mtp/unified_update_model_status.cc new file mode 100644 index 00000000000..4945e935c8c --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/unified_update_model_status.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/common/flags.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "xpu/internal/infra_op.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; +void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &has_running_seqs, + const paddle::Tensor &step_input_ids, + const paddle::Tensor &adaptive_step_input_len, + const paddle::Tensor &step_output_ids, + const paddle::Tensor &step_output_len, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &is_paused, + const paddle::Tensor &mask_rollback, + const paddle::Tensor &token_ids_all, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &end_tokens, + const paddle::Tensor &max_dec_len, + const bool is_naive_mode, + const bool prefill_one_step_stop) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + + // just for ut to run base line + std::unique_ptr cpu_ctx; + if (seq_lens_encoder.place().GetType() == phi::AllocationType::CPU) { + cpu_ctx = std::make_unique(baidu::xpu::api::kCPU); + ctx = cpu_ctx.get(); + } + + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + PADDLE_ENFORCE_LE( + max_bsz, + 1024, + phi::errors::InvalidArgument( + "unified_update_model_status: max_bsz (%d) must be <= 1024 " + "(single-block launch limit).", + max_bsz)); + const int max_step_tokens = step_input_ids.shape()[1]; + const int max_model_len = token_ids_all.shape()[1]; + const int num_end_tokens = end_tokens.shape()[0]; + + // has_running_seqs is CPU tensor, need to copy to GPU first + auto has_running_seqs_xpu = + has_running_seqs.copy_to(seq_lens_this_time.place(), false); + int r = fastdeploy::plugin::unified_update_model_status( + ctx, + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(has_running_seqs_xpu.data()), + const_cast(mask_rollback.data()), + const_cast(step_input_ids.data()), + const_cast(adaptive_step_input_len.data()), + const_cast(step_output_ids.data()), + const_cast(step_output_len.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(is_paused.data()), + const_cast(token_ids_all.data()), + prompt_lens.data(), + const_cast(step_idx.data()), + end_tokens.data(), + max_dec_len.data(), + real_bsz, + max_bsz, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "unified_update_model_status"); + // Copy result back to CPU + auto has_running_seqs_cpu = + has_running_seqs_xpu.copy_to(has_running_seqs.place(), false); + bool *out_data = const_cast(has_running_seqs.data()); + out_data[0] = has_running_seqs_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(unified_update_model_status) + .Inputs({"seq_lens_encoder", + "seq_lens_decoder", + "has_running_seqs", + "step_input_ids", + "adaptive_step_input_len", + "step_output_ids", + "step_output_len", + "stop_flags", + "seq_lens_this_time", + "is_paused", + "mask_rollback", + "token_ids_all", + "prompt_lens", + "step_idx", + "end_tokens", + "max_dec_len"}) + .Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"}) + .Outputs({"seq_lens_encoder_out", + "seq_lens_decoder_out", + "has_running_seqs_out", + "step_input_ids_out", + "adaptive_step_input_len_out", + "step_output_ids_out", + "step_output_len_out", + "stop_flags_out", + "seq_lens_this_time_out", + "mask_rollback_out", + "token_ids_all_out", + "step_idx_out"}) + .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"has_running_seqs", "has_running_seqs_out"}, + {"step_input_ids", "step_input_ids_out"}, + {"adaptive_step_input_len", "adaptive_step_input_len_out"}, + {"step_output_ids", "step_output_ids_out"}, + {"step_output_len", "step_output_len_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"mask_rollback", "mask_rollback_out"}, + {"token_ids_all", "token_ids_all_out"}, + {"step_idx", "step_idx_out"}}) + .SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 14468ddda48..01929009e5b 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -34,8 +34,7 @@ void prof_start(); void prof_stop(); std::vector AdjustBatch( - const paddle::Tensor& x, // [token_num, dim_embed] - const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_idx, @@ -62,7 +61,6 @@ std::vector BlockAttn( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& cum_offsets, const paddle::Tensor& rotary_embs, const paddle::Tensor& block_tables, const paddle::Tensor& prefix_block_tables, @@ -210,7 +208,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& stop_flags, const paddle::Tensor& not_need_stop, const paddle::Tensor& max_dec_len, @@ -254,8 +252,8 @@ void SpeculateTokenPenaltyMultiScores( const paddle::Tensor& min_len, const paddle::Tensor& eos_token_id, const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len); void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder, @@ -413,8 +411,7 @@ std::vector EagleGetSelfHiddenStates( const paddle::Tensor& step_idx); std::vector GatherNextToken( - const paddle::Tensor& x, // [token_num, dim_embed] - const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_map, @@ -500,6 +497,14 @@ std::vector SpeculateGetPaddingOffset( const paddle::Tensor& seq_len, const paddle::Tensor& seq_lens_encoder); +std::vector SpeculatePreProcess( + const int64_t cpu_token_num, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_len, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder); + void StepPaddle(const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& ori_seq_lens_encoder, @@ -540,6 +545,25 @@ void MTPStepPaddle( const int block_size, const int max_draft_tokens); +void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& has_running_seqs, + const paddle::Tensor& step_input_ids, + const paddle::Tensor& adaptive_step_input_len, + const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_paused, + const paddle::Tensor& mask_rollback, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& end_tokens, + const paddle::Tensor& max_dec_len, + const bool is_naive_mode, + const bool prefill_one_step_stop); + void SpeculateStepPaddle( const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, @@ -682,7 +706,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("adjust_batch", &AdjustBatch, py::arg("x"), - py::arg("cum_offsets"), py::arg("encoder_seq_lod"), py::arg("decoder_seq_lod"), py::arg("encoder_batch_idx"), @@ -701,7 +724,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("qkv"), py::arg("key_cache"), py::arg("value_cache"), - py::arg("cum_offsets"), py::arg("rotary_embs"), py::arg("block_tables"), py::arg("prefix_block_tables"), @@ -812,7 +834,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), // 编码器序列长度张量 py::arg("seq_lens_decoder"), // 解码器序列长度张量 py::arg("step_idx"), // 步骤索引张量 - py::arg("output_cum_offsets"), // 输出累积偏移量张量 + py::arg("cu_seqlens_q_output"), // 输出累积偏移量张量 py::arg("stop_flags"), // 停止标志张量 py::arg("not_need_stop"), // 无需停止标志张量 py::arg("max_dec_len"), // 最大解码长度张量 @@ -885,7 +907,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("gather_next_token", &GatherNextToken, py::arg("x"), - py::arg("cum_offsets"), py::arg("encoder_seq_lod"), py::arg("decoder_seq_lod"), py::arg("encoder_batch_map"), @@ -1002,6 +1023,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("redundant_ep_rank_num_plus_one"), "moe export RedundantTopKSelect function"); + m.def("unified_update_model_status", + &UnifiedUpdateModelStatus, + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("has_running_seqs"), + py::arg("step_input_ids"), + py::arg("adaptive_step_input_len"), + py::arg("step_output_ids"), + py::arg("step_output_len"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("is_paused"), + py::arg("mask_rollback"), + py::arg("token_ids_all"), + py::arg("prompt_lens"), + py::arg("step_idx"), + py::arg("end_tokens"), + py::arg("max_dec_len"), + py::arg("is_naive_mode"), + py::arg("max_draft_tokens"), + "Unified update model status"); + m.def("mtp_step_paddle", &MTPStepPaddle, py::arg("base_model_stop_flags"), @@ -1117,8 +1160,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("min_len"), py::arg("eos_token_id"), py::arg("seq_lens_this_time"), - py::arg("output_padding_offset"), - py::arg("output_cum_offsets"), + py::arg("batch_id_per_token_output"), + py::arg("cu_seqlens_q_output"), py::arg("max_seq_len"), "Applies token penalty with multiple scores"); @@ -1182,7 +1225,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_dec_len"), py::arg("end_tokens"), py::arg("is_block_step"), - py::arg("output_cum_offsets"), + py::arg("cu_seqlens_q_output"), py::arg("actual_candidate_len"), py::arg("actual_draft_token_nums"), py::arg("topp"), @@ -1246,6 +1289,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_seq_len"), "Get output padding offset"); + m.def("speculate_pre_process", + &SpeculatePreProcess, + py::arg("cpu_token_num"), + py::arg("input_ids"), + py::arg("seq_len"), + py::arg("draft_tokens"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + "speculate pre process to remove padding and to acquire cu_seq_len"); + m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, py::arg("input_ids"), @@ -1419,7 +1472,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &TopPCandidates, py::arg("probs"), py::arg("top_p"), - py::arg("output_padding_offset"), + py::arg("batch_id_per_token_output"), py::arg("candidates_len"), py::arg("max_seq_len"), "Generate top-p candidates based on probability distributions"); diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 1cbd7a8029b..75bd67e23f0 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -388,8 +388,8 @@ DLL_EXPORT int speculate_token_penalty_multi_scores( const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -432,7 +432,7 @@ DLL_EXPORT int speculate_verify(api::Context* ctx, const int64_t* max_dec_len, const int64_t* end_tokens, const bool* is_block_step, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, const int* actual_candidate_len, const int real_bsz, const int max_draft_tokens, @@ -465,7 +465,7 @@ DLL_EXPORT int draft_model_update(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -574,7 +574,7 @@ template DLL_EXPORT int top_p_candidates(api::Context* ctx, const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, T* out_val, int* actual_candidates_lens, @@ -630,6 +630,24 @@ DLL_EXPORT int speculate_schedule_cache(api::Context* ctx, const int block_num_per_seq, const bool prefill_one_step_stop); +DLL_EXPORT int speculate_preprocess(api::Context* ctx, + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int token_num_data, + const int real_bs); + DLL_EXPORT int speculate_update_v3(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, @@ -662,6 +680,31 @@ DLL_EXPORT int speculate_update(api::Context* ctx, const int max_bsz, const int max_draft_tokens); +DLL_EXPORT int unified_update_model_status(api::Context* ctx, + int* seq_lens_encoder, + int* seq_lens_decoder, + bool* has_running_seqs, + int* mask_rollback, + int64_t* step_input_ids, + int* adaptive_step_input_len, + int64_t* step_output_ids, + int* step_output_len, + bool* stop_flags, + int* seq_lens_this_time, + const bool* is_paused, + int64_t* token_ids_all, + const int64_t* prompt_lens, + int64_t* step_idx, + const int64_t* end_tokens, + const int64_t* max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop); + template DLL_EXPORT int rebuild_hidden_states(api::Context* ctx, const T* input, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu index f6388ceb92a..cad6ddfc793 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu @@ -22,7 +22,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -45,8 +45,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens, auto* pre_ids_now = pre_ids + tid * pre_id_length; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * max_base_model_draft_token; - const int next_tokens_start_id = - tid * max_seq_len - output_cum_offsets[tid]; + const int next_tokens_start_id = cu_seqlens_q_output[tid]; auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; auto seq_len_this_time = seq_lens_this_time[tid]; auto seq_len_encoder = seq_lens_encoder[tid]; @@ -72,8 +71,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens, base_model_draft_tokens_now[substep + 1] = token_this_time; } // multi_end - if (is_in_end(token_this_time, end_ids, end_ids_len) || - prefill_one_step_stop) { + if (is_in_end(token_this_time, end_ids, end_ids_len)) { stop_flags[tid] = true; stop_flag_now_int_sm[cid] += 1; // max_dec_len diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu index b35e85ff1eb..d6b6856dc33 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_ban_bad_words.xpu @@ -22,7 +22,7 @@ inline __device__ void update_bad_words_logit( template __global__ void speculate_ban_bad_words(T* logits, const int64_t* bad_words_list, - const int* output_padding_offset, + const int* batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t bad_words_length, @@ -32,7 +32,7 @@ __global__ void speculate_ban_bad_words(T* logits, int nthreads = cluster_num() * core_num(); int start = -1; int end = -1; - int output_padding_offset_lm; + int batch_id_per_token_output_lm; partition(tid, nthreads, static_cast(token_num * bad_words_length), @@ -41,10 +41,10 @@ __global__ void speculate_ban_bad_words(T* logits, &end); for (int i = start; i < end; i++) { int token_idx = i / bad_words_length; - GM2LM(output_padding_offset + token_idx, - &output_padding_offset_lm, + GM2LM(batch_id_per_token_output + token_idx, + &batch_id_per_token_output_lm, sizeof(int)); - int bs_idx = (token_idx + output_padding_offset_lm) / max_seq_len; + int bs_idx = batch_id_per_token_output_lm; if (bs_idx >= bs) { continue; } @@ -63,7 +63,7 @@ __global__ void speculate_ban_bad_words(T* logits, template __global__ void speculate_ban_bad_words( \ DATA_TYPE* logits, \ const int64_t* bad_words_list, \ - const int* output_padding_offset, \ + const int* batch_id_per_token_output_lm, \ const int64_t bs, \ const int64_t length, \ const int64_t bad_words_length, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu index 4b11abfd331..d0861a98125 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_min_length_logits_process.xpu @@ -11,8 +11,8 @@ __global__ void speculate_min_length_logits_process( const int64_t* cur_len, const int64_t* min_len, const int64_t* eos_token_id, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -29,26 +29,26 @@ __global__ void speculate_min_length_logits_process( int64_t eos_token_id_now; int64_t bi; int64_t end_num; - int output_padding_offset_now; - int output_cum_offsets_now; + int batch_id_per_token_output_now; + int cu_seqlens_q_output_now; __simd__ float float32logits_now[32]; for (int64_t i = tid; i < token_num * end_length; i += nthreads) { int64_t token_idx = i / end_length; - GM2LM(output_padding_offset + token_idx, - &output_padding_offset_now, + GM2LM(batch_id_per_token_output + token_idx, + &batch_id_per_token_output_now, sizeof(int)); - bi = (token_idx + output_padding_offset_now) / max_seq_len; + bi = batch_id_per_token_output[token_idx]; if (bi >= bs) { continue; } end_num = i % end_length; GM2LM_ASYNC( - output_cum_offsets + bi, (void*)&output_cum_offsets_now, sizeof(int)); + cu_seqlens_q_output + bi, (void*)&cu_seqlens_q_output_now, sizeof(int)); GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t)); GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t)); mfence(); - int query_start_token_idx = bi * max_seq_len - output_cum_offsets_now; + int query_start_token_idx = cu_seqlens_q_output_now; if (cur_len_now >= 0 && (cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) { GM2LM( @@ -74,8 +74,8 @@ __global__ void speculate_min_length_logits_process( const int64_t* cur_len, \ const int64_t* min_len, \ const int64_t* eos_token_id, \ - const int* output_padding_offset, \ - const int* output_cum_offsets, \ + const int* batch_id_per_token_output, \ + const int* cu_seqlens_q_output, \ const int64_t bs, \ const int64_t length, \ const int64_t length_id, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_preprocess.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_preprocess.xpu new file mode 100644 index 00000000000..186d7293605 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_preprocess.xpu @@ -0,0 +1,157 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +#include "xpu/kernel/cluster_debug.h" + +namespace fd_xpu3 { + +#define MAX_BATCH_SIZE 1024 + +static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) { + auto v1 = vsrlp_int32x16(1 << 8, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 7, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 6, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 5, v0); + v0 = vvadd_int32x16(v0, v1); + return vextract_int32x16(v0, 1); +} + +inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x, + int64_t len) { + int32x16_t x_l, x_h; + int32x16_t sum = vset_zero_int(); + const auto rounddown_len = rounddown32(len); + + for (int64_t i = 0; i < rounddown_len; i += 32) { + vload2_sm(x + i, x_l, x_h); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + + if (rounddown_len < len) { + const auto mask = ~(-1 << (len - rounddown_len)); + vload2_sm_mz(x + rounddown_len, x_l, x_h, mask); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + return v_reduce_sum_int32(sum); +} + +__global__ void speculate_preprocess_kernel( + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int real_bs) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + __shared__ int sm_seq_lens[MAX_BATCH_SIZE]; + __shared__ int sm_seq_lens_output[MAX_BATCH_SIZE]; + __shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE]; + __shared__ int sm_cum_seq_len, sm_cum_seq_len_output; + __simd__ __shared__ int buffer_cu_seqlens[64]; + __simd__ __shared__ int buffer_cu_seqlens_output[64]; + + if (cid == 0) { + GM2SM_ASYNC(seq_lens, sm_seq_lens, sizeof(int) * real_bs); + GM2SM(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * real_bs); + } + sync_all(); + for (int bid = cid; bid < real_bs; bid += ncores) { + if (sm_seq_lens[bid] == 0) { + sm_seq_lens_output[bid] = 0; + } else if (sm_seq_lens[bid] == 1) { + sm_seq_lens_output[bid] = 1; + } else if (sm_seq_lens_encoder[bid] != 0) { + sm_seq_lens_output[bid] = 1; + } else { + sm_seq_lens_output[bid] = sm_seq_lens[bid]; + } + } + mfence_sm(); + sync_all(); + + for (int bi = clusterid; bi < real_bs; bi += nclusters) { + int cum_seq_len = 0; + int cum_seq_len_output = 0; + for (int i = cid; i < bi + 1; i += ncores) { + cum_seq_len += sm_seq_lens[i]; + cum_seq_len_output += sm_seq_lens_output[i]; + } + buffer_cu_seqlens[cid] = cum_seq_len; + buffer_cu_seqlens_output[cid] = cum_seq_len_output; + mfence(); + sync_all(); + if (cid == 0) { + cum_seq_len = + primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores)); + cum_seq_len_output = primitive_reduce_sum_sm(buffer_cu_seqlens_output, + min(bi + 1, ncores)); + LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int)); + LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int)); + LM2GM_ASYNC( + &cum_seq_len_output, cu_seq_lens_q_output + bi + 1, sizeof(int)); + if (bi == real_bs - 1) { + LM2GM_ASYNC(&cum_seq_len_output, real_output_token_num, sizeof(int)); + } + sm_cum_seq_len = cum_seq_len; + sm_cum_seq_len_output = cum_seq_len_output; + } + mfence(); + sync_all(); + + const int lm_seq_lens = sm_seq_lens[bi]; + const int lm_seq_lens_encoder = sm_seq_lens_encoder[bi]; + for (int i = cid; i < lm_seq_lens; i += ncores) { + const int tgt_seq_id = sm_cum_seq_len - lm_seq_lens + i; + if (max_draft_tokens_per_batch > 0 && lm_seq_lens_encoder <= 0) { + // speculative decoding + const int src_seq_id = bi * max_draft_tokens_per_batch + i; + int64_t lm_draft_tokens; + GM2LM(draft_tokens + src_seq_id, &lm_draft_tokens, sizeof(int64_t)); + LM2GM( + &lm_draft_tokens, ids_remove_padding + tgt_seq_id, sizeof(int64_t)); + } else { + // Non-speculative decoding + const int src_seq_id = bi * max_seq_len + i; + int64_t lm_input_data; + GM2LM(input_data + src_seq_id, &lm_input_data, sizeof(int64_t)); + LM2GM(&lm_input_data, ids_remove_padding + tgt_seq_id, sizeof(int64_t)); + } + LM2GM(&bi, batch_id_per_token + tgt_seq_id, sizeof(int)); + } + + const int lm_seq_lens_output = sm_seq_lens_output[bi]; + for (int i = cid; i < lm_seq_lens_output; i += ncores) { + const int tgt_seq_id_output = + sm_cum_seq_len_output - lm_seq_lens_output + i; + LM2GM(&bi, batch_id_per_token_output + tgt_seq_id_output, sizeof(int)); + } + mfence(); + sync_all(); + } + + if (cid == 0 && clusterid == 0) { + const int lm_zero = 0; + LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int)); + LM2GM_ASYNC(&lm_zero, cu_seqlens_k, sizeof(int)); + LM2GM(&lm_zero, cu_seq_lens_q_output, sizeof(int)); + } +} +} // namespace fd_xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu index c567003ea9a..778f5852038 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu @@ -22,7 +22,7 @@ __device__ void speculate_update_repeat_times_normal( __global_ptr__ const int64_t *pre_ids, __global_ptr__ const int64_t *cur_len, __global_ptr__ int *repeat_times, - __global_ptr__ const int *output_padding_offset, + __global_ptr__ const int *batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -40,15 +40,17 @@ __device__ void speculate_update_repeat_times_normal( int n_length = (length + max_sm_len - 1) / max_sm_len; int64_t *cur_len_lm = (int64_t *)lm; - int output_padding_offset_now; + int batch_id_per_token_output_now; GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t)); for (int nli = 0; nli < n_length; nli++) { int step = nli * max_sm_len; int cur_length = min(max_sm_len, length - step); for (int64_t i = clusterid; i < token_num; i += nclusters) { - GM2LM(output_padding_offset + i, &output_padding_offset_now, sizeof(int)); - int64_t bi = (i + output_padding_offset_now) / max_seq_len; + GM2LM(batch_id_per_token_output + i, + &batch_id_per_token_output_now, + sizeof(int)); + int64_t bi = batch_id_per_token_output_now; if (bi >= bs || cur_len_lm[bi] < 0) { continue; } @@ -86,10 +88,10 @@ __device__ void speculate_update_repeat_times_normal( __device__ void speculate_update_repeat_times_optimized( char *lm, __shared_ptr__ char *sm, - __global_ptr__ const int64_t *pre_ids, // {bs, length_id} - __global_ptr__ const int64_t *cur_len, // {bs} - __global_ptr__ int *repeat_times, // {token_num, length} - __global_ptr__ const int *output_padding_offset, // {token_num} + __global_ptr__ const int64_t *pre_ids, // {bs, length_id} + __global_ptr__ const int64_t *cur_len, // {bs} + __global_ptr__ int *repeat_times, // {token_num, length} + __global_ptr__ const int *batch_id_per_token_output, // {token_num} const int64_t bs, const int64_t length, const int64_t length_id, @@ -108,10 +110,10 @@ __device__ void speculate_update_repeat_times_optimized( int cur_len_sm_len = 640; __shared_ptr__ int64_t *cur_len_sm = (__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len); - __shared_ptr__ int *output_padding_offset_sm = + __shared_ptr__ int *batch_id_per_token_output_sm = (__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len); - DoublePtr<1, SmPtr> buffer_ptr_output_padding_offset( - (SmPtr(output_padding_offset_sm))); + DoublePtr<1, SmPtr> buffer_ptr_batch_id_per_token_output( + (SmPtr(batch_id_per_token_output_sm))); int pre_ids_lm_len = 4; int64_t *pre_ids_lm = (int64_t *)lm; DoublePtr<4, LmPtr> buffer_ptr_pre_ids((LmPtr(pre_ids_lm))); @@ -119,18 +121,18 @@ __device__ void speculate_update_repeat_times_optimized( int64_t i = clusterid; if (i < token_num && cid == 0) { GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t)); - buffer_ptr_output_padding_offset.gm_load_async(output_padding_offset + i, - 1); + buffer_ptr_batch_id_per_token_output.gm_load_async( + batch_id_per_token_output + i, 1); mfence_sm(); } sync_all(); for (; i < token_num; i += nclusters) { if (cid == 0 && i + nclusters < token_num) { - buffer_ptr_output_padding_offset.next().gm_load_async( - output_padding_offset + i + nclusters, 1); + buffer_ptr_batch_id_per_token_output.next().gm_load_async( + batch_id_per_token_output + i + nclusters, 1); } - int64_t bi = (i + (buffer_ptr_output_padding_offset.ptr[0])) / max_seq_len; - buffer_ptr_output_padding_offset.toggle(); + int64_t bi = buffer_ptr_batch_id_per_token_output.ptr[0]; + buffer_ptr_batch_id_per_token_output.toggle(); if (bi >= bs || cur_len_sm[bi] < 0) { mfence_sm(); sync_all(); @@ -224,15 +226,16 @@ __device__ void speculate_update_repeat_times_optimized( } } -__global__ void speculate_update_repeat_times(const int64_t *pre_ids, - const int64_t *cur_len, - int *repeat_times, - const int *output_padding_offset, - const int64_t bs, - const int64_t length, - const int64_t length_id, - const int64_t token_num, - const int64_t max_seq_len) { +__global__ void speculate_update_repeat_times( + const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int *batch_id_per_token_output, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t token_num, + const int64_t max_seq_len) { char lm[6 * 1024]; __shared__ char sm[256 * 1024]; @@ -242,7 +245,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids, pre_ids, cur_len, repeat_times, - output_padding_offset, + batch_id_per_token_output, bs, length, length_id, @@ -254,7 +257,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids, pre_ids, cur_len, repeat_times, - output_padding_offset, + batch_id_per_token_output, bs, length, length_id, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu index f66e2919d21..3e139df84e2 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu @@ -25,7 +25,7 @@ __global__ void speculate_update_value_by_repeat_times( const T *presence_score, const float *temperatures, T *logits, - const int *output_padding_offset, + const int *batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t token_num, @@ -46,17 +46,16 @@ __global__ void speculate_update_value_by_repeat_times( if (token_end >= token_num) { token_end = token_num - 1; } - int output_padding_offset_start_lm; - int output_padding_offset_end_lm; - GM2LM_ASYNC(output_padding_offset + token_start, - (void *)&output_padding_offset_start_lm, + int batch_id_per_token_output_start_lm; + int batch_id_per_token_output_end_lm; + GM2LM_ASYNC(batch_id_per_token_output + token_start, + (void *)&batch_id_per_token_output_start_lm, sizeof(int)); - GM2LM(output_padding_offset + token_end, - (void *)&output_padding_offset_end_lm, + GM2LM(batch_id_per_token_output + token_end, + (void *)&batch_id_per_token_output_end_lm, sizeof(int)); - int64_t bs_start = - (token_start + output_padding_offset_start_lm) / max_seq_len; - int64_t bs_end = (token_end + output_padding_offset_end_lm) / max_seq_len; + int64_t bs_start = batch_id_per_token_output_start_lm; + int64_t bs_end = batch_id_per_token_output_end_lm; const int param_len = 256; // ncores = 64 for xpu2 __shared__ __simd__ float alpha_buf[param_len * 64]; @@ -89,13 +88,13 @@ __global__ void speculate_update_value_by_repeat_times( const int buffer_len = 512; __simd__ float logits_lm[buffer_len]; int times_lm[buffer_len]; - int output_padding_offset_lm[buffer_len]; + int batch_id_per_token_output_lm[buffer_len]; for (int64_t i = start; i < end; i += buffer_len) { int read_len = min(end - i, buffer_len); GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); - GM2LM_ASYNC(output_padding_offset + i / length, - output_padding_offset_lm, + GM2LM_ASYNC(batch_id_per_token_output + i / length, + batch_id_per_token_output_lm, ((read_len + length - 1) / length + 1) * sizeof(int)); GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); primitive_cast((const T *)(logits_lm), logits_lm, read_len); @@ -104,7 +103,7 @@ __global__ void speculate_update_value_by_repeat_times( logit_now = logits_lm[j]; int token_idx = (i + j) / length; int bs_idx = - (token_idx + output_padding_offset_lm[token_idx - i / length]) / + (token_idx + batch_id_per_token_output_lm[token_idx - i / length]) / max_seq_len; if (bs_idx >= bs) { continue; @@ -134,7 +133,7 @@ __global__ void speculate_update_value_by_repeat_times( const DATA_TYPE *presence_score, \ const float *temperatures, \ DATA_TYPE *logits, \ - const int *output_padding_offset, \ + const int *batch_id_per_token_output, \ const int64_t bs, \ const int64_t length, \ const int64_t token_num, \ @@ -151,7 +150,7 @@ __global__ void speculate_update_value_by_repeat_times_simd( const T *presence_score, // [bs] const float *temperatures, // [bs] T *logits, // [bs * length] - const int *output_padding_offset, + const int *batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t token_num, @@ -198,7 +197,7 @@ __global__ void speculate_update_value_by_repeat_times_simd( const int buffer_len = 512; __simd__ float logits_lm[buffer_len]; __simd__ float times_lm[buffer_len]; - int output_padding_offset_lm[buffer_len]; + int batch_id_per_token_output_lm[buffer_len]; float32x16_t logits_; float32x16_t logits_tmp_0; @@ -208,8 +207,8 @@ __global__ void speculate_update_value_by_repeat_times_simd( for (int64_t i = start; i < end; i += buffer_len) { int read_len = min(end - i, buffer_len); GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); - GM2LM_ASYNC(output_padding_offset + i / length, - output_padding_offset_lm, + GM2LM_ASYNC(batch_id_per_token_output + i / length, + batch_id_per_token_output_lm, ((read_len + length - 1) / length + 1) * sizeof(int)); GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); primitive_cast((const T *)(logits_lm), logits_lm, read_len); @@ -220,9 +219,7 @@ __global__ void speculate_update_value_by_repeat_times_simd( time_ = vload_lm_float32x16(times_lm + j); logits_ = vload_lm_float32x16(logits_lm + j); int token_idx = (i + j) / length; - int bs_idx = - (token_idx + output_padding_offset_lm[token_idx - i / length]) / - max_seq_len; + int bs_idx = batch_id_per_token_output_lm[token_idx - i / length]; if (bs_idx >= bs) { continue; } @@ -269,7 +266,7 @@ __global__ void speculate_update_value_by_repeat_times_simd( const DATA_TYPE *presence_score, \ const float *temperatures, \ DATA_TYPE *logits, \ - const int *output_padding_offset, \ + const int *batch_id_per_token_output, \ const int64_t bs, \ const int64_t length, \ const int64_t token_num, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu index 70c89bb7698..e8f8ed6883d 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -95,50 +95,31 @@ topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, template __global__ void speculate_verify( const int64_t *sampled_token_ids, - int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的 - // token(通过验证或采样) - int *accept_num, // out [real_bsz], 每个序列最终接受的 token - // 数量(只统计通过验证的) - int64_t - *step_idx, // out [real_bsz], 记录每个bid序列已经生成或接受的token数 - bool *stop_flags, // out [real_bsz], 每个序列的停止标志,遇到 - // 或长度超限时置 true - const int *seq_lens_encoder, // [real_bsz], 每个样本 encoder - // 输入长度,用于判断 prefill 阶段 - const int *seq_lens_decoder, // [real_bsz], 每个样本 decoder 输出的 token - // 数(即 draft token 数) - const int64_t * - draft_tokens, // [real_bsz, max_draft_tokens], draft model 输出的 token - const int *actual_draft_token_nums, // [real_bsz], draft_tokens - // 中实际有效的 token 数量 + int64_t *accept_tokens, // out [real_bsz, max_draft_tokens] + int *accept_num, // out [real_bsz], + int64_t *step_idx, // out [real_bsz], + bool *stop_flags, // out [real_bsz], + const int *seq_lens_encoder, // [real_bsz] + const int *seq_lens_decoder, // [real_bsz] + const int64_t *draft_tokens, // [real_bsz, max_draft_tokens], + const int *actual_draft_token_nums, // [real_bsz], 实际有效的 token 数量 const float *dev_curand_states, // used for random - const float *topp, // [real_bsz],TopP 阈值(如 - // 0.9),用于控制核采样截断概率和候选数 - const int *seq_lens_this_time, // [real_bsz], 本轮 verify - // 阶段每个样本实际参与验证的 token 数 + const float *topp, // [real_bsz], + const int *seq_lens_this_time, // [real_bsz], const int64_t - *verify_tokens, // [sum(seq_lens_this_time), max_candidate_len], verify - // decoder 输出的候选 token - const float - *verify_scores, // 同上, 每个 verify token 对应的概率分布,用于采样 + *verify_tokens, // [sum(seq_lens_this_time), max_candidate_len] + const float *verify_scores, const int64_t *max_dec_len, // [real_bsz], - // 每个样本允许生成的最大长度(超过则触发终止) - const int64_t - *end_tokens, // [end_length], 终止 token 列表(如 ),命中即终止 - const bool *is_block_step, // [real_bsz], 指示是否当前为 block step(为 - // true 时跳过 verify) - const int - *output_cum_offsets, // [real_bsz], verify_tokens 的起始偏移,用于定位 - // token 所在 verify 索引 - const int *actual_candidate_len, // [sum(seq_lens_this_time)], 每个 verify - // token 实际可用候选数(用于 TopP 截断) - const int real_bsz, // batch size - const int max_draft_tokens, // scalar, 每个样本最多允许的 draft token 数 + const int64_t *end_tokens, // [end_length] + const bool *is_block_step, // [real_bsz], + const int *cu_seqlens_q_output, + const int *actual_candidate_len, // [sum(seq_lens_this_time)], + const int real_bsz, // batch size + const int max_draft_tokens, const int end_length, - const int max_seq_len, // scalar, 每个序列的最大 token 数(用于偏移计算) - const int max_candidate_len, // scalar, 每个 verify token - // 的最大候选数(用于验证或采样) - const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数) + const int max_seq_len, + const int max_candidate_len, + const int verify_window, const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts, @@ -151,7 +132,7 @@ __global__ void speculate_verify( if (is_block_step[bid]) { continue; } - const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + const int start_token_id = cu_seqlens_q_output[bid]; if (stop_flags[bid]) { stop_flag_now_int = 1; } else { // 这里prefill阶段也会进入,但是因为draft diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu index 6cee5771e85..93bbfa1cbe5 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/top_p_candidates.xpu @@ -10,7 +10,7 @@ __device__ void top_p_candidates_big_n( char* lm, __global_ptr__ const T* src, __global_ptr__ const T* top_ps, - __global_ptr__ const int* output_padding_offset, + __global_ptr__ const int* batch_id_per_token_output, __global_ptr__ int64_t* out_id, __global_ptr__ T* out_val, __global_ptr__ int* actual_candidates_lens, @@ -32,11 +32,13 @@ __device__ void top_p_candidates_big_n( __shared__ T sm_out_val[64 * TopPBeamTopK]; // only used in core 0 - int lm_output_padding_offset; + int lm_batch_id_per_token_output; for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) { if (cid == 0) { - GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); + GM2LM(batch_id_per_token_output + i, + &lm_batch_id_per_token_output, + sizeof(int)); } for (int64_t j = 0; j < TopPBeamTopK; j++) { lm_out_id[j] = -1; @@ -142,8 +144,7 @@ __device__ void top_p_candidates_big_n( } } - int ori_token_id = i + lm_output_padding_offset; - int bid = ori_token_id / max_seq_len; + int bid = lm_batch_id_per_token_output; T lm_top_p; GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); float top_p_value = static_cast(lm_top_p); @@ -182,7 +183,7 @@ __device__ void top_p_candidates_normal( char* lm, __global_ptr__ const T* src, __global_ptr__ const T* top_ps, - __global_ptr__ const int* output_padding_offset, + __global_ptr__ const int* batch_id_per_token_output, __global_ptr__ int64_t* out_id, __global_ptr__ T* out_val, __global_ptr__ int* actual_candidates_lens, @@ -200,7 +201,7 @@ __device__ void top_p_candidates_normal( int64_t lm_out_id[TopPBeamTopK]; T lm_out_val[TopPBeamTopK]; - int lm_output_padding_offset; + int lm_batch_id_per_token_output; T lm_top_p; int64_t default_id = 0; T default_val = static_cast(0.f); @@ -236,9 +237,10 @@ __device__ void top_p_candidates_normal( } mfence_lm(); } - GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); - int ori_token_id = i + lm_output_padding_offset; - int bid = ori_token_id / max_seq_len; + GM2LM(batch_id_per_token_output + i, + &lm_batch_id_per_token_output, + sizeof(int)); + int bid = lm_batch_id_per_token_output; GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); float top_p_value = static_cast(lm_top_p); bool set_to_default_val = false; @@ -272,7 +274,7 @@ __device__ void top_p_candidates_normal( template __global__ void top_p_candidates(const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, T* out_val, int* actual_candidates_lens, @@ -284,29 +286,31 @@ __global__ void top_p_candidates(const T* src, if (token_num % (core_num() * cluster_num()) != 0 && vocab_size >= core_num() * (6 * 1024 / sizeof(T)) && vocab_size >= core_num() * TopPBeamTopK) { - top_p_candidates_big_n(lm, - src, - top_ps, - output_padding_offset, - out_id, - out_val, - actual_candidates_lens, - vocab_size, - token_num, - max_cadidate_len, - max_seq_len); + top_p_candidates_big_n( + lm, + src, + top_ps, + batch_id_per_token_output, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + max_cadidate_len, + max_seq_len); } else { - top_p_candidates_normal(lm, - src, - top_ps, - output_padding_offset, - out_id, - out_val, - actual_candidates_lens, - vocab_size, - token_num, - max_cadidate_len, - max_seq_len); + top_p_candidates_normal( + lm, + src, + top_ps, + batch_id_per_token_output, + out_id, + out_val, + actual_candidates_lens, + vocab_size, + token_num, + max_cadidate_len, + max_seq_len); } } @@ -314,7 +318,7 @@ __global__ void top_p_candidates(const T* src, template __global__ void top_p_candidates( \ const T* src, \ const T* top_ps, \ - const int* output_padding_offset, \ + const int* batch_id_per_token_output, \ int64_t* out_id, \ T* out_val, \ int* actual_candidates_lens, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/unified_update_model_status.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/unified_update_model_status.xpu new file mode 100644 index 00000000000..d4a225950c0 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/unified_update_model_status.xpu @@ -0,0 +1,217 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +#include "xpu/kernel/cluster_debug.h" + +namespace fd_xpu3 { + +#define MAX_BATCH_SIZE 1024 + +static inline __device__ int v_reduce_sum_int32(int32x16_t &v0) { + auto v1 = vsrlp_int32x16(1 << 8, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 7, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 6, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 5, v0); + v0 = vvadd_int32x16(v0, v1); + return vextract_int32x16(v0, 1); +} + +inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int *x, + int64_t len) { + int32x16_t x_l, x_h; + int32x16_t sum = vset_zero_int(); + const auto rounddown_len = rounddown32(len); + + for (int64_t i = 0; i < rounddown_len; i += 32) { + vload2_sm(x + i, x_l, x_h); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + + if (rounddown_len < len) { + const auto mask = ~(-1 << (len - rounddown_len)); + vload2_sm_mz(x + rounddown_len, x_l, x_h, mask); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + return v_reduce_sum_int32(sum); +} + +inline __device__ bool is_end_token(int64_t token, + __shared_ptr__ const int64_t *end_tokens, + int num_end_tokens) { +#pragma unroll 4 + for (int i = 0; i < num_end_tokens; i++) { + if (token == end_tokens[i]) return true; + } + return false; +} + +__global__ void unified_update_model_status_kernel(int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid > 0) return; + __shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE]; + __shared__ int sm_seq_lens_decoder[MAX_BATCH_SIZE]; + __shared__ bool sm_stop_flags[MAX_BATCH_SIZE]; + __shared__ int64_t sm_step_idx[MAX_BATCH_SIZE]; + __shared__ bool sm_is_paused[MAX_BATCH_SIZE]; + __shared__ int64_t sm_end_tokens[MAX_BATCH_SIZE]; + + __shared__ int sm_cum_seq_len, sm_cum_seq_len_output; + __shared__ int buffer_stop_flag_int[64]; + if (cid == 0) { + GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * max_bsz); + GM2SM_ASYNC(seq_lens_decoder, sm_seq_lens_decoder, sizeof(int) * max_bsz); + GM2SM_ASYNC(stop_flags, sm_stop_flags, sizeof(bool) * max_bsz); + GM2SM_ASYNC(step_idx, sm_step_idx, sizeof(int64_t) * max_bsz); + GM2SM_ASYNC(is_paused, sm_is_paused, sizeof(bool) * max_bsz); + GM2SM_ASYNC(end_tokens, sm_end_tokens, sizeof(int64_t) * num_end_tokens); + } + buffer_stop_flag_int[cid] = 0; + mfence_sm(); + sync_all(); + for (int batch_id = cid; batch_id < max_bsz; batch_id += ncores) { + // Read state + int cur_seq_len_encoder = sm_seq_lens_encoder[batch_id]; + int cur_seq_len_decoder = sm_seq_lens_decoder[batch_id]; + bool cur_stop_flag = sm_stop_flags[batch_id]; + int output_len = 0; + int64_t cur_step_idx = sm_step_idx[batch_id]; + bool cur_is_paused = sm_is_paused[batch_id]; + + bool is_running = !cur_stop_flag && !cur_is_paused; + + // Compute output length + if (is_running) { + if (is_naive_mode) { + output_len = 1; + } else { + output_len = step_output_len[batch_id]; + } + } + + // EOS detection + if (is_running && output_len > 0) { + bool hit_stop = false; + __global_ptr__ int64_t *output_ids = + &step_output_ids[batch_id * max_step_tokens]; + + for (int i = 0; i < output_len; i++) { + cur_step_idx++; + int64_t token = output_ids[i]; + bool is_eos = is_end_token(token, sm_end_tokens, num_end_tokens); + bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]); + + if (is_eos || max_len_hit) { + if (!is_eos) output_ids[i] = sm_end_tokens[0]; + output_len = i + 1; + cur_stop_flag = true; + hit_stop = true; + break; + } + } + + if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) { + cur_stop_flag = true; + } + } + + // Update state and write back + if (is_running) { + if (cur_stop_flag) { + buffer_stop_flag_int[cid] += 1; + if (output_len == 0) cur_seq_len_decoder = 0; + stop_flags[batch_id] = true; + mask_rollback[batch_id] = 0; + } else if (cur_seq_len_encoder == 0) { + cur_seq_len_decoder += output_len; + mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len; + } else { + mask_rollback[batch_id] = 0; + } + + if (cur_seq_len_encoder > 0) { + cur_seq_len_decoder += cur_seq_len_encoder; + cur_seq_len_encoder = 0; + } + + seq_lens_encoder[batch_id] = cur_seq_len_encoder; + seq_lens_decoder[batch_id] = cur_seq_len_decoder; + step_output_len[batch_id] = output_len; + step_idx[batch_id] = cur_step_idx; + + // Write history to token_ids_all + if (cur_step_idx > 0 && output_len > 0) { + // Bounds check: highest write index is prompt_lens + cur_step_idx + if (prompt_lens[batch_id] + cur_step_idx < max_model_len) { + __global_ptr__ int64_t *token_ids_all_now = + &token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]]; + __global_ptr__ int64_t *output_ids = + &step_output_ids[batch_id * max_step_tokens]; + for (int i = 0; i < output_len; i++) { + token_ids_all_now[cur_step_idx - i] = + output_ids[output_len - 1 - i]; + } + } + } + + // Setup next input + if (output_len > 0) { + step_input_ids[batch_id * max_step_tokens] = + step_output_ids[batch_id * max_step_tokens + output_len - 1]; + } + + if (is_naive_mode) { + seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1; + } + } else if (batch_id >= real_bsz) { + // Padding slot: just count as stopped, don't modify state + buffer_stop_flag_int[cid] += 1; + } else { + // Stopped or paused slot (batch_id < real_bsz) + buffer_stop_flag_int[cid] += 1; + stop_flags[batch_id] = true; + seq_lens_decoder[batch_id] = 0; + seq_lens_this_time[batch_id] = 0; + step_output_len[batch_id] = 0; + } + } + mfence_sm(); + sync_all(); + int stop_flag_int = 0; + if (cid == 0) { + for (int i = 0; i < ncores; i++) { + stop_flag_int += buffer_stop_flag_int[i]; + } + } + has_running_seqs[0] = stop_flag_int < max_bsz; +} +} // namespace fd_xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp index b5d24e9d6e9..004bf17ef0e 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_update.cpp @@ -24,7 +24,7 @@ __attribute__((global)) void draft_model_update( int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -60,7 +60,7 @@ static int cpu_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -82,8 +82,7 @@ static int cpu_wrapper(api::Context* ctx, auto* pre_ids_now = pre_ids + tid * pre_id_length; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * max_base_model_draft_token; - const int next_tokens_start_id = - tid * max_seq_len - output_cum_offsets[tid]; + const int next_tokens_start_id = cu_seqlens_q_output[tid]; auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; auto seq_len_this_time = seq_lens_this_time[tid]; auto seq_len_encoder = seq_lens_encoder[tid]; @@ -158,7 +157,7 @@ static int xpu3_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -182,7 +181,7 @@ static int xpu3_wrapper(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, reinterpret_cast(step_idx), - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, reinterpret_cast(max_dec_len), @@ -209,7 +208,7 @@ int draft_model_update(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -234,7 +233,7 @@ int draft_model_update(api::Context* ctx, seq_lens_decoder); WRAPPER_DUMP_PARAM6(ctx, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, @@ -255,7 +254,7 @@ int draft_model_update(api::Context* ctx, WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder); WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder); WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx); - WRAPPER_CHECK_PTR(ctx, int, bsz, output_cum_offsets); + WRAPPER_CHECK_PTR(ctx, int, bsz, cu_seqlens_q_output); WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags); WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop); WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len); @@ -272,7 +271,7 @@ int draft_model_update(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, @@ -296,7 +295,7 @@ int draft_model_update(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_preprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_preprocess.cpp new file mode 100644 index 00000000000..3e2faf31fcf --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_preprocess.cpp @@ -0,0 +1,244 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/xdnn_impl.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace fd_xpu3 { + +__attribute__((global)) void speculate_preprocess_kernel( + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int real_bs); +} // namespace fd_xpu3 + +namespace fastdeploy { +namespace plugin { + +static int cpu_wrapper(api::Context* ctx, + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int token_num_data, + const int real_bs) { + cu_seqlens_q[0] = 0; + cu_seqlens_k[0] = 0; + for (int i = 0; i < real_bs; ++i) { + const int seq_len = seq_lens[i]; + cu_seqlens_q[i + 1] = cu_seqlens_q[i] + seq_len; + cu_seqlens_k[i + 1] = cu_seqlens_k[i] + seq_len; + } + + for (int bi = 0; bi < real_bs; ++bi) { + for (int i = 0; i < seq_lens[bi]; ++i) { + const int tgt_seq_id = cu_seqlens_q[bi + 1] - seq_lens[bi] + i; + if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) { + // speculative decoding + const int src_seq_id = bi * max_draft_tokens_per_batch + i; + ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id]; + } else { + // Non-speculative decoding + const int src_seq_id = bi * max_seq_len + i; + ids_remove_padding[tgt_seq_id] = input_data[src_seq_id]; + } + batch_id_per_token[tgt_seq_id] = bi; + } + } + + for (int bid = 0; bid < real_bs; ++bid) { + if (seq_lens[bid] == 0) { + seq_lens_output[bid] = 0; + } else if (seq_lens[bid] == 1) { + seq_lens_output[bid] = 1; + } else if (seq_lens_encoder[bid] != 0) { + seq_lens_output[bid] = 1; + } else { + seq_lens_output[bid] = seq_lens[bid]; + } + } + + cu_seq_lens_q_output[0] = 0; + for (int i = 0; i < real_bs; ++i) { + cu_seq_lens_q_output[i + 1] = cu_seq_lens_q_output[i] + seq_lens_output[i]; + } + real_output_token_num[0] = cu_seq_lens_q_output[real_bs]; + + for (int bi = 0; bi < real_bs; ++bi) { + for (int i = 0; i < seq_lens_output[bi]; ++i) { + const int tgt_seq_id_output = + cu_seq_lens_q_output[bi + 1] - seq_lens_output[bi] + i; + batch_id_per_token_output[tgt_seq_id_output] = bi; + } + } + + return api::SUCCESS; +} + +static int xpu3_wrapper(api::Context* ctx, + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int token_num_data, + const int real_bs) { + using XPU_INT64 = typename api::XPUIndexType::type; + int32_t ret_xre = fd_xpu3:: + speculate_preprocess_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(ids_remove_padding), + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_output, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num, + reinterpret_cast(input_data), + seq_lens, + reinterpret_cast(draft_tokens), + seq_lens_encoder, + max_seq_len, + max_draft_tokens_per_batch, + real_bs); + KERNEL_ASSERT_SUCCESS(ctx, ret_xre); + return api::SUCCESS; +} + +int speculate_preprocess(api::Context* ctx, + int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cu_seqlens_q, + int* cu_seqlens_k, + int* seq_lens_output, + int* cu_seq_lens_q_output, + int* batch_id_per_token_output, + int* real_output_token_num, + const int64_t* input_data, + const int* seq_lens, + const int64_t* draft_tokens, + const int* seq_lens_encoder, + const int max_seq_len, + const int max_draft_tokens_per_batch, + const int token_num_data, + const int real_bs) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_preprocess", int); + WRAPPER_DUMP_PARAM6(ctx, + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_output, + cu_seq_lens_q_output); + WRAPPER_DUMP_PARAM6(ctx, + batch_id_per_token_output, + real_output_token_num, + input_data, + seq_lens, + draft_tokens, + seq_lens_encoder); + WRAPPER_DUMP_PARAM3(ctx, max_seq_len, max_draft_tokens_per_batch, real_bs); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int64_t, token_num_data, ids_remove_padding); + WRAPPER_CHECK_PTR(ctx, int, token_num_data, batch_id_per_token); + WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_q); + WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_k); + WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_output); + WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seq_lens_q_output); + WRAPPER_CHECK_PTR( + ctx, int, real_bs* max_draft_tokens_per_batch, batch_id_per_token_output); + WRAPPER_CHECK_PTR(ctx, int, 1, real_output_token_num); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bs * max_seq_len, input_data); + WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens); + WRAPPER_CHECK_PTR( + ctx, int, real_bs* max_draft_tokens_per_batch, draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_encoder); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_output, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num, + input_data, + seq_lens, + draft_tokens, + seq_lens_encoder, + max_seq_len, + max_draft_tokens_per_batch, + token_num_data, + real_bs); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_output, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num, + input_data, + seq_lens, + draft_tokens, + seq_lens_encoder, + max_seq_len, + max_draft_tokens_per_batch, + token_num_data, + real_bs); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace fastdeploy diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp index 36db3e0f530..fe3ed8a0a52 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_token_penalty_multi_scores.cpp @@ -25,8 +25,8 @@ __attribute__((global)) void speculate_min_length_logits_process( const int64_t* cur_len, const int64_t* min_len, const int64_t* eos_token_id, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -37,7 +37,7 @@ __attribute__((global)) void speculate_update_repeat_times( const int64_t* pre_ids, const int64_t* cur_len, int* repeat_times, - const int* output_padding_offset, + const int* batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -146,7 +146,7 @@ static int cpu_wrapper(api::Context* ctx, const int64_t* eos_token_id, const int64_t* bad_words, const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -172,7 +172,7 @@ static int cpu_wrapper(api::Context* ctx, WRAPPER_ASSERT_SUCCESS(ctx, ret); for (int64_t i = 0; i < token_num; i++) { int64_t bi = (i + output_padding_offset[i]) / max_seq_len; - int64_t query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]; + int64_t query_start_token_idx = batch_id_per_token_output[bi]; if (bi < bs && cur_len[bi] >= 0 && (cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) { for (int64_t j = 0; j < end_length; j++) { @@ -236,8 +236,8 @@ static int xpu3_wrapper(api::Context* ctx, const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -268,7 +268,7 @@ static int xpu3_wrapper(api::Context* ctx, reinterpret_cast(pre_ids), reinterpret_cast(cur_len), repeat_times, - output_padding_offset, + batch_id_per_token_output, bs, length, length_id, @@ -282,8 +282,8 @@ static int xpu3_wrapper(api::Context* ctx, reinterpret_cast(cur_len), reinterpret_cast(min_len), reinterpret_cast(eos_token_id), - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, bs, length, length_id, @@ -300,7 +300,7 @@ static int xpu3_wrapper(api::Context* ctx, presence_scores, temperatures, logits, - output_padding_offset, + batch_id_per_token_output, bs, length, token_num, @@ -311,7 +311,7 @@ static int xpu3_wrapper(api::Context* ctx, ret_xre = ban_bad_words_kernel<<ncluster(), 64, ctx->xpu_stream>>>( logits, reinterpret_cast(bad_words), - output_padding_offset, + batch_id_per_token_output, bs, length, length_bad_words, @@ -334,8 +334,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -357,8 +357,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, min_len, eos_token_id, bad_words, - output_padding_offset, - output_cum_offsets); + cu_seqlens_q_output, + batch_id_per_token_output); WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len); WRAPPER_DUMP(ctx); @@ -373,8 +373,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, int64_t min_len_len = -1; int64_t eos_token_id_len = -1; int64_t bad_words_len = -1; - int64_t output_padding_offset_len = -1; - int64_t output_cum_offsets_len = -1; + // int64_t output_padding_offset_len = -1; + // int64_t output_cum_offsets_len = -1; WRAPPER_ASSERT_LE(ctx, bs, 640); WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length}); @@ -386,8 +386,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); - WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num}); - WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs}); + // WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num}); + // WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs}); WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); @@ -398,8 +398,9 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); - WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, output_padding_offset); - WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, output_cum_offsets); + // WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, + // output_padding_offset); WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, + // output_cum_offsets); if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, pre_ids, @@ -412,8 +413,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, min_len, eos_token_id, bad_words, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, bs, length, length_id, @@ -434,8 +435,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx, min_len, eos_token_id, bad_words, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, bs, length, length_id, @@ -459,8 +460,8 @@ template int speculate_token_penalty_multi_scores( const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -480,8 +481,8 @@ template int speculate_token_penalty_multi_scores( const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, @@ -501,8 +502,8 @@ template int speculate_token_penalty_multi_scores( const int64_t* min_len, const int64_t* eos_token_id, const int64_t* bad_words, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t bs, const int64_t length, const int64_t length_id, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index 549ae7189cd..cdfba51763c 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -23,25 +23,25 @@ typedef uint32_t curandStatePhilox4_32_10_t; template __attribute__((global)) void speculate_verify( - const int64_t* sampled_token_ids, - int64_t* accept_tokens, - int* accept_num, - int64_t* step_idx, - bool* stop_flags, - const int* seq_lens_encoder, - const int* seq_lens_decoder, - const int64_t* draft_tokens, - const int* actual_draft_token_nums, - const float* dev_curand_states, - const float* topp, - const int* seq_lens_this_time, - const int64_t* verify_tokens, - const float* verify_scores, - const int64_t* max_dec_len, - const int64_t* end_tokens, - const bool* is_block_step, - const int* output_cum_offsets, - const int* actual_candidate_len, + const int64_t *sampled_token_ids, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, @@ -58,7 +58,7 @@ namespace fastdeploy { namespace plugin { static inline bool is_in_end(const int64_t id, - const int64_t* end_ids, + const int64_t *end_ids, int length) { bool flag = false; for (int i = 0; i < length; i++) { @@ -69,7 +69,7 @@ static inline bool is_in_end(const int64_t id, return flag; } -static inline bool is_in(const int64_t* candidates, +static inline bool is_in(const int64_t *candidates, const int64_t draft, const int candidate_len) { for (int i = 0; i < candidate_len; i++) { @@ -80,7 +80,7 @@ static inline bool is_in(const int64_t* candidates, return false; } -static inline unsigned int xorwow(unsigned int& state) { // NOLINT +static inline unsigned int xorwow(unsigned int &state) { // NOLINT state ^= state >> 7; state ^= state << 9; state ^= state >> 13; @@ -89,9 +89,9 @@ static inline unsigned int xorwow(unsigned int& state) { // NOLINT typedef uint32_t curandStatePhilox4_32_10_t; -static int64_t topp_sampling_kernel(const int64_t* candidate_ids, - const float* candidate_scores, - const float* dev_curand_states, +static int64_t topp_sampling_kernel(const int64_t *candidate_ids, + const float *candidate_scores, + const float *dev_curand_states, const int candidate_len, const float topp, int tid) { @@ -111,26 +111,26 @@ static int64_t topp_sampling_kernel(const int64_t* candidate_ids, } template -static int cpu_wrapper(api::Context* ctx, - const int64_t* sampled_token_ids, - int64_t* accept_tokens, - int* accept_num, - int64_t* step_idx, - bool* stop_flags, - const int* seq_lens_encoder, - const int* seq_lens_decoder, - const int64_t* draft_tokens, - const int* actual_draft_token_nums, - const float* dev_curand_states, - const float* topp, - const int* seq_lens_this_time, - const int64_t* verify_tokens, - const float* verify_scores, - const int64_t* max_dec_len, - const int64_t* end_tokens, - const bool* is_block_step, - const int* output_cum_offsets, - const int* actual_candidate_len, +static int cpu_wrapper(api::Context *ctx, + const int64_t *sampled_token_ids, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, @@ -147,7 +147,7 @@ static int cpu_wrapper(api::Context* ctx, int stop_flag_now_int = 0; if (!(is_block_step[bid] || bid >= real_bsz)) { - const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + const int start_token_id = cu_seqlens_q_output[bid]; // printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id); // printf("bid %d\n", bid); @@ -155,11 +155,11 @@ static int cpu_wrapper(api::Context* ctx, stop_flag_now_int = 1; } else { // 这里prefill阶段也会进入,但是因为draft // tokens会置零,因此会直接到最后的采样阶段 - auto* verify_tokens_now = + auto *verify_tokens_now = verify_tokens + start_token_id * max_candidate_len; - auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens; - auto* actual_candidate_len_now = actual_candidate_len + start_token_id; - auto* sampled_token_id_now = sampled_token_ids + start_token_id; + auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; + auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + auto *sampled_token_id_now = sampled_token_ids + start_token_id; int i = 0; // printf("seq_lens_this_time[%d]-1: %d \n",bid, @@ -306,7 +306,7 @@ static int cpu_wrapper(api::Context* ctx, // 也是从verify_tokens_now[i]中选一个 但是停止的情况不算 if (!stop_flag_now_int) { int64_t accept_token; - const float* verify_scores_now = + const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len; step_idx[bid]++; if (use_target_sampling) { @@ -347,26 +347,26 @@ static int cpu_wrapper(api::Context* ctx, } template -static int xpu3_wrapper(api::Context* ctx, - const int64_t* sampled_token_ids, - int64_t* accept_tokens, - int* accept_num, - int64_t* step_idx, - bool* stop_flags, - const int* seq_lens_encoder, - const int* seq_lens_decoder, - const int64_t* draft_tokens, - const int* actual_draft_token_nums, - const float* dev_curand_states, - const float* topp, - const int* seq_lens_this_time, - const int64_t* verify_tokens, - const float* verify_scores, - const int64_t* max_dec_len, - const int64_t* end_tokens, - const bool* is_block_step, - const int* output_cum_offsets, - const int* actual_candidate_len, +static int xpu3_wrapper(api::Context *ctx, + const int64_t *sampled_token_ids, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, @@ -380,24 +380,24 @@ static int xpu3_wrapper(api::Context* ctx, using XPU_INT64 = typename api::XPUIndexType::type; int32_t ret_xre = fd_xpu3::speculate_verify <<ncluster(), 64, ctx->xpu_stream>>>( - reinterpret_cast(sampled_token_ids), - reinterpret_cast(accept_tokens), + reinterpret_cast(sampled_token_ids), + reinterpret_cast(accept_tokens), accept_num, - reinterpret_cast(step_idx), + reinterpret_cast(step_idx), stop_flags, seq_lens_encoder, seq_lens_decoder, - reinterpret_cast(draft_tokens), + reinterpret_cast(draft_tokens), actual_draft_token_nums, dev_curand_states, topp, seq_lens_this_time, - reinterpret_cast(verify_tokens), + reinterpret_cast(verify_tokens), verify_scores, - reinterpret_cast(max_dec_len), - reinterpret_cast(end_tokens), + reinterpret_cast(max_dec_len), + reinterpret_cast(end_tokens), is_block_step, - output_cum_offsets, + cu_seqlens_q_output, actual_candidate_len, real_bsz, max_draft_tokens, @@ -413,26 +413,26 @@ static int xpu3_wrapper(api::Context* ctx, return api::SUCCESS; } template -int speculate_verify(api::Context* ctx, - const int64_t* sampled_token_ids, - int64_t* accept_tokens, - int* accept_num, - int64_t* step_idx, - bool* stop_flags, - const int* seq_lens_encoder, - const int* seq_lens_decoder, - const int64_t* draft_tokens, - const int* actual_draft_token_nums, - const float* dev_curand_states, - const float* topp, - const int* seq_lens_this_time, - const int64_t* verify_tokens, - const float* verify_scores, - const int64_t* max_dec_len, - const int64_t* end_tokens, - const bool* is_block_step, - const int* output_cum_offsets, - const int* actual_candidate_len, +int speculate_verify(api::Context *ctx, + const int64_t *sampled_token_ids, + int64_t *accept_tokens, + int *accept_num, + int64_t *step_idx, + bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *draft_tokens, + const int *actual_draft_token_nums, + const float *dev_curand_states, + const float *topp, + const int *seq_lens_this_time, + const int64_t *verify_tokens, + const float *verify_scores, + const int64_t *max_dec_len, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, @@ -462,7 +462,7 @@ int speculate_verify(api::Context* ctx, end_tokens); WRAPPER_DUMP_PARAM5(ctx, is_block_step, - output_cum_offsets, + cu_seqlens_q_output, actual_candidate_len, real_bsz, max_draft_tokens); @@ -492,7 +492,7 @@ int speculate_verify(api::Context* ctx, WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len); WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); - WRAPPER_CHECK_PTR(ctx, int, real_bsz, output_cum_offsets); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_seqlens_q_output); // WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len); // param check sm size limit @@ -525,7 +525,7 @@ int speculate_verify(api::Context* ctx, max_dec_len, end_tokens, is_block_step, - output_cum_offsets, + cu_seqlens_q_output, actual_candidate_len, real_bsz, max_draft_tokens, @@ -557,7 +557,7 @@ int speculate_verify(api::Context* ctx, max_dec_len, end_tokens, is_block_step, - output_cum_offsets, + cu_seqlens_q_output, actual_candidate_len, real_bsz, max_draft_tokens, @@ -575,36 +575,36 @@ int speculate_verify(api::Context* ctx, #define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ template int fastdeploy::plugin::speculate_verify( \ - fastdeploy::plugin::api::Context*, /* xpu_ctx */ \ - const int64_t*, /* sampled_token_ids */ \ - int64_t*, /* accept_tokens */ \ - int*, /* accept_num */ \ - int64_t*, /* step_idx */ \ - bool*, /* stop_flags */ \ - const int*, /* seq_lens_encoder */ \ - const int*, /* seq_lens_decoder */ \ - const int64_t*, /* draft_tokens */ \ - const int*, /* actual_draft_token_nums */ \ - const float*, /* dev_curand_states or topp */ \ - const float*, /* topp or nullptr */ \ - const int*, /* seq_lens_this_time */ \ - const int64_t*, /* verify_tokens */ \ - const float*, /* verify_scores */ \ - const int64_t*, /* max_dec_len */ \ - const int64_t*, /* end_tokens */ \ - const bool*, /* is_block_step */ \ - const int*, /* output_cum_offsets */ \ - const int*, /* actual_candidate_len */ \ - int, /* real_bsz */ \ - int, /* max_draft_tokens */ \ - int, /* end_length */ \ - int, /* max_seq_len */ \ - int, /* max_candidate_len */ \ - int, /* verify_window */ \ - bool, /* prefill_one_step_stop */ \ - bool, /* benchmark_mode */ \ - bool, /* accept_all_drafts */ \ - bool /* use_target_sampling */ \ + fastdeploy::plugin::api::Context *, /* xpu_ctx */ \ + const int64_t *, /* sampled_token_ids */ \ + int64_t *, /* accept_tokens */ \ + int *, /* accept_num */ \ + int64_t *, /* step_idx */ \ + bool *, /* stop_flags */ \ + const int *, /* seq_lens_encoder */ \ + const int *, /* seq_lens_decoder */ \ + const int64_t *, /* draft_tokens */ \ + const int *, /* actual_draft_token_nums */ \ + const float *, /* dev_curand_states or topp */ \ + const float *, /* topp or nullptr */ \ + const int *, /* seq_lens_this_time */ \ + const int64_t *, /* verify_tokens */ \ + const float *, /* verify_scores */ \ + const int64_t *, /* max_dec_len */ \ + const int64_t *, /* end_tokens */ \ + const bool *, /* is_block_step */ \ + const int *, /* cu_seqlens_q_output */ \ + const int *, /* actual_candidate_len */ \ + int, /* real_bsz */ \ + int, /* max_draft_tokens */ \ + int, /* end_length */ \ + int, /* max_seq_len */ \ + int, /* max_candidate_len */ \ + int, /* verify_window */ \ + bool, /* prefill_one_step_stop */ \ + bool, /* benchmark_mode */ \ + bool, /* accept_all_drafts */ \ + bool /* use_target_sampling */ \ ); INSTANTIATE_SPECULATE_VERIFY(false, false) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp index af44bb7406b..d1f880ecf17 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/top_p_candidates.cpp @@ -17,16 +17,17 @@ namespace fd_xpu3 { template -__attribute__((global)) void top_p_candidates(const T* src, - const T* top_ps, - const int* output_padding_offset, - int64_t* out_id, - T* out_val, - int* actual_candidates_lens, - int vocab_size, - int token_num, - int max_candidate_len, - int max_seq_len); +__attribute__((global)) void top_p_candidates( + const T* src, + const T* top_ps, + const int* batch_id_per_token_output, + int64_t* out_id, + T* out_val, + int* actual_candidates_lens, + int vocab_size, + int token_num, + int max_candidate_len, + int max_seq_len); } // namespace fd_xpu3 namespace fastdeploy { @@ -36,7 +37,7 @@ template static int cpu_wrapper(api::Context* ctx, const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, T* out_val, int* actual_candidates_lens, @@ -70,8 +71,7 @@ static int cpu_wrapper(api::Context* ctx, } } } - int ori_token_id = i + output_padding_offset[i]; - int bid = ori_token_id / max_seq_len; + int bid = batch_id_per_token_output[i]; float top_p_value = static_cast(top_ps[bid]); bool set_to_default_val = false; for (int j = 0; j < TopPBeamTopK; j++) { @@ -97,7 +97,7 @@ template static int xpu3_wrapper(api::Context* ctx, const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, T* out_val, int* actual_candidates_lens, @@ -110,7 +110,7 @@ static int xpu3_wrapper(api::Context* ctx, <<ncluster(), 64, ctx->xpu_stream>>>( src, top_ps, - output_padding_offset, + batch_id_per_token_output, reinterpret_cast(out_id), out_val, actual_candidates_lens, @@ -126,7 +126,7 @@ template int top_p_candidates(api::Context* ctx, const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, T* out_val, int* actual_candidates_lens, @@ -136,7 +136,8 @@ int top_p_candidates(api::Context* ctx, int max_seq_len) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T); - WRAPPER_DUMP_PARAM5(ctx, src, top_ps, output_padding_offset, out_id, out_val); + WRAPPER_DUMP_PARAM5( + ctx, src, top_ps, batch_id_per_token_output, out_id, out_val); WRAPPER_DUMP_PARAM5(ctx, actual_candidates_lens, vocab_size, @@ -146,7 +147,7 @@ int top_p_candidates(api::Context* ctx, WRAPPER_DUMP(ctx); WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src); - WRAPPER_CHECK_PTR(ctx, T, token_num, output_padding_offset); + WRAPPER_CHECK_PTR(ctx, T, token_num, batch_id_per_token_output); WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id); WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val); @@ -161,7 +162,7 @@ int top_p_candidates(api::Context* ctx, return cpu_wrapper(ctx, src, top_ps, - output_padding_offset, + batch_id_per_token_output, out_id, out_val, actual_candidates_lens, @@ -173,7 +174,7 @@ int top_p_candidates(api::Context* ctx, return xpu3_wrapper(ctx, src, top_ps, - output_padding_offset, + batch_id_per_token_output, out_id, out_val, actual_candidates_lens, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/unified_update_model_status.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/unified_update_model_status.cpp new file mode 100644 index 00000000000..062219bfb25 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/unified_update_model_status.cpp @@ -0,0 +1,376 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/xdnn_impl.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace fd_xpu3 { + +__attribute__((global)) void unified_update_model_status_kernel( + int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop); +} // namespace fd_xpu3 + +namespace fastdeploy { +namespace plugin { + +bool is_end_token(int64_t token, + const int64_t *end_tokens, + int num_end_tokens) { +#pragma unroll 4 + for (int i = 0; i < num_end_tokens; i++) { + if (token == end_tokens[i]) return true; + } + return false; +} + +static int cpu_wrapper(api::Context *ctx, + int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop) { + int stop_flag_int = 0; + + for (int batch_id = 0; batch_id < max_bsz; batch_id++) { + // Read state + int cur_seq_len_encoder = seq_lens_encoder[batch_id]; + int cur_seq_len_decoder = seq_lens_decoder[batch_id]; + bool cur_stop_flag = stop_flags[batch_id]; + int output_len = 0; + int64_t cur_step_idx = step_idx[batch_id]; + bool cur_is_paused = is_paused[batch_id]; + + bool is_running = !cur_stop_flag && !cur_is_paused; + + // Compute output length + if (is_running) { + if (is_naive_mode) { + output_len = 1; + } else { + output_len = step_output_len[batch_id]; + } + } + + // EOS detection + if (is_running && output_len > 0) { + bool hit_stop = false; + int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; + + for (int i = 0; i < output_len; i++) { + cur_step_idx++; + int64_t token = output_ids[i]; + bool is_eos = is_end_token(token, end_tokens, num_end_tokens); + bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]); + + if (is_eos || max_len_hit) { + if (!is_eos) output_ids[i] = end_tokens[0]; + output_len = i + 1; + cur_stop_flag = true; + hit_stop = true; + break; + } + } + + if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) { + cur_stop_flag = true; + } + } + + // Update state and write back + if (is_running) { + if (cur_stop_flag) { + stop_flag_int += 1; + if (output_len == 0) cur_seq_len_decoder = 0; + stop_flags[batch_id] = true; + mask_rollback[batch_id] = 0; + } else if (cur_seq_len_encoder == 0) { + cur_seq_len_decoder += output_len; + mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len; + } else { + mask_rollback[batch_id] = 0; + } + + if (cur_seq_len_encoder > 0) { + cur_seq_len_decoder += cur_seq_len_encoder; + cur_seq_len_encoder = 0; + } + + seq_lens_encoder[batch_id] = cur_seq_len_encoder; + seq_lens_decoder[batch_id] = cur_seq_len_decoder; + step_output_len[batch_id] = output_len; + step_idx[batch_id] = cur_step_idx; + + // Write history to token_ids_all + if (cur_step_idx > 0 && output_len > 0) { + // Bounds check: highest write index is prompt_lens + cur_step_idx + if (prompt_lens[batch_id] + cur_step_idx < max_model_len) { + int64_t *token_ids_all_now = + &token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]]; + int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; + for (int i = 0; i < output_len; i++) { + token_ids_all_now[cur_step_idx - i] = + output_ids[output_len - 1 - i]; + } + } + } + + // Setup next input + if (output_len > 0) { + step_input_ids[batch_id * max_step_tokens] = + step_output_ids[batch_id * max_step_tokens + output_len - 1]; + } + + if (is_naive_mode) { + seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1; + } + } else if (batch_id >= real_bsz) { + // Padding slot: just count as stopped, don't modify state + stop_flag_int += 1; + } else { + // Stopped or paused slot (batch_id < real_bsz) + stop_flag_int += 1; + stop_flags[batch_id] = true; + seq_lens_decoder[batch_id] = 0; + seq_lens_this_time[batch_id] = 0; + step_output_len[batch_id] = 0; + } + } + has_running_seqs[0] = stop_flag_int < max_bsz; + return api::SUCCESS; +} + +static int xpu3_wrapper(api::Context *ctx, + int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop) { + using XPU_INT64 = typename api::XPUIndexType::type; + int32_t ret_xre = + fd_xpu3::unified_update_model_status_kernel<<ncluster(), + 64, + ctx->xpu_stream>>>( + seq_lens_encoder, + seq_lens_decoder, + has_running_seqs, + mask_rollback, + reinterpret_cast(step_input_ids), + adaptive_step_input_len, + reinterpret_cast(step_output_ids), + step_output_len, + stop_flags, + seq_lens_this_time, + is_paused, + reinterpret_cast(token_ids_all), + reinterpret_cast(prompt_lens), + reinterpret_cast(step_idx), + reinterpret_cast(end_tokens), + reinterpret_cast(max_dec_len), + real_bsz, + max_bsz, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + KERNEL_ASSERT_SUCCESS(ctx, ret_xre); + return api::SUCCESS; +} + +int unified_update_model_status(api::Context *ctx, + int *seq_lens_encoder, + int *seq_lens_decoder, + bool *has_running_seqs, + int *mask_rollback, + int64_t *step_input_ids, + int *adaptive_step_input_len, + int64_t *step_output_ids, + int *step_output_len, + bool *stop_flags, + int *seq_lens_this_time, + const bool *is_paused, + int64_t *token_ids_all, + const int64_t *prompt_lens, + int64_t *step_idx, + const int64_t *end_tokens, + const int64_t *max_dec_len, + int real_bsz, + int max_bsz, + int max_step_tokens, + int max_model_len, + int num_end_tokens, + bool is_naive_mode, + bool prefill_one_step_stop) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "unified_update_model_status", int); + WRAPPER_DUMP_PARAM6(ctx, + seq_lens_encoder, + seq_lens_decoder, + has_running_seqs, + mask_rollback, + step_input_ids, + adaptive_step_input_len); + WRAPPER_DUMP_PARAM6(ctx, + step_output_ids, + step_output_len, + stop_flags, + seq_lens_this_time, + is_paused, + token_ids_all); + WRAPPER_DUMP_PARAM6( + ctx, prompt_lens, step_idx, end_tokens, max_dec_len, real_bsz, max_bsz); + WRAPPER_DUMP_PARAM5(ctx, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, bool, 1, has_running_seqs); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, mask_rollback); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_input_ids); + // WRAPPER_CHECK_PTR(ctx, int, 0, adaptive_step_input_len); // Temporarily + // unused + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_output_ids); + WRAPPER_CHECK_PTR(ctx, int, max_bsz, step_output_len); + WRAPPER_CHECK_PTR(ctx, bool, max_bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, bool, max_bsz, is_paused); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_model_len, token_ids_all); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, prompt_lens); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, step_idx); + WRAPPER_CHECK_PTR(ctx, int64_t, num_end_tokens, end_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, max_dec_len); + WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz); + WRAPPER_ASSERT_GE(ctx, 1024, num_end_tokens); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + seq_lens_encoder, + seq_lens_decoder, + has_running_seqs, + mask_rollback, + step_input_ids, + adaptive_step_input_len, + step_output_ids, + step_output_len, + stop_flags, + seq_lens_this_time, + is_paused, + token_ids_all, + prompt_lens, + step_idx, + end_tokens, + max_dec_len, + real_bsz, + max_bsz, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + seq_lens_encoder, + seq_lens_decoder, + has_running_seqs, + mask_rollback, + step_input_ids, + adaptive_step_input_len, + step_output_ids, + step_output_len, + stop_flags, + seq_lens_this_time, + is_paused, + token_ids_all, + prompt_lens, + step_idx, + end_tokens, + max_dec_len, + real_bsz, + max_bsz, + max_step_tokens, + max_model_len, + num_end_tokens, + is_naive_mode, + prefill_one_step_stop); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace fastdeploy diff --git a/custom_ops/xpu_ops/test/test_speculate_pre_process.py b/custom_ops/xpu_ops/test/test_speculate_pre_process.py new file mode 100644 index 00000000000..059ac6860e9 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_pre_process.py @@ -0,0 +1,328 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_pre_process + + +def speculate_pre_process_ref( + input_ids, + seq_lens, + draft_tokens, + seq_lens_encoder, + max_seq_len, + max_draft_tokens_per_batch, + real_bsz, + token_num, +): + """ + Python reference implementation for SpeculatePreProcessKernel. + + Returns: + ids_remove_padding: int64[token_num] + batch_id_per_token: int32[token_num] + cu_seqlens_q: int32[real_bsz + 1] + cu_seqlens_k: int32[real_bsz + 1] + seq_lens_output: int32[real_bsz] + cu_seq_lens_q_output: int32[real_bsz + 1] + batch_id_per_token_output: int32[real_bsz * max_draft_tokens_per_batch] + real_output_token_num: int32[1] + """ + # --- Part 1: ids_remove_padding, batch_id_per_token, cu_seqlens_q/k --- + ids_remove_padding = np.zeros(token_num, dtype=np.int64) + batch_id_per_token = np.zeros(token_num, dtype=np.int32) + cu_seqlens_q = np.zeros(real_bsz + 1, dtype=np.int32) + cu_seqlens_k = np.zeros(real_bsz + 1, dtype=np.int32) + + cum = 0 + for bi in range(real_bsz): + cum += seq_lens[bi] + cu_seqlens_q[bi + 1] = cum + cu_seqlens_k[bi + 1] = cum + + start = cum - seq_lens[bi] + for i in range(seq_lens[bi]): + tgt = start + i + if max_draft_tokens_per_batch > 0 and seq_lens_encoder[bi] <= 0: + src = bi * max_draft_tokens_per_batch + i + ids_remove_padding[tgt] = draft_tokens[src] + else: + src = bi * max_seq_len + i + ids_remove_padding[tgt] = input_ids[src] + batch_id_per_token[tgt] = bi + + # --- Part 2: seq_lens_output --- + seq_lens_output = np.zeros(real_bsz, dtype=np.int32) + for bid in range(real_bsz): + if seq_lens[bid] == 0: + seq_lens_output[bid] = 0 + elif seq_lens[bid] == 1: + seq_lens_output[bid] = 1 + elif seq_lens_encoder[bid] != 0: + seq_lens_output[bid] = 1 + else: + seq_lens_output[bid] = seq_lens[bid] + + # --- Part 3: cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num --- + cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32) + batch_id_per_token_output = np.zeros(real_bsz * max_draft_tokens_per_batch, dtype=np.int32) + + cum_output = 0 + for bi in range(real_bsz): + cum_output += seq_lens_output[bi] + cu_seq_lens_q_output[bi + 1] = cum_output + + start_out = cum_output - seq_lens_output[bi] + for i in range(seq_lens_output[bi]): + batch_id_per_token_output[start_out + i] = bi + + real_output_token_num = np.array([cum_output], dtype=np.int32) + + return ( + ids_remove_padding, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_output, + cu_seq_lens_q_output, + batch_id_per_token_output, + real_output_token_num, + ) + + +def build_inputs( + real_bsz, + max_seq_len, + max_draft_tokens, + seq_lens_list, + seq_lens_encoder_list, + draft_tokens_data=None, + input_ids_data=None, + seed=42, +): + """ + Helper to build test inputs from explicit seq_lens and seq_lens_encoder lists. + draft_tokens_data and input_ids_data are optional; if None, random data is used. + """ + rng = np.random.default_rng(seed) + seq_lens = np.array(seq_lens_list, dtype=np.int32) + seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32) + seq_lens_decoder = np.zeros(real_bsz, dtype=np.int32) # not used in kernel logic + + token_num = int(np.sum(seq_lens)) + + if input_ids_data is not None: + input_ids = np.array(input_ids_data, dtype=np.int64).reshape(real_bsz, max_seq_len) + else: + input_ids = rng.integers(1, 1000, size=(real_bsz, max_seq_len), dtype=np.int64) + + if draft_tokens_data is not None: + draft_tokens = np.array(draft_tokens_data, dtype=np.int64).reshape(real_bsz, max_draft_tokens) + else: + draft_tokens = rng.integers(1, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) + + return { + "input_ids": input_ids, + "seq_lens": seq_lens, + "draft_tokens": draft_tokens, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "max_seq_len": max_seq_len, + "max_draft_tokens": max_draft_tokens, + "token_num": token_num, + "real_bsz": real_bsz, + } + + +def run_and_compare(tc, inputs): + """ + Call GPU op and Python reference, compare all outputs. + tc: unittest.TestCase instance (for assertion messages). + """ + real_bsz = inputs["real_bsz"] + max_seq_len = inputs["max_seq_len"] + max_draft_tokens = inputs["max_draft_tokens"] + token_num = inputs["token_num"] + + t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") + t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32") + t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64") + t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32") + t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32") + + gpu_outs = speculate_pre_process( + token_num, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder + ) + + ref_outs = speculate_pre_process_ref( + input_ids=inputs["input_ids"].reshape(-1), + seq_lens=inputs["seq_lens"], + draft_tokens=inputs["draft_tokens"].reshape(-1), + seq_lens_encoder=inputs["seq_lens_encoder"], + max_seq_len=max_seq_len, + max_draft_tokens_per_batch=max_draft_tokens, + real_bsz=real_bsz, + token_num=token_num, + ) + + output_names = [ + "ids_remove_padding", + "batch_id_per_token", + "cu_seqlens_q", + "cu_seqlens_k", + "cu_seq_lens_q_output", + "batch_id_per_token_output", + "real_output_token_num", + ] + # GPU op returns 7 tensors; ref returns 8 (with seq_lens_output at index 4). + # GPU output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, + # cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num + # Ref output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, + # seq_lens_output, cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num + ref_indices = [0, 1, 2, 3, 5, 6, 7] # skip seq_lens_output (index 4) for direct comparison + for name, gpu_idx, ref_idx in zip(output_names, range(7), ref_indices): + gpu_val = gpu_outs[gpu_idx].numpy() + ref_val = ref_outs[ref_idx] + # Trim batch_id_per_token_output to the valid portion (real_output_token_num) + # The kernel only writes valid positions; beyond that the content is undefined. + if name == "batch_id_per_token_output": + valid_len = int(ref_outs[7][0]) # real_output_token_num + gpu_val = gpu_val[:valid_len] + ref_val = ref_val[:valid_len] + np.testing.assert_allclose( + gpu_val, + ref_val, + err_msg=f"Mismatch in output '{name}'", + ) + + +class TestSpeculatePreProcess(unittest.TestCase): + """Unit tests for speculate_pre_process custom operator.""" + + # ---------------------------------------------------------------- + # Test 1: mixed batch covering all 4 seq_lens_output branches + # bid=0: seq_lens=0 => output=0 (skip) + # bid=1: seq_lens=1, encoder=0 => output=1, read draft_tokens + # bid=2: seq_lens=5, encoder=3 => output=1, read input_ids (prefill) + # bid=3: seq_lens=4, encoder=0 => output=4, read draft_tokens (decode) + # bid=4: seq_lens=1, encoder=2 => output=1, read input_ids (prefill single) + # bid=5: seq_lens=8, encoder=0 => output=8, read draft_tokens (decode saturated) + # ---------------------------------------------------------------- + def test_mixed_batch_all_branches(self): + inputs = build_inputs( + real_bsz=6, + max_seq_len=16, + max_draft_tokens=8, + seq_lens_list=[0, 1, 5, 4, 1, 8], + seq_lens_encoder_list=[0, 0, 3, 0, 2, 0], + ) + run_and_compare(self, inputs) + + # ---------------------------------------------------------------- + # Test 2: token_num=0 early return — verify no crash, 7 outputs + # ---------------------------------------------------------------- + def test_all_zero_seq_lens(self): + real_bsz = 3 + t_input_ids = paddle.zeros([real_bsz, 8], dtype="int64") + t_seq_lens = paddle.zeros([real_bsz], dtype="int32") + t_draft_tokens = paddle.zeros([real_bsz, 4], dtype="int64") + t_seq_lens_encoder = paddle.zeros([real_bsz], dtype="int32") + t_seq_lens_decoder = paddle.zeros([real_bsz], dtype="int32") + + gpu_outs = speculate_pre_process( + 0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder + ) + self.assertEqual(len(gpu_outs), 7) + self.assertIsNotNone(gpu_outs[-3]) + self.assertIsNotNone(gpu_outs[-2]) + self.assertIsNotNone(gpu_outs[-1]) + # test copy + fake_cu_seqlens_q_output = paddle.empty([real_bsz + 1], dtype="int32") + fake_batch_id_per_token_output = paddle.empty([real_bsz], dtype="int32") + fake_cu_seqlens_q_output.copy_(gpu_outs[-3]) + fake_batch_id_per_token_output.copy_(gpu_outs[-2]) + # test slice + fake_batch_id_per_token_output[: gpu_outs[-1].item()] + + # ---------------------------------------------------------------- + # Test 3: exact token values — manually verify ids_remove_padding + # bid=0: encoder=0 (decode) => draft_tokens[0][0:3] = [10,11,12] + # bid=1: encoder=5 (prefill) => input_ids[1][0:2] = [200,201] + # ---------------------------------------------------------------- + def test_exact_token_values(self): + inputs = build_inputs( + real_bsz=2, + max_seq_len=4, + max_draft_tokens=4, + seq_lens_list=[3, 2], + seq_lens_encoder_list=[0, 5], + draft_tokens_data=[[10, 11, 12, 13], [20, 21, 22, 23]], + input_ids_data=[[100, 101, 102, 103], [200, 201, 202, 203]], + ) + + t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") + t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32") + t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64") + t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32") + t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32") + + gpu_outs = speculate_pre_process( + int(np.sum(inputs["seq_lens"])), + t_input_ids, + t_seq_lens, + t_draft_tokens, + t_seq_lens_encoder, + t_seq_lens_decoder, + ) + + np.testing.assert_allclose(gpu_outs[0].numpy(), [10, 11, 12, 200, 201]) + np.testing.assert_allclose(gpu_outs[1].numpy(), [0, 0, 0, 1, 1]) + np.testing.assert_allclose(gpu_outs[2].numpy(), [0, 3, 5]) + np.testing.assert_allclose(gpu_outs[6].numpy(), [4]) # real_output_token_num = 3+1 + + # ---------------------------------------------------------------- + # Test 4: random stress test (2 configs covering small & medium batch) + # ---------------------------------------------------------------- + def test_random_configs(self): + configs = [ + {"real_bsz": 7, "max_seq_len": 32, "max_draft_tokens": 8, "seed": 200}, + {"real_bsz": 32, "max_seq_len": 128, "max_draft_tokens": 16, "seed": 400}, + ] + for cfg in configs: + with self.subTest(**cfg): + rng = np.random.default_rng(cfg["seed"]) + real_bsz = cfg["real_bsz"] + max_draft = cfg["max_draft_tokens"] + seq_lens_list = rng.integers(0, max_draft + 1, size=real_bsz).tolist() + seq_lens_encoder_list = rng.integers(0, 3, size=real_bsz).tolist() + + inputs = build_inputs( + real_bsz=real_bsz, + max_seq_len=cfg["max_seq_len"], + max_draft_tokens=max_draft, + seq_lens_list=seq_lens_list, + seq_lens_encoder_list=seq_lens_encoder_list, + seed=cfg["seed"], + ) + if inputs["token_num"] == 0: + continue + run_and_compare(self, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_unified_update_model_status.py b/custom_ops/xpu_ops/test/test_unified_update_model_status.py new file mode 100644 index 00000000000..0c0c9d47a6b --- /dev/null +++ b/custom_ops/xpu_ops/test/test_unified_update_model_status.py @@ -0,0 +1,574 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for unified_update_model_status kernel. + +Kernel semantics (from unified_update_model_status.cu): + - Launched as <<<1, 1024>>>, one thread per batch slot (max_bsz <= 1024). + - real_bsz = seq_lens_this_time.shape[0], max_bsz = stop_flags.shape[0]. + - has_running_seqs is a CPU tensor (copied to GPU, kernel writes, copied back). + - Padding slots (batch_id >= real_bsz): only counted as stopped, NO state modified. + - Stopped/paused real slots: set stop_flags=true, seq_lens_decoder=0, + seq_lens_this_time=0, step_output_len=0. + - Running slots: EOS detection → state update → token_ids_all write → next input setup. +""" + +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import unified_update_model_status + +CUDA_PLACE = paddle.XPUPlace(0) +CPU_PLACE = paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers — tensor creation / kernel invocation / output extraction +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict → paddle tensors. has_running_seqs goes to CPU.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif k == "has_running_seqs": + # Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU) + paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE) + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """Call unified_update_model_status kernel.""" + unified_update_model_status( + paddle_inputs["seq_lens_encoder"], + paddle_inputs["seq_lens_decoder"], + paddle_inputs["has_running_seqs"], + paddle_inputs["step_input_ids"], + paddle_inputs["adaptive_step_input_len"], + paddle_inputs["step_output_ids"], + paddle_inputs["step_output_len"], + paddle_inputs["stop_flags"], + paddle_inputs["seq_lens_this_time"], + paddle_inputs["is_paused"], + paddle_inputs["mask_rollback"], + paddle_inputs["token_ids_all"], + paddle_inputs["prompt_lens"], + paddle_inputs["step_idx"], + paddle_inputs["end_tokens"], + paddle_inputs["max_dec_len"], + inputs["is_naive_mode"], + inputs["prefill_one_step_stop"], + ) + + +# All 12 in-place output keys (from SetInplaceMap in .cu) +OUTPUT_KEYS = [ + "seq_lens_encoder", + "seq_lens_decoder", + "has_running_seqs", + "step_input_ids", + "step_output_ids", + "step_output_len", + "stop_flags", + "seq_lens_this_time", + "mask_rollback", + "token_ids_all", + "step_idx", + # adaptive_step_input_len is in InplaceMap but kernel never writes it +] + + +def get_outputs(paddle_inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: + """Extract ALL in-place-modified tensors back to numpy.""" + return {k: paddle_inputs[k].numpy() for k in OUTPUT_KEYS} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + + +def gen_inputs( + real_bsz: int = 8, + max_step_tokens: int = 16, + max_model_len: int = 256, + seed: int = 42, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, +) -> Dict[str, Any]: + """Generate randomized test inputs for unified_update_model_status kernel. + + Shapes follow the kernel contract: + - real_bsz = seq_lens_this_time.shape[0] + - max_bsz = stop_flags.shape[0] (= real_bsz + padding) + - is_paused.shape[0] = max_bsz + """ + rng = np.random.default_rng(seed) + max_bsz = real_bsz + 4 # padding slots + + # Per-slot arrays (size=max_bsz) + seq_lens_encoder = rng.integers(0, 5, size=max_bsz, dtype=np.int32) + seq_lens_decoder = rng.integers(10, 100, size=max_bsz, dtype=np.int32) + step_input_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64) + adaptive_step_input_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32) + step_output_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64) + step_output_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32) + stop_flags = np.zeros(max_bsz, dtype=bool) + # Randomly stop a few real slots + stop_flags[rng.choice(real_bsz, size=min(2, real_bsz), replace=False)] = True + # Padding slots (batch_id >= real_bsz) must be stopped — kernel accesses + # seq_lens_this_time[batch_id] which is only sized real_bsz + stop_flags[real_bsz:] = True + is_paused = np.zeros(max_bsz, dtype=bool) + mask_rollback = np.zeros(max_bsz, dtype=np.int32) + prompt_lens = rng.integers(10, 50, size=max_bsz, dtype=np.int64) + token_ids_all = rng.integers(0, 1000, size=(max_bsz, max_model_len), dtype=np.int64) + step_idx = rng.integers(0, 50, size=max_bsz, dtype=np.int64) + max_dec_len = rng.integers(100, 200, size=max_bsz, dtype=np.int64) + + # Per-real-batch arrays (size=real_bsz) + seq_lens_this_time = rng.integers(1, max_step_tokens + 1, size=real_bsz, dtype=np.int32) + + # Scalar / small tensors + has_running_seqs = np.array([True], dtype=bool) + end_tokens = rng.integers(1, 1000, size=4, dtype=np.int64) + + return { + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "has_running_seqs": has_running_seqs, + "step_input_ids": step_input_ids, + "adaptive_step_input_len": adaptive_step_input_len, + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "stop_flags": stop_flags, + "seq_lens_this_time": seq_lens_this_time, + "is_paused": is_paused, + "mask_rollback": mask_rollback, + "token_ids_all": token_ids_all, + "prompt_lens": prompt_lens, + "step_idx": step_idx, + "end_tokens": end_tokens, + "max_dec_len": max_dec_len, + # Scalar configs + "real_bsz": real_bsz, + "max_bsz": max_bsz, + "max_step_tokens": max_step_tokens, + "max_model_len": max_model_len, + "is_naive_mode": is_naive_mode, + "prefill_one_step_stop": prefill_one_step_stop, + } + + +# ============================================================ +# Layer 3: Reference implementation (1:1 with CUDA kernel) +# ============================================================ + + +def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference of unified_update_model_status_kernel. + + Line references are to unified_update_model_status.cu. + """ + # Deep-copy all mutable in-place tensors + seq_lens_encoder = inputs["seq_lens_encoder"].copy() + seq_lens_decoder = inputs["seq_lens_decoder"].copy() + step_output_len = inputs["step_output_len"].copy() + stop_flags = inputs["stop_flags"].copy() + seq_lens_this_time = inputs["seq_lens_this_time"].copy() + mask_rollback = inputs["mask_rollback"].copy() + token_ids_all = inputs["token_ids_all"].copy() + step_idx = inputs["step_idx"].copy() + step_input_ids = inputs["step_input_ids"].copy() + step_output_ids = inputs["step_output_ids"].copy() + + # Read-only inputs + real_bsz = inputs["real_bsz"] + max_bsz = inputs["max_bsz"] + max_model_len = inputs["max_model_len"] + is_naive_mode = inputs["is_naive_mode"] + prefill_one_step_stop = inputs["prefill_one_step_stop"] + end_tokens = inputs["end_tokens"] + num_end_tokens = len(end_tokens) + max_dec_len = inputs["max_dec_len"] + prompt_lens = inputs["prompt_lens"] + is_paused = inputs["is_paused"] + + # Block-level stop count for has_running_seqs reduction (line 175) + stop_count = 0 + + for batch_id in range(max_bsz): + # --- line 68-75: Read state --- + cur_seq_len_encoder = int(seq_lens_encoder[batch_id]) + cur_seq_len_decoder = int(seq_lens_decoder[batch_id]) + cur_stop_flag = bool(stop_flags[batch_id]) + output_len = 0 + cur_step_idx = int(step_idx[batch_id]) + cur_is_paused = bool(is_paused[batch_id]) + + # line 77 + is_running = not cur_stop_flag and not cur_is_paused + + # --- line 80-86: Compute output length --- + if is_running: + output_len = 1 if is_naive_mode else int(step_output_len[batch_id]) + + # --- line 89-110: EOS detection --- + if is_running and output_len > 0: + hit_stop = False + for i in range(output_len): + cur_step_idx += 1 # line 94 + token = int(step_output_ids[batch_id, i]) # line 95 + is_eos = any(token == end_tokens[j] for j in range(num_end_tokens)) # line 96 + max_len_hit = cur_step_idx >= int(max_dec_len[batch_id]) # line 97 + + if is_eos or max_len_hit: # line 99 + if not is_eos: + step_output_ids[batch_id, i] = end_tokens[0] # line 100 + output_len = i + 1 # line 101 + cur_stop_flag = True # line 102 + hit_stop = True # line 103 + break # line 104 + + # line 108-110 + if not hit_stop and prefill_one_step_stop and cur_seq_len_encoder > 0: + cur_stop_flag = True + + # --- line 114-166: Update state and write back --- + if is_running: + if cur_stop_flag: + # line 115-119 + stop_count += 1 + if output_len == 0: + cur_seq_len_decoder = 0 # line 117 + stop_flags[batch_id] = True # line 118 + mask_rollback[batch_id] = 0 # line 119 + elif cur_seq_len_encoder == 0: + # line 120-122 + cur_seq_len_decoder += output_len # line 121 + mask_rollback[batch_id] = int(seq_lens_this_time[batch_id]) - output_len # line 122 + else: + # line 123-124 (encoder > 0, not stopped) + mask_rollback[batch_id] = 0 + + # line 127-130: Fold encoder into decoder + if cur_seq_len_encoder > 0: + cur_seq_len_decoder += cur_seq_len_encoder # line 128 + cur_seq_len_encoder = 0 # line 129 + + # line 132-135: Write back scalar state + seq_lens_encoder[batch_id] = cur_seq_len_encoder + seq_lens_decoder[batch_id] = cur_seq_len_decoder + step_output_len[batch_id] = output_len + step_idx[batch_id] = cur_step_idx + + # line 138-145: Write history to token_ids_all + if cur_step_idx > 0 and output_len > 0: + base = int(prompt_lens[batch_id]) + for i in range(output_len): + # token_ids_all_now[cur_step_idx - i] = output_ids[output_len - 1 - i] + write_idx = base + cur_step_idx - i + if 0 <= write_idx < max_model_len: + token_ids_all[batch_id, write_idx] = step_output_ids[batch_id, output_len - 1 - i] + + # line 148-151: Setup next step_input_ids + if output_len > 0: + step_input_ids[batch_id, 0] = step_output_ids[batch_id, output_len - 1] + + # line 153-155: naive_mode → seq_lens_this_time + if is_naive_mode: + seq_lens_this_time[batch_id] = 0 if cur_stop_flag else 1 + + elif batch_id >= real_bsz: + # line 156-158: Padding slot — only count, don't modify state + stop_count += 1 + else: + # line 159-166: Stopped or paused real slot + stop_count += 1 + stop_flags[batch_id] = True # line 162 + seq_lens_decoder[batch_id] = 0 # line 163 + seq_lens_this_time[batch_id] = 0 # line 164 + step_output_len[batch_id] = 0 # line 165 + + # line 177-179: has_running_seqs = stop_sum < max_bsz + has_running_seqs = np.array([stop_count < max_bsz], dtype=bool) + + return { + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "has_running_seqs": has_running_seqs, + "step_input_ids": step_input_ids, + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "stop_flags": stop_flags, + "seq_lens_this_time": seq_lens_this_time, + "mask_rollback": mask_rollback, + "token_ids_all": token_ids_all, + "step_idx": step_idx, + } + + +# ============================================================ +# Layer 4a: TEST_CONFIGS +# ============================================================ + +TEST_CONFIGS = [ + # --- basic mode coverage --- + { + "name": "mtp_mode", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 42, + "is_naive_mode": False, + }, + { + "name": "naive_mode", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 42, + "is_naive_mode": True, + }, + # --- batch size --- + { + "name": "small_batch", + "real_bsz": 1, + "max_step_tokens": 8, + "max_model_len": 128, + "seed": 42, + "is_naive_mode": False, + }, + { + "name": "large_batch", + "real_bsz": 32, + "max_step_tokens": 16, + "max_model_len": 512, + "seed": 42, + "is_naive_mode": False, + }, + # --- prefill_one_step_stop --- + { + "name": "prefill_one_step_stop", + "real_bsz": 8, + "max_step_tokens": 8, + "max_model_len": 128, + "seed": 42, + "is_naive_mode": False, + "prefill_one_step_stop": True, + }, + # --- different seeds for randomized coverage --- + { + "name": "seed_100", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 100, + "is_naive_mode": False, + }, + { + "name": "seed_200_naive", + "real_bsz": 8, + "max_step_tokens": 16, + "max_model_len": 256, + "seed": 200, + "is_naive_mode": True, + }, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + + +class TestUnifiedUpdateModelStatus(unittest.TestCase): + + def setUp(self): + if not paddle.is_compiled_with_xpu(): + self.skipTest("Requires XPU") + + # ------ shared helpers ------ + + def _run_and_get(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + return get_outputs(paddle_inputs) + + def _check_all_outputs(self, inputs: Dict[str, Any], outputs: Dict[str, np.ndarray]): + """Compare ALL output tensors against reference + sanity checks.""" + ref = reference_impl(inputs) + for key in OUTPUT_KEYS: + if not np.array_equal(outputs[key], ref[key]): + diff_mask = outputs[key] != ref[key] + diff_indices = np.argwhere(diff_mask) + for idx in diff_indices[:10]: # print first 10 mismatches + idx_tuple = tuple(idx) + print( + f" [{key}] mismatch at {idx_tuple}: " + f"gpu={outputs[key][idx_tuple]} ref={ref[key][idx_tuple]}" + ) + if key == "token_ids_all": + bid = idx_tuple[0] + print( + f" batch_id={bid}, prompt_lens={inputs['prompt_lens'][bid]}, " + f"step_idx(input)={inputs['step_idx'][bid]}, " + f"step_idx(gpu)={outputs['step_idx'][bid]}, " + f"step_idx(ref)={ref['step_idx'][bid]}, " + f"step_output_len(gpu)={outputs['step_output_len'][bid]}, " + f"step_output_len(ref)={ref['step_output_len'][bid]}, " + f"stop_flags(input)={inputs['stop_flags'][bid]}, " + f"is_paused={inputs['is_paused'][bid]}, " + f"seq_lens_encoder={inputs['seq_lens_encoder'][bid]}" + ) + np.testing.assert_array_equal(outputs[key], ref[key], err_msg=f"{key} mismatch") + + # Sanity: running slots must have encoder zeroed + for i in range(inputs["real_bsz"]): + if not inputs["stop_flags"][i] and not inputs["is_paused"][i]: + self.assertEqual(outputs["seq_lens_encoder"][i], 0, f"Running slot {i} should have encoder=0") + self.assertTrue(np.all(outputs["seq_lens_decoder"] >= 0), "negative seq_lens_decoder") + self.assertTrue(np.all(outputs["step_output_len"] >= 0), "negative step_output_len") + self.assertTrue(np.all(outputs["step_idx"] >= 0), "negative step_idx") + + def _run_full_test(self, config: Dict[str, Any]) -> Dict[str, np.ndarray]: + inputs = gen_inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest.""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + def test_eos_detection(self): + """EOS token at position 2 should truncate output_len to 3.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + eos_token = int(inputs["end_tokens"][0]) + inputs["step_output_ids"][0, 2] = eos_token + inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_max_dec_len_stop(self): + """step_idx near max_dec_len should trigger stop and replace with end_tokens[0].""" + # Use large max_model_len to avoid token_ids_all overflow: + # kernel doesn't bounds-check prompt_lens + step_idx < max_model_len + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=512, seed=42) + inputs["step_idx"][:] = [95, 50, 0, 0, 0, 0] + inputs["max_dec_len"][:] = 100 + inputs["step_output_len"][:] = [10, 5, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_paused_slots(self): + """Paused slots should be treated as stopped/paused (decoder=0, output_len=0).""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["is_paused"][:] = [True, True, False, False, False, False, False, False] + inputs["stop_flags"][: inputs["real_bsz"]] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_all_stopped(self): + """All slots stopped → has_running_seqs should be False.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["stop_flags"][:] = True + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_encoder_to_decoder(self): + """Encoder length should fold into decoder: decoder += encoder, encoder → 0.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["seq_lens_encoder"][:] = [10, 0, 0, 0, 0, 0] + inputs["seq_lens_decoder"][:] = [20, 30, 0, 0, 0, 0] + inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_token_ids_all_writing(self): + """token_ids_all should be written at prompt_lens + step_idx positions.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["step_idx"][:] = [10, 20, 0, 0, 0, 0] + inputs["prompt_lens"][:] = [5, 5, 0, 0, 0, 0] + inputs["step_output_len"][:] = [3, 2, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + inputs["seq_lens_encoder"][:] = 0 + # Use end_tokens that won't collide with output_ids + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + inputs["step_output_ids"][0, :3] = [100, 200, 300] + inputs["step_output_ids"][1, :2] = [400, 500] + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_zero_output_len(self): + """Running slot with output_len=0 in MTP mode: output_len stays 0.""" + inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42) + inputs["step_output_len"][:] = [0, 5, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_prefill_one_step_stop_with_encoder(self): + """prefill_one_step_stop + encoder>0 should stop even without EOS.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42, prefill_one_step_stop=True) + inputs["seq_lens_encoder"][:] = [5, 0, 0, 0, 0, 0, 0, 0] + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + # Ensure no accidental EOS hit + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + def test_mask_rollback(self): + """mask_rollback = seq_lens_this_time - output_len for running decode slots.""" + inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42) + inputs["stop_flags"][: inputs["real_bsz"]] = False + inputs["is_paused"][:] = False + inputs["seq_lens_encoder"][:] = 0 # All decode slots + inputs["seq_lens_this_time"][:] = [6, 4, 8, 3] + inputs["step_output_len"][:] = [3, 2, 5, 1, 0, 0, 0, 0] + # Avoid EOS/max_dec_len hits + inputs["end_tokens"][:] = [9990, 9991, 9992, 9993] + inputs["max_dec_len"][:] = 10000 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9e512f32355..0e20e6ad79a 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -271,6 +271,8 @@ class XPUForwardMeta(ForwardMeta): hidden_states: Optional[paddle.Tensor] = None is_draft: bool = False + # max bs + max_num_seqs: int = 0 def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None): """ diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index 9295c71411a..85565d33efb 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -196,7 +196,6 @@ def forward_mixed( qkv, forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id + 1], - forward_meta.cum_offsets, metadata.rotary_embs, metadata.block_tables, forward_meta.prefix_block_tables, diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 9d87e78cb1a..3ad1c6aa444 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -1069,8 +1069,8 @@ def forward_xpu( sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], - share_inputs["output_padding_offset"], - share_inputs["output_cum_offsets"], + share_inputs["batch_id_per_token_output"], + share_inputs["cu_seqlens_q_output"], max_model_len, sampling_metadata.pre_token_ids, ) @@ -1091,7 +1091,7 @@ def forward_xpu( verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, - share_inputs["output_padding_offset"], + share_inputs["batch_id_per_token_output"], self.speculative_max_candidate_len, max_model_len, ) @@ -1113,7 +1113,7 @@ def forward_xpu( share_inputs["max_dec_len"], sampling_metadata.eos_token_ids, share_inputs["is_block_step"], - share_inputs["output_cum_offsets"], + share_inputs["cu_seqlens_q_output"], actual_candidate_len, share_inputs["actual_draft_token_num"], sampling_metadata.top_p, @@ -1366,8 +1366,8 @@ def forward_xpu( sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], - share_inputs["output_padding_offset"], - share_inputs["output_cum_offsets"], + share_inputs["batch_id_per_token_output"], + share_inputs["cu_seqlens_q_output"], max_model_len, sampling_metadata.pre_token_ids, ) diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 8a449b597d0..9e32ea34876 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -40,9 +40,7 @@ save_output_topk, set_stop_value_multi_ends, speculate_clear_accept_nums, - speculate_get_output_padding_offset, - speculate_get_padding_offset, - speculate_get_seq_lens_output, + speculate_pre_process, speculate_save_output, speculate_set_stop_value_multi_seqs, speculate_set_value_by_flags_and_idx, @@ -109,51 +107,32 @@ def xpu_pre_process( ) -> XPUForwardMeta: """ """ max_len = input_ids.shape[1] - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) + token_num_cpu = paddle.sum(seq_lens_this_time).cpu() if use_speculate_method: ( ids_remove_padding, - cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = speculate_get_padding_offset( - input_ids, - draft_tokens, - cum_offsets_now, - token_num, - seq_lens_this_time, - seq_lens_encoder, - ) - seq_lens_output = speculate_get_seq_lens_output( - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - ) - if isinstance(seq_lens_output, list): - seq_lens_output = seq_lens_output[0] - output_token_num = paddle.sum(seq_lens_output) - output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32") - output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( - output_cum_offsets_tmp, - output_token_num, - seq_lens_output, - max_len, + cu_seqlens_q_output, + batch_id_per_token_output, + real_output_token_num, + ) = speculate_pre_process( + token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder ) - share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) - share_inputs["output_padding_offset"].copy_(output_padding_offset, False) + share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output + share_inputs["batch_id_per_token_output"] = batch_id_per_token_output else: + cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") ( ids_remove_padding, cum_offsets, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + ) = get_padding_offset(input_ids, cum_offsets_now, token_num_cpu, seq_lens_this_time) - share_inputs["cum_offsets"] = cum_offsets share_inputs["batch_id_per_token"] = batch_id_per_token share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_k"] = cu_seqlens_k @@ -165,12 +144,12 @@ def xpu_pre_process( seq_lens_encoder=share_inputs["seq_lens_encoder"], seq_lens_decoder=share_inputs["seq_lens_decoder"], seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], batch_id_per_token=share_inputs["batch_id_per_token"], cu_seqlens_q=share_inputs["cu_seqlens_q"], cu_seqlens_k=share_inputs["cu_seqlens_k"], block_tables=share_inputs["block_tables"], caches=share_inputs["caches"], + max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], ) ( @@ -205,7 +184,6 @@ def xpu_pre_process( adjusted_input = adjust_batch( ids_remove_padding.reshape([-1, 1]), - cum_offsets, xpu_forward_meta.encoder_seq_lod, xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.encoder_batch_idx, @@ -237,7 +215,6 @@ def xpu_pre_process( def xpu_process_output( forward_output, - cum_offsets: paddle.Tensor, xpu_forward_meta: XPUForwardMeta, share_inputs, ) -> paddle.Tensor: @@ -250,7 +227,6 @@ def xpu_process_output( hiddden_states = gather_next_token( forward_output, - cum_offsets, xpu_forward_meta.encoder_seq_lod, xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.encoder_batch_map, @@ -261,7 +237,7 @@ def xpu_process_output( xpu_forward_meta.decoder_batch_map_cpu, xpu_forward_meta.len_info_cpu, output_padding_offset, # output_padding_offset - -1, # max_input_length + xpu_forward_meta.max_num_seqs, ) return hiddden_states diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 5868d0bff0c..6ce328aa283 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -813,11 +813,7 @@ def _post_process(self, sampled_token_ids): # Note(ZKK): # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 - ( - self.model_inputs["cu_seqlens_q_output"] - if current_platform.is_cuda() - else self.model_inputs["output_cum_offsets"] - ), + self.model_inputs["cu_seqlens_q_output"], self.model_inputs["stop_flags"], self.model_inputs["not_need_stop"], self.model_inputs["max_dec_len"], @@ -1111,9 +1107,7 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) - hidden_states = xpu_process_output( - model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs - ) + hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) sampled_token_ids, sampler_output = self.sampler( diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a984e8788c4..1446257d3ae 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1114,6 +1114,8 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: self.cache_config.block_size, self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) + + # TODO(chenhuan): support cached_token_num self.forward_meta = xpu_pre_process( self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], @@ -1577,9 +1579,7 @@ class at the server level, which is too granular for ModelRunner. ) if self.use_cudagraph: model_output = model_output[: self.real_token_num] - hidden_states = xpu_process_output( - model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs - ) + hidden_states = xpu_process_output(model_output, self.forward_meta, self.share_inputs) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) sampler_output = None