@@ -53,7 +53,7 @@ BatchInputBuilder::BatchInputBuilder(
5353 mm_data_vec_(mm_data_vec),
5454 args_(args),
5555 thread_pool_(thread_pool),
56- num_sequences_(static_cast < int32_t >( sequences.size() )),
56+ num_sequences_(sequences.size()),
5757 swap_block_transfer_infos_(swap_block_transfer_infos),
5858 batch_id_(batch_id) {
5959 // Reserve space for better performance
@@ -72,35 +72,31 @@ BatchInputBuilder::BatchInputBuilder(
7272ForwardInput BatchInputBuilder::build_forward_input (
7373 uint32_t num_decoding_tokens,
7474 uint32_t min_decoding_batch_size) {
75- process_sequences (0 , static_cast < uint32_t >(num_sequences_) );
75+ process_sequences ();
7676 padding_decode_batch_size (num_decoding_tokens, min_decoding_batch_size);
7777
7878 return state_to_forward_input ();
7979}
8080
81- RawForwardInput BatchInputBuilder::build_raw_forward_input (uint32_t start_idx,
82- uint32_t end_idx) {
83- if (!thread_pool_ ||
84- end_idx - start_idx < static_cast <uint32_t >(thread_pool_->size ())) {
85- process_sequences (start_idx, end_idx);
81+ RawForwardInput BatchInputBuilder::build_raw_forward_input () {
82+ if (!thread_pool_ || num_sequences_ < thread_pool_->size ()) {
83+ process_sequences ();
8684 } else {
87- process_sequences_multithreaded (start_idx, end_idx );
85+ process_sequences_multithreaded ();
8886 }
8987 return state_to_raw_forward_input ();
9088}
9189
92- void BatchInputBuilder::process_sequences (uint32_t start_idx,
93- uint32_t end_idx) {
94- for (int32_t i = start_idx; i < end_idx; ++i) {
90+ void BatchInputBuilder::process_sequences () {
91+ for (int32_t i = 0 ; i < num_sequences_; ++i) {
9592 process_single_sequence (i);
9693 }
9794}
9895
99- void BatchInputBuilder::process_sequences_multithreaded (uint32_t start_idx,
100- uint32_t end_idx) {
96+ void BatchInputBuilder::process_sequences_multithreaded () {
10197 const size_t threads_num = thread_pool_->size ();
10298 const size_t sequences_per_thread =
103- (end_idx - start_idx + threads_num - 1 ) / threads_num;
99+ (num_sequences_ + threads_num - 1 ) / threads_num;
104100
105101 BlockingCounter counter (threads_num);
106102
@@ -117,17 +113,17 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
117113 BuilderState& state,
118114 std::unordered_set<int32_t >& write_block_ids) {
119115 for (size_t i = thread_start_idx;
120- i < thread_end_idx && i < static_cast <size_t >(end_idx );
116+ i < thread_end_idx && i < static_cast <size_t >(num_sequences_ );
121117 ++i) {
122118 process_single_sequence (i, &state, &write_block_ids);
123119 }
124120 };
125121
126122 // Start parallel tasks
127123 for (size_t thread_idx = 0 ; thread_idx < threads_num; ++thread_idx) {
128- size_t thread_start_idx = start_idx + thread_idx * sequences_per_thread;
124+ size_t thread_start_idx = thread_idx * sequences_per_thread;
129125 size_t thread_end_idx = std::min (thread_start_idx + sequences_per_thread,
130- static_cast <size_t >(end_idx ));
126+ static_cast <size_t >(num_sequences_ ));
131127
132128 thread_pool_->schedule ([process_sequences_range,
133129 thread_start_idx,
@@ -214,7 +210,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
214210 state_.new_token_slot_ids .insert (state_.new_token_slot_ids .end (),
215211 state.new_token_slot_ids .begin (),
216212 state.new_token_slot_ids .end ());
217- state_.prefill_seq_len += state.prefill_seq_len ;
218213 state_.embedding_ids .insert (state_.embedding_ids .end (),
219214 state.embedding_ids .begin (),
220215 state.embedding_ids .end ());
@@ -306,11 +301,6 @@ void BatchInputBuilder::process_single_sequence(
306301 sequence, n_kv_cache_tokens, seq_len, q_seq_len, state_ptr);
307302 }
308303
309- // Track prefill sequences
310- if (sequence->is_chunked_prefill_stage ()) {
311- state.prefill_seq_len ++;
312- }
313-
314304 // Input for beam search kernel
315305 if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search () &&
316306 sequence->num_generated_tokens () > 0 ) {
@@ -658,7 +648,6 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
658648 raw_forward_input.num_sequences = num_sequences_;
659649 // raw_forward_input.dp_global_token_nums = ;
660650 raw_forward_input.transfer_kv_infos = std::move (state_.transfer_kv_infos );
661- raw_forward_input.prefill_seq_len = state_.prefill_seq_len ;
662651
663652 // for flashinfer
664653 raw_forward_input.paged_kv_indptr = std::move (state_.paged_kv_indptr );
0 commit comments