Skip to content

Commit 5f8c45b

Browse files
committed
update batch info prefill
1 parent 29511eb commit 5f8c45b

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/parallax/vllm/batch_info.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def form_vllm_batch_prefill(
9090
new_request_data_list = []
9191
num_scheduled_tokens: Dict[str, int] = {}
9292
total_tokens = 0
93+
94+
# Check if this is a non-first peer (IntermediateRequest with hidden_states)
95+
is_first_peer = model_runner.is_first_peer if hasattr(model_runner, "is_first_peer") else True
9396

9497
for req in batched_requests:
9598
sampling_params = transform_sampling_params_to_vllm(req.sampling_params)
@@ -99,7 +102,14 @@ def form_vllm_batch_prefill(
99102

100103
computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req)
101104

102-
prompt_token_ids = getattr(req, "input_ids", None) or []
105+
# For non-first peers, use hidden_states shape instead of input_ids length
106+
if not is_first_peer and hasattr(req, "hidden_states") and req.hidden_states is not None:
107+
# hidden_states shape: (num_tokens, hidden_size)
108+
num_tokens = req.hidden_states.shape[0]
109+
prompt_token_ids = req.input_ids[:num_tokens] if req.input_ids else list(range(num_tokens))
110+
else:
111+
prompt_token_ids = getattr(req, "input_ids", None) or []
112+
103113
num_new_tokens = max(len(prompt_token_ids) - num_computed_tokens, 0)
104114
if num_new_tokens > 0:
105115
new_blocks = kv_cache_manager.allocate_slots(
@@ -123,7 +133,7 @@ def form_vllm_batch_prefill(
123133

124134
new_req_data = NewRequestData(
125135
req_id=req.request_id,
126-
prompt_token_ids=req.input_ids,
136+
prompt_token_ids=prompt_token_ids,
127137
mm_features=[],
128138
sampling_params=sampling_params,
129139
pooling_params=None,

0 commit comments

Comments
 (0)