Skip to content

Commit 507e7e9

Browse files
authored
Merge branch 'main' into feature/rec_251126_pr_v2
2 parents 5ac51ad + cc75854 commit 507e7e9

File tree

20 files changed

+63
-102
lines changed

20 files changed

+63
-102
lines changed

xllm/api_service/api_service.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ void CommonCompletionsImpl(std::unique_ptr<Service>& service,
139139
return;
140140
}
141141

142-
auto call = std::make_shared<Call>(ctrl, guard.release(), req_pb, resp_pb);
142+
auto call = std::make_shared<Call>(
143+
ctrl, guard.release(), req_pb, resp_pb, arena != nullptr);
143144
service->process_async(call);
144145
}
145146
} // namespace

xllm/core/common/global_flags.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ DEFINE_double(max_memory_utilization,
135135

136136
DEFINE_int32(max_tokens_per_batch, 20480, "Max number of tokens per batch.");
137137

138-
DEFINE_int32(max_seqs_per_batch, 256, "Max number of sequences per batch.");
138+
DEFINE_int32(max_seqs_per_batch, 1024, "Max number of sequences per batch.");
139139

140140
DEFINE_bool(enable_schedule_overlap,
141141
true,
@@ -172,7 +172,7 @@ DEFINE_int32(ep_size, 1, "Expert parallel size for MoE model.");
172172

173173
DEFINE_string(
174174
communication_backend,
175-
"lccl",
175+
"hccl",
176176
"NPU communication backend.(e.g. lccl, hccl). When enable dp, use hccl.");
177177

178178
// --- ep load balance config ---

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,15 @@ void WorkerService::step(ForwardInput& fwd_input,
136136
}
137137
}
138138
} else {
139+
auto int_options = torch::TensorOptions().device(torch::kCPU);
139140
if (worker_->is_driver()) {
140141
// construct fake output tensor
141-
auto options =
142-
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
143-
auto total_prefill_seq_len = 0;
144-
auto total_num_sequences = 0;
145-
146-
total_num_sequences += fwd_input.input_params.num_sequences;
147-
total_prefill_seq_len += fwd_input.input_params.prefill_seq_len;
148-
149-
next_tokens =
150-
torch::arange(-1,
151-
-1 * (total_num_sequences - total_prefill_seq_len + 1),
152-
-1,
153-
options);
142+
int32_t num_decode_seqs = fwd_input.sampling_params.sample_idxes.size(0);
143+
next_tokens = torch::arange(
144+
-1, -1 * (num_decode_seqs + 1), -1, int_options.dtype(torch::kInt32));
154145
std::move(future).deferValue([](auto&&) {});
155146
}
156-
expert_load_data = torch::zeros({1, 1}).to(torch::kInt64).contiguous();
147+
expert_load_data = torch::zeros({1, 1}, int_options.dtype(torch::kInt64));
157148
}
158149
}
159150

xllm/core/framework/batch/batch.cpp

100755100644
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ std::map<uint32_t, uint32_t> Batch::cal_seq_exchange_index(
196196
return index_shift;
197197
}
198198

199-
RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
200-
uint32_t end_idx,
201-
const ModelArgs& args,
199+
RawForwardInput Batch::prepare_forward_input(const ModelArgs& args,
202200
ThreadPool* thread_pool) {
203201
dp_balance_shuffle_seqs();
204202
BatchInputBuilder builder(sequences_,
@@ -210,7 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
210208
&args,
211209
batch_forward_type_,
212210
thread_pool);
213-
return builder.build_raw_forward_input(start_idx, end_idx);
211+
return builder.build_raw_forward_input();
214212
}
215213

