@@ -90,6 +90,9 @@ def form_vllm_batch_prefill(
9090 new_request_data_list = []
9191 num_scheduled_tokens : Dict [str , int ] = {}
9292 total_tokens = 0
93+
94+ # Check if this is a non-first peer (IntermediateRequest with hidden_states)
95+ is_first_peer = model_runner .is_first_peer if hasattr (model_runner , "is_first_peer" ) else True
9396
9497 for req in batched_requests :
9598 sampling_params = transform_sampling_params_to_vllm (req .sampling_params )
@@ -99,7 +102,14 @@ def form_vllm_batch_prefill(
99102
100103 computed_blocks , num_computed_tokens = kv_cache_manager .get_computed_blocks (vllm_req )
101104
102- prompt_token_ids = getattr (req , "input_ids" , None ) or []
105+ # For non-first peers, use hidden_states shape instead of input_ids length
106+ if not is_first_peer and hasattr (req , "hidden_states" ) and req .hidden_states is not None :
107+ # hidden_states shape: (num_tokens, hidden_size)
108+ num_tokens = req .hidden_states .shape [0 ]
109+ prompt_token_ids = req .input_ids [:num_tokens ] if req .input_ids else list (range (num_tokens ))
110+ else :
111+ prompt_token_ids = getattr (req , "input_ids" , None ) or []
112+
103113 num_new_tokens = max (len (prompt_token_ids ) - num_computed_tokens , 0 )
104114 if num_new_tokens > 0 :
105115 new_blocks = kv_cache_manager .allocate_slots (
@@ -123,7 +133,7 @@ def form_vllm_batch_prefill(
123133
124134 new_req_data = NewRequestData (
125135 req_id = req .request_id ,
126- prompt_token_ids = req . input_ids ,
136+ prompt_token_ids = prompt_token_ids ,
127137 mm_features = [],
128138 sampling_params = sampling_params ,
129139 pooling_params = None ,
0 commit comments