Skip to content

Commit 2e2a304

Browse files
authored
feat: add batch forward type. (#430)
1 parent 3bce8a9 commit 2e2a304

21 files changed

+195
-40
lines changed

xllm/core/framework/batch/batch.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,39 @@ void Batch::add(Sequence* sequence, uint32_t allowed_max_token) {
5252
input_embeddings_vec_.emplace_back(input_embedding);
5353

5454
const auto& mm_data = sequence->get_mm_data();
55-
// if (sequence->is_prefill_stage() && mm_data.valid()) // TODO:Compatible
56-
// With Chunked Prefill
57-
if ((sequence->kv_state().kv_cache_tokens_num() <
58-
sequence->num_prompt_tokens()) &&
59-
mm_data.valid())
55+
// if (sequence->is_chunked_prefill_stage() && mm_data.valid())
56+
// TODO:Compatible With Chunked Prefill
57+
if ((sequence->stage() == SequenceStage::PREFILL) && mm_data.valid()) {
6058
mm_data_vec_.emplace_back(mm_data);
59+
}
60+
}
61+
62+
void Batch::update_forward_type(Sequence* sequence) {
63+
auto stage = sequence->stage();
64+
switch (batch_forward_type_.value()) {
65+
case BatchForwardType::PREFILL:
66+
if (stage == SequenceStage::CHUNKED_PREFILL) {
67+
batch_forward_type_ = BatchForwardType::CHUNKED_PREFILL;
68+
} else if (stage == SequenceStage::DECODE) {
69+
batch_forward_type_ = BatchForwardType::MIXED;
70+
}
71+
break;
72+
case BatchForwardType::CHUNKED_PREFILL:
73+
if (stage == SequenceStage::DECODE) {
74+
batch_forward_type_ = BatchForwardType::MIXED;
75+
}
76+
break;
77+
case BatchForwardType::DECODE:
78+
if (stage != SequenceStage::DECODE) {
79+
batch_forward_type_ = BatchForwardType::MIXED;
80+
}
81+
break;
82+
case BatchForwardType::MIXED:
83+
break;
84+
case BatchForwardType::EMPTY:
85+
batch_forward_type_ = BatchForwardType(static_cast<int32_t>(stage));
86+
break;
87+
}
6188
}
6289

6390
void Batch::add(const std::vector<Sequence*>& sequences) {
@@ -75,7 +102,8 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
75102
mm_data_vec_,
76103
swap_block_transfer_infos_,
77104
batch_id_,
78-
&args);
105+
&args,
106+
batch_forward_type_);
79107
return builder.build_forward_input(num_decoding_tokens,
80108
min_decoding_batch_size);
81109
}
@@ -180,6 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
180208
swap_block_transfer_infos_,
181209
batch_id_,
182210
&args,
211+
batch_forward_type_,
183212
thread_pool);
184213
return builder.build_raw_forward_input(start_idx, end_idx);
185214
}
@@ -282,7 +311,7 @@ bool Batch::update_sequence_state(Sequence* seq, bool replace_fake_token) {
282311
// prefill-or-not state of last stage, otherwise, we need the state
283312
// of current stage.
284313
if (FLAGS_enable_chunked_prefill) {
285-
if (!replace_fake_token && seq->is_prefill_stage()) {
314+
if (!replace_fake_token && seq->is_chunked_prefill_stage()) {
286315
seq->pre_scheduled_step_prefill_queue().push(true);
287316
// if not replace_fake_token, pop out here to avoid endless growth
288317
if (seq->pre_scheduled_step_prefill_queue().size() > 2) {

xllm/core/framework/batch/batch.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include <limits>
2424
#include <vector>
2525

26+
#include "framework/batch/batch_forward_type.h"
2627
#include "framework/request/mm_data.h"
2728
#include "framework/request/request.h"
2829
#include "framework/request/sequence.h"
@@ -53,6 +54,8 @@ class Batch {
5354
sequence_groups_.push_back(sequence_group);
5455
}
5556

57+
void update_forward_type(Sequence* sequence);
58+
5659
void set_swap_block_transfer_infos(
5760
std::vector<BlockTransferInfo>* swap_block_transfer_infos) {
5861
swap_block_transfer_infos_ = swap_block_transfer_infos;
@@ -113,12 +116,6 @@ class Batch {
113116
return allowed_max_tokens_;
114117
}
115118

116-
void set_batch_prefill_status(const bool all_seqs_in_prefill) {
117-
all_seqs_in_prefill_ = all_seqs_in_prefill;
118-
}
119-
120-
bool get_batch_prefill_status() const { return all_seqs_in_prefill_; }
121-
122119
std::map<uint32_t, uint32_t> cal_seq_exchange_index_test(
123120
std::vector<uint32_t>& kv_cache_tokens_num) {
124121
return cal_seq_exchange_index(kv_cache_tokens_num);
@@ -152,8 +149,7 @@ class Batch {
152149
// mm_data in the batch
153150
std::vector<MMData> mm_data_vec_;
154151

155-
// all sequences in this batch are in prefill stage
156-
bool all_seqs_in_prefill_ = false;
152+
BatchForwardType batch_forward_type_;
157153

158154
uint64_t batch_id_ = UNINITIALIZED_BATCH_ID;
159155
};

xllm/core/framework/batch/batch_factory.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,7 @@ std::vector<Batch> BatchFactory::create_batches(
5656
// if dp enabled, each sequence is required to
5757
// dispatch to the same rank in the whole lifetime
5858
batches[sequence->dp_rank()].add(sequence, token_budget);
59-
if (!((sequence->stage() == SequenceStage::DECODE) &&
60-
(sequence->kv_state().kv_cache_tokens_num() > 0))) {
61-
batches[sequence->dp_rank()].set_batch_prefill_status(true);
62-
}
59+
batches[sequence->dp_rank()].update_forward_type(sequence);
6360
}
6461

6562
if (is_beam_search(running_requests)) {
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
19+
namespace xllm {
20+
21+
class BatchForwardType {
22+
public:
23+
enum Value : int32_t {
24+
// Prefill without using kv cache.
25+
PREFILL = 0,
26+
// Chunked prefill using kv cache.
27+
// No decode sequence in this type.
28+
CHUNKED_PREFILL = 1,
29+
// Decode one token.
30+
// No prefill sequence in this type.
31+
DECODE = 2,
32+
// Mixed prefill and decode in one batch when doing chunked prefill.
33+
MIXED = 3,
34+
// No sequence to forward.
35+
EMPTY = 4,
36+
};
37+
38+
BatchForwardType() : value_(EMPTY) {}
39+
40+
BatchForwardType(int32_t v) : value_(static_cast<Value>(v)) {}
41+
42+
constexpr BatchForwardType(Value v) : value_(v) {}
43+
44+
BatchForwardType& operator=(Value v) {
45+
value_ = v;
46+
return *this;
47+
}
48+
49+
int32_t value() const { return value_; }
50+
51+
bool is_prefill() const { return (value_ == PREFILL); }
52+
53+
bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); }
54+
55+
bool no_decode() const {
56+
return (value_ == PREFILL || value_ == CHUNKED_PREFILL);
57+
}
58+
59+
bool has_decode() const { return (value_ == DECODE || value_ == MIXED); }
60+
61+
bool is_decode() const { return (value_ == DECODE); }
62+
63+
bool is_mixed() const { return (value_ == MIXED); }
64+
65+
bool is_empty() const { return (value_ == EMPTY); }
66+
67+
std::string to_string() const {
68+
switch (value_) {
69+
case PREFILL:
70+
return "PREFILL";
71+
case CHUNKED_PREFILL:
72+
return "CHUNKED_PREFILL";
73+
case DECODE:
74+
return "DECODE";
75+
case MIXED:
76+
return "MIXED";
77+
case EMPTY:
78+
return "EMPTY";
79+
default:
80+
return "UNKNOWN";
81+
}
82+
}
83+
84+
private:
85+
Value value_;
86+
};
87+
} // namespace xllm

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ BatchInputBuilder::BatchInputBuilder(
4545
std::vector<BlockTransferInfo>* swap_block_transfer_infos,
4646
const uint64_t batch_id,
4747
const ModelArgs* args,
48+
BatchForwardType batch_forward_type,
4849
ThreadPool* thread_pool)
4950
: sequences_(sequences),
5051
allowed_max_tokens_(allowed_max_tokens),
@@ -65,6 +66,7 @@ BatchInputBuilder::BatchInputBuilder(
6566
use_mrope_ = (args_->rope_scaling_rope_type() == "mrope");
6667
}
6768
write_block_ids_.clear();
69+
state_.batch_forward_type = batch_forward_type;
6870
}
6971

7072
ForwardInput BatchInputBuilder::build_forward_input(
@@ -305,7 +307,7 @@ void BatchInputBuilder::process_single_sequence(
305307
}
306308

307309
// Track prefill sequences
308-
if (sequence->is_prefill_stage()) {
310+
if (sequence->is_chunked_prefill_stage()) {
309311
state.prefill_seq_len++;
310312
}
311313

@@ -552,6 +554,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
552554

553555
auto& input_params = forward_input.input_params;
554556
input_params.empty_kv_cache = state_.empty_kv_cache;
557+
input_params.batch_forward_type = state_.batch_forward_type;
555558
input_params.num_sequences = state_.block_tables_vec.size();
556559
input_params.kv_max_seq_len = state_.max_seq_len;
557560
input_params.q_max_seq_len = state_.q_max_seq_len;
@@ -645,7 +648,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
645648
raw_forward_input.unique_token_lens_vec =
646649
std::move(state_.unique_token_lens_vec);
647650
raw_forward_input.empty_kv_cache = state_.empty_kv_cache;
648-
// raw_forward_input.global_empty_kv_cache = ;
651+
raw_forward_input.batch_forward_type = state_.batch_forward_type;
649652
raw_forward_input.max_seq_len = state_.max_seq_len;
650653
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
651654
raw_forward_input.seq_lens = std::move(state_.seq_lens);

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BatchInputBuilder {
4141
std::vector<BlockTransferInfo>* swap_block_transfer_infos,
4242
const uint64_t batch_id,
4343
const ModelArgs* args,
44+
BatchForwardType batch_forward_type,
4445
ThreadPool* thread_pool = nullptr);
4546

4647
ForwardInput build_forward_input(uint32_t num_decoding_tokens,
@@ -77,6 +78,7 @@ class BatchInputBuilder {
7778
std::vector<int32_t> unique_token_lens_vec;
7879

7980
// Sequence metadata
81+
BatchForwardType batch_forward_type;
8082
bool empty_kv_cache = true;
8183
uint32_t max_seq_len = 0;
8284
uint32_t q_max_seq_len = 0;

xllm/core/framework/batch/mposition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020
namespace xllm {
2121

2222
torch::Tensor MPositionHelper::get_positions() {
23-
// if (seq_.is_prefill_stage()) {
23+
// if (seq_.is_chunked_prefill_stage()) {
2424
if (seq_.kv_state().kv_cache_tokens_num() < seq_.num_prompt_tokens()) {
2525
auto& mm_data = seq_.get_mm_data();
2626

xllm/core/framework/model/model_input_params.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#if defined(USE_NPU)
2222
#include "platform/npu/npu_layer_synchronizer.h"
2323
#endif
24+
#include "framework/batch/batch_forward_type.h"
2425
#include "framework/request/mm_data.h"
2526
#include "npu_dp_ep_padding.h"
2627
#include "util/tensor_helper.h"
@@ -88,6 +89,7 @@ struct ModelInputParams {
8889
ModelInputParams params;
8990
params.empty_kv_cache = empty_kv_cache;
9091
params.global_empty_kv_cache = global_empty_kv_cache;
92+
params.batch_forward_type = batch_forward_type;
9193
params.num_sequences = num_sequences;
9294
params.kv_max_seq_len = kv_max_seq_len;
9395
params.q_max_seq_len = q_max_seq_len;
@@ -163,6 +165,7 @@ struct ModelInputParams {
163165
}
164166
// whether the kv-cache is empty for all sequences.
165167
bool empty_kv_cache = true;
168+
BatchForwardType batch_forward_type;
166169

167170
// total number of sequences in the batch
168171
int32_t num_sequences = 0;

xllm/core/framework/request/request.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ class Request : public RequestBase {
9292
return state_.sampling_param.beam_width > 1;
9393
}
9494

95-
bool is_prefill_stage() const { return sequences_group_->is_prefill_stage(); }
95+
bool is_chunked_prefill_stage() const {
96+
return sequences_group_->is_chunked_prefill_stage();
97+
}
9698

9799
private:
98100
RequestState state_;

xllm/core/framework/request/sequence.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void Sequence::append_token(const Token& token) {
101101
CHECK_LT(num_tokens_, tokens_.size())
102102
<< "exceed the token capacity of the sequence";
103103
CHECK(!finished_) << "cannot append token to a finished sequence";
104-
CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_prefill_stage())
104+
CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage())
105105
<< "cannot append token to a prefill sequence";
106106

107107
if (!sequence_params_.enable_schedule_overlap) {

0 commit comments

Comments
 (0)