@@ -52,12 +52,39 @@ void Batch::add(Sequence* sequence, uint32_t allowed_max_token) {
5252 input_embeddings_vec_.emplace_back (input_embedding);
5353
5454 const auto & mm_data = sequence->get_mm_data ();
55- // if (sequence->is_prefill_stage() && mm_data.valid()) // TODO:Compatible
56- // With Chunked Prefill
57- if ((sequence->kv_state ().kv_cache_tokens_num () <
58- sequence->num_prompt_tokens ()) &&
59- mm_data.valid ())
55+ // if (sequence->is_chunked_prefill_stage() && mm_data.valid())
56+ // TODO:Compatible With Chunked Prefill
57+ if ((sequence->stage () == SequenceStage::PREFILL) && mm_data.valid ()) {
6058 mm_data_vec_.emplace_back (mm_data);
59+ }
60+ }
61+
62+ void Batch::update_forward_type (Sequence* sequence) {
63+ auto stage = sequence->stage ();
64+ switch (batch_forward_type_.value ()) {
65+ case BatchForwardType::PREFILL:
66+ if (stage == SequenceStage::CHUNKED_PREFILL) {
67+ batch_forward_type_ = BatchForwardType::CHUNKED_PREFILL;
68+ } else if (stage == SequenceStage::DECODE) {
69+ batch_forward_type_ = BatchForwardType::MIXED;
70+ }
71+ break ;
72+ case BatchForwardType::CHUNKED_PREFILL:
73+ if (stage == SequenceStage::DECODE) {
74+ batch_forward_type_ = BatchForwardType::MIXED;
75+ }
76+ break ;
77+ case BatchForwardType::DECODE:
78+ if (stage != SequenceStage::DECODE) {
79+ batch_forward_type_ = BatchForwardType::MIXED;
80+ }
81+ break ;
82+ case BatchForwardType::MIXED:
83+ break ;
84+ case BatchForwardType::EMPTY:
85+ batch_forward_type_ = BatchForwardType (static_cast <int32_t >(stage));
86+ break ;
87+ }
6188}
6289
6390void Batch::add (const std::vector<Sequence*>& sequences) {
@@ -75,7 +102,8 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
75102 mm_data_vec_,
76103 swap_block_transfer_infos_,
77104 batch_id_,
78- &args);
105+ &args,
106+ batch_forward_type_);
79107 return builder.build_forward_input (num_decoding_tokens,
80108 min_decoding_batch_size);
81109}
@@ -180,6 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
180208 swap_block_transfer_infos_,
181209 batch_id_,
182210 &args,
211+ batch_forward_type_,
183212 thread_pool);
184213 return builder.build_raw_forward_input (start_idx, end_idx);
185214}
@@ -282,7 +311,7 @@ bool Batch::update_sequence_state(Sequence* seq, bool replace_fake_token) {
282311 // prefill-or-not state of last stage, otherwise, we need the state
283312 // of current stage.
284313 if (FLAGS_enable_chunked_prefill) {
285- if (!replace_fake_token && seq->is_prefill_stage ()) {
314+ if (!replace_fake_token && seq->is_chunked_prefill_stage ()) {
286315 seq->pre_scheduled_step_prefill_queue ().push (true );
287316 // if not replace_fake_token, pop out here to avoid endless growth
288317 if (seq->pre_scheduled_step_prefill_queue ().size () > 2 ) {
0 commit comments