216214
void Batch::process_sample_output(const RawForwardOutput& raw_output,
@@ -341,7 +339,7 @@ void Batch::append_token_for_sequence(Sequence* seq,
341339
seq->pre_scheduled_step_prefill_queue().pop();
342340
}
343341
}
344-
} else {
342+
} else if (!seq->cancelled()) {
345343
// truely update the real token if replace_fake_token
346344
seq->update_last_step_token(token, token_idx);
347345
if (FLAGS_enable_chunked_prefill && token_idx == 0) {

xllm/core/framework/batch/batch.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ class Batch {
8585
const ModelArgs& args);
8686

8787
// Convert Batch to pb type, which will be pass to remote worker.
88-
RawForwardInput prepare_forward_input(uint32_t start_idx,
89-
uint32_t end_idx,
90-
const ModelArgs& args,
88+
RawForwardInput prepare_forward_input(const ModelArgs& args,
9189
ThreadPool* thread_pool);
9290

9391
// process output

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ BatchInputBuilder::BatchInputBuilder(
5353
mm_data_vec_(mm_data_vec),
5454
args_(args),
5555
thread_pool_(thread_pool),
56-
num_sequences_(static_cast<int32_t>(sequences.size())),
56+
num_sequences_(sequences.size()),
5757
swap_block_transfer_infos_(swap_block_transfer_infos),
5858
batch_id_(batch_id) {
5959
// Reserve space for better performance
@@ -72,35 +72,31 @@ BatchInputBuilder::BatchInputBuilder(
7272
ForwardInput BatchInputBuilder::build_forward_input(
7373
uint32_t num_decoding_tokens,
7474
uint32_t min_decoding_batch_size) {
75-
process_sequences(0, static_cast<uint32_t>(num_sequences_));
75+
process_sequences();
7676
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);
7777

7878
return state_to_forward_input();
7979
}
8080

81-
RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
82-
uint32_t end_idx) {
83-
if (!thread_pool_ ||
84-
end_idx - start_idx < static_cast<uint32_t>(thread_pool_->size())) {
85-
process_sequences(start_idx, end_idx);
81+
RawForwardInput BatchInputBuilder::build_raw_forward_input() {
82+
if (!thread_pool_ || num_sequences_ < thread_pool_->size()) {
83+
process_sequences();
8684
} else {
87-
process_sequences_multithreaded(start_idx, end_idx);
85+
process_sequences_multithreaded();
8886
}
8987
return state_to_raw_forward_input();
9088
}
9189

92-
void BatchInputBuilder::process_sequences(uint32_t start_idx,
93-
uint32_t end_idx) {
94-
for (int32_t i = start_idx; i < end_idx; ++i) {
90+
void BatchInputBuilder::process_sequences() {
91+
for (int32_t i = 0; i < num_sequences_; ++i) {
9592
process_single_sequence(i);
9693
}
9794
}
9895

99-
void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
100-
uint32_t end_idx) {
96+
void BatchInputBuilder::process_sequences_multithreaded() {
10197
const size_t threads_num = thread_pool_->size();
10298
const size_t sequences_per_thread =
103-
(end_idx - start_idx + threads_num - 1) / threads_num;
99+
(num_sequences_ + threads_num - 1) / threads_num;
104100

105101
BlockingCounter counter(threads_num);
106102

@@ -117,17 +113,17 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
117113
BuilderState& state,
118114
std::unordered_set<int32_t>& write_block_ids) {
119115
for (size_t i = thread_start_idx;
120-
i < thread_end_idx && i < static_cast<size_t>(end_idx);
116+
i < thread_end_idx && i < static_cast<size_t>(num_sequences_);
121117
++i) {
122118
process_single_sequence(i, &state, &write_block_ids);
123119
}
124120
};
125121

126122
// Start parallel tasks
127123
for (size_t thread_idx = 0; thread_idx < threads_num; ++thread_idx) {
128-
size_t thread_start_idx = start_idx + thread_idx * sequences_per_thread;
124+
size_t thread_start_idx = thread_idx * sequences_per_thread;
129125
size_t thread_end_idx = std::min(thread_start_idx + sequences_per_thread,
130-
static_cast<size_t>(end_idx));
126+
static_cast<size_t>(num_sequences_));
131127

132128
thread_pool_->schedule([process_sequences_range,
133129
thread_start_idx,
@@ -214,7 +210,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
214210
state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(),
215211
state.new_token_slot_ids.begin(),
216212
state.new_token_slot_ids.end());
217-
state_.prefill_seq_len += state.prefill_seq_len;
218213
state_.embedding_ids.insert(state_.embedding_ids.end(),
219214
state.embedding_ids.begin(),
220215
state.embedding_ids.end());
@@ -306,11 +301,6 @@ void BatchInputBuilder::process_single_sequence(
306301
sequence, n_kv_cache_tokens, seq_len, q_seq_len, state_ptr);
307302
}
308303

309-
// Track prefill sequences
310-
if (sequence->is_chunked_prefill_stage()) {
311-
state.prefill_seq_len++;
312-
}
313-
314304
// Input for beam search kernel
315305
if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() &&
316306
sequence->num_generated_tokens() > 0) {
@@ -658,7 +648,6 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
658648
raw_forward_input.num_sequences = num_sequences_;
659649
// raw_forward_input.dp_global_token_nums = ;
660650
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
661-
raw_forward_input.prefill_seq_len = state_.prefill_seq_len;
662651

663652
// for flashinfer
664653
raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr);

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ class BatchInputBuilder {
4747
ForwardInput build_forward_input(uint32_t num_decoding_tokens,
4848
uint32_t min_decoding_batch_size);
4949

50-
RawForwardInput build_raw_forward_input(uint32_t start_idx, uint32_t end_idx);
50+
RawForwardInput build_raw_forward_input();
5151

5252
private:
5353
// Core building methods
54-
void process_sequences(uint32_t start_idx, uint32_t end_idx);
55-
void process_sequences_multithreaded(uint32_t start_idx, uint32_t end_idx);
54+
void process_sequences();
55+
void process_sequences_multithreaded();
5656
void padding_decode_batch_size(uint32_t num_decoding_tokens,
5757
uint32_t min_decoding_batch_size);
5858
ForwardInput state_to_forward_input();
@@ -100,7 +100,6 @@ class BatchInputBuilder {
100100
// Additional data
101101
std::vector<int32_t> embedding_ids;
102102
std::vector<int32_t> extra_token_ids;
103-
uint32_t prefill_seq_len = 0;
104103
std::vector<TransferKVInfo> transfer_kv_infos;
105104

106105
// for continuous kvcache
@@ -153,7 +152,7 @@ class BatchInputBuilder {
153152

154153
// Configuration
155154
bool use_mrope_ = false;
156-
int32_t num_sequences_ = 0;
155+
uint32_t num_sequences_ = 0;
157156

158157
// copy in and out cache contents
159158
std::unordered_set<int32_t> write_block_ids_;

xllm/core/framework/model/model_input_params.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ struct ModelInputParams {
110110

111111
params.mm_data = MMData::to(mm_data, device);
112112
params.dp_global_token_nums = dp_global_token_nums;
113-
params.prefill_seq_len = prefill_seq_len;
114113
params.embedding_ids = std::move(embedding_ids);
115114
params.extra_token_ids = std::move(extra_token_ids);
116115
params.dp_ep_padding_data = dp_ep_padding_data;
@@ -151,8 +150,7 @@ struct ModelInputParams {
151150
<< " , global_empty_kv_cache is " << global_empty_kv_cache
152151
<< " , num_sequences is " << num_sequences
153152
<< " , kv_max_seq_len is " << kv_max_seq_len
154-
<< " , q_max_seq_len is " << q_max_seq_len
155-
<< " , prefill_seq_len is " << prefill_seq_len;
153+
<< " , q_max_seq_len is " << q_max_seq_len;
156154
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
157155
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
158156
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
@@ -209,9 +207,6 @@ struct ModelInputParams {
209207
// whether the kv-cache is empty for all sequences,mainly used for dp case
210208
bool global_empty_kv_cache = true;
211209

212-
// num of prefill sequence in chunked prefill case
213-
uint32_t prefill_seq_len = 0;
214-
215210
// embedding ids of each sequence
216211
std::vector<int32_t> embedding_ids;
217212

xllm/core/framework/request/request.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ size_t Request::total_num_blocks() {
125125
return num;
126126
}
127127

128+
void Request::set_cancel() {
129+
cancelled_.store(true, std::memory_order_relaxed);
130+
for (const auto& seq : sequences()) {
131+
seq->set_cancel();
132+
}
133+
}
134+
128135
RequestOutput Request::generate_output(const Tokenizer& tokenizer,
129136
ThreadPool* thread_pool) {
130137
// summarize statistics for all sequences
@@ -159,7 +166,7 @@ void Request::update_connection_status() {
159166
if (!is_disconnected) {
160167
return;
161168
}
162-
cancelled_.store(true, std::memory_order_relaxed);
169+
set_cancel();
163170
}
164171

165172
} // namespace xllm

xllm/core/framework/request/request.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Request : public RequestBase {
5555

5656
SequencesGroup* sequence_group() { return sequences_group_.get(); }
5757

58-
void set_cancel() { cancelled_.store(true, std::memory_order_relaxed); }
58+
void set_cancel();
5959

6060
bool cancelled() const { return cancelled_.load(std::memory_order_relaxed); }
6161

0 commit comments

Comments
 (0)