Skip to content

Commit 0e023e9

Browse files
njhillkhluu
authored andcommitted
[BugFix] Fix PP performance and PP kv connector output regression (#28768)
Signed-off-by: Nick Hill <nhill@redhat.com> (cherry picked from commit 7765e5b)
1 parent 4ae8b23 commit 0e023e9

File tree

4 files changed

+108
-104
lines changed

4 files changed

+108
-104
lines changed

vllm/v1/engine/core.py

Lines changed: 66 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
from vllm.v1.request import Request, RequestStatus
6262
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6363
from vllm.v1.structured_output import StructuredOutputManager
64-
from vllm.v1.utils import record_function_or_nullcontext
6564
from vllm.version import __version__ as VLLM_VERSION
6665

6766
logger = init_logger(__name__)
@@ -179,11 +178,13 @@ def __init__(
179178
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
180179
self.batch_queue = deque(maxlen=self.batch_queue_size)
181180

181+
self.ec_producer = (
182+
vllm_config.ec_transfer_config is not None
183+
and vllm_config.ec_transfer_config.is_ec_producer
184+
)
185+
182186
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
183-
if (
184-
self.vllm_config.cache_config.enable_prefix_caching
185-
or kv_connector is not None
186-
):
187+
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
187188
caching_hash_fn = get_hash_fn_by_name(
188189
vllm_config.cache_config.prefix_caching_hash_algo
189190
)
@@ -239,7 +240,7 @@ def _initialize_kv_caches(
239240

240241
elapsed = time.time() - start
241242
logger.info_once(
242-
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
243+
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
243244
elapsed,
244245
scope="local",
245246
)
@@ -305,6 +306,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
305306
)
306307
raise err
307308

309+
def _log_err_callback(self, scheduler_output: SchedulerOutput):
310+
"""Log error details of a future that's not expected to return a result."""
311+
312+
def callback(f, sched_output=scheduler_output):
313+
with self.log_error_detail(sched_output):
314+
result = f.result()
315+
assert result is None
316+
317+
return callback
318+
308319
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
309320
"""Schedule, execute, and make output.
310321
@@ -316,21 +327,17 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
316327
# or finished and not yet removed from the batch.
317328
if not self.scheduler.has_requests():
318329
return {}, False
319-
with record_function_or_nullcontext("core step: schedule"):
320-
scheduler_output = self.scheduler.schedule()
321-
322-
with record_function_or_nullcontext("core step: execute_model"):
323-
future = self.model_executor.execute_model(scheduler_output, non_block=True)
324-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
325-
with self.log_error_detail(scheduler_output):
326-
model_output = future.result()
327-
if model_output is None:
328-
model_output = self.model_executor.sample_tokens(grammar_output)
329-
330-
with record_function_or_nullcontext("core step: update_from_output"):
331-
engine_core_outputs = self.scheduler.update_from_output(
332-
scheduler_output, model_output
333-
)
330+
scheduler_output = self.scheduler.schedule()
331+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
332+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
333+
with self.log_error_detail(scheduler_output):
334+
model_output = future.result()
335+
if model_output is None:
336+
model_output = self.model_executor.sample_tokens(grammar_output)
337+
338+
engine_core_outputs = self.scheduler.update_from_output(
339+
scheduler_output, model_output
340+
)
334341

335342
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
336343

@@ -368,52 +375,34 @@ def step_with_batch_queue(
368375
model_executed = False
369376
deferred_scheduler_output = None
370377
if self.scheduler.has_requests():
371-
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
372-
scheduler_output = self.scheduler.schedule()
373-
with record_function_or_nullcontext(
374-
"core step_with_batch_queue: execute_model"
375-
):
376-
exec_future = self.model_executor.execute_model(
377-
scheduler_output, non_block=True
378-
)
379-
model_executed = scheduler_output.total_num_scheduled_tokens > 0
378+
scheduler_output = self.scheduler.schedule()
379+
exec_future = self.model_executor.execute_model(
380+
scheduler_output, non_block=True
381+
)
382+
if not self.ec_producer:
383+
model_executed = scheduler_output.total_num_scheduled_tokens > 0
380384

381-
if scheduler_output.pending_structured_output_tokens:
382-
with record_function_or_nullcontext(
383-
"core step_with_batch_queue: pending_structured_output_tokens"
384-
):
385-
# We need to defer sampling until we have processed the model output
386-
# from the prior step.
387-
deferred_scheduler_output = scheduler_output
388-
# Block-wait for execute to return
389-
# (continues running async on the GPU).
390-
with self.log_error_detail(scheduler_output):
391-
exec_result = exec_future.result()
392-
assert exec_result is None
385+
if not model_executed:
386+
# No sampling required (no requests scheduled).
387+
future = cast(Future[ModelRunnerOutput], exec_future)
393388
else:
394-
with record_function_or_nullcontext(
395-
"core step_with_batch_queue: get_grammar_bitmask"
396-
):
397-
# We aren't waiting for any tokens, get any grammar
398-
# output immediately.
389+
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
390+
391+
if not scheduler_output.pending_structured_output_tokens:
392+
# We aren't waiting for any tokens, get any grammar output
393+
# and sample immediately.
399394
grammar_output = self.scheduler.get_grammar_bitmask(
400395
scheduler_output
401396
)
402-
# Block-wait for execute to return (continues running async on the GPU).
403-
with self.log_error_detail(scheduler_output):
404-
exec_result = exec_future.result()
405-
406-
if exec_result is None:
407-
with record_function_or_nullcontext(
408-
"core step_with_batch_queue: sample_tokens"
409-
):
410-
# Call sample tokens.
411-
future = self.model_executor.sample_tokens(
412-
grammar_output, non_block=True
413-
)
397+
future = self.model_executor.sample_tokens(
398+
grammar_output, non_block=True
399+
)
414400
else:
415-
# No sampling required (e.g. all requests finished).
416-
future = cast(Future[ModelRunnerOutput], exec_future)
401+
# We need to defer sampling until we have processed the model output
402+
# from the prior step.
403+
deferred_scheduler_output = scheduler_output
404+
405+
if not deferred_scheduler_output:
417406
# Add this step's future to the queue.
418407
batch_queue.appendleft((future, scheduler_output))
419408
if (
@@ -430,34 +419,27 @@ def step_with_batch_queue(
430419
# only be called when the scheduler contains requests or the queue
431420
# is non-empty.
432421
return None, False
433-
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
434-
# Block until the next result is available.
435-
future, scheduler_output = batch_queue.pop()
436-
with self.log_error_detail(scheduler_output):
437-
model_output = future.result()
438-
with record_function_or_nullcontext(
439-
"core step_with_batch_queue: update_from_output"
440-
):
441-
engine_core_outputs = self.scheduler.update_from_output(
442-
scheduler_output, model_output
443-
)
422+
423+
# Block until the next result is available.
424+
future, scheduler_output = batch_queue.pop()
425+
with self.log_error_detail(scheduler_output):
426+
model_output = future.result()
427+
428+
engine_core_outputs = self.scheduler.update_from_output(
429+
scheduler_output, model_output
430+
)
444431

445432
# NOTE(nick): We can either handle the deferred tasks here or save
446433
# in a field and do it immediately once step_with_batch_queue is
447434
# re-called. The latter slightly favors TTFT over TPOT/throughput.
448435
if deferred_scheduler_output:
449-
with record_function_or_nullcontext(
450-
"core step_with_batch_queue: deferred_scheduler_output"
451-
):
452-
# We now have the tokens needed to compute the bitmask for the
453-
# deferred request. Get the bitmask and call sample tokens.
454-
grammar_output = self.scheduler.get_grammar_bitmask(
455-
deferred_scheduler_output
456-
)
457-
future = self.model_executor.sample_tokens(
458-
grammar_output, non_block=True
459-
)
460-
batch_queue.appendleft((future, deferred_scheduler_output))
436+
# We now have the tokens needed to compute the bitmask for the
437+
# deferred request. Get the bitmask and call sample tokens.
438+
grammar_output = self.scheduler.get_grammar_bitmask(
439+
deferred_scheduler_output
440+
)
441+
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
442+
batch_queue.appendleft((future, deferred_scheduler_output))
461443

462444
return engine_core_outputs, model_executed
463445

vllm/v1/executor/ray_executor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def _init_executor(self) -> None:
9999
# KV connector setup
100100
self.has_connector = self.vllm_config.kv_transfer_config is not None
101101

102+
self.ec_producer = (
103+
self.vllm_config.ec_transfer_config is not None
104+
and self.vllm_config.ec_transfer_config.is_ec_producer
105+
)
106+
102107
self.scheduler_output: SchedulerOutput | None = None
103108

104109
@property
@@ -395,6 +400,12 @@ def execute_model( # type: ignore[override]
395400
"State error: sample_tokens() must be called "
396401
"after execute_model() returns None."
397402
)
403+
404+
if self.ec_producer or not scheduler_output.total_num_scheduled_tokens:
405+
# Model will not execute, call model runner immediately.
406+
return self._execute_dag(scheduler_output, None, non_block)
407+
408+
# Model will execute, defer to sample_tokens() call.
398409
self.scheduler_output = scheduler_output
399410
return COMPLETED_NONE_FUTURE if non_block else None
400411

@@ -417,10 +428,18 @@ def sample_tokens( # type: ignore[override]
417428
"""
418429
scheduler_output = self.scheduler_output
419430
if scheduler_output is None:
420-
return None # noqa
431+
return COMPLETED_NONE_FUTURE if non_block else None # noqa
421432

422433
self.scheduler_output = None
423434

435+
return self._execute_dag(scheduler_output, grammar_output, non_block)
436+
437+
def _execute_dag(
438+
self,
439+
scheduler_output: SchedulerOutput,
440+
grammar_output: "GrammarOutput | None",
441+
non_block: bool = False,
442+
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
424443
# Build the compiled DAG for the first time.
425444
if self.forward_dag is None: # type: ignore
426445
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88
from collections.abc import Iterator
99
from contextlib import contextmanager
10-
from copy import deepcopy
10+
from copy import copy, deepcopy
1111
from functools import reduce
1212
from itertools import product
1313
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -236,7 +236,7 @@ class ExecuteModelState(NamedTuple):
236236
hidden_states: torch.Tensor
237237
sample_hidden_states: torch.Tensor
238238
aux_hidden_states: list[torch.Tensor] | None
239-
kv_connector_output: KVConnectorOutput | None
239+
ec_connector_output: ECConnectorOutput | None
240240

241241

242242
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@@ -542,6 +542,7 @@ def __init__(
542542

543543
# Ephemeral state transferred between execute_model() and sample_tokens().
544544
self.execute_model_state: ExecuteModelState | None = None
545+
self.kv_connector_output: KVConnectorOutput | None = None
545546

546547
def reset_mm_cache(self) -> None:
547548
if self.mm_budget:
@@ -2673,6 +2674,7 @@ def execute_model(
26732674
# Return the intermediate tensors.
26742675
assert isinstance(hidden_states, IntermediateTensors)
26752676
hidden_states.kv_connector_output = kv_connector_output
2677+
self.kv_connector_output = kv_connector_output
26762678
return hidden_states
26772679

26782680
if self.is_pooling_model:
@@ -2723,17 +2725,31 @@ def execute_model(
27232725
hidden_states,
27242726
sample_hidden_states,
27252727
aux_hidden_states,
2726-
kv_connector_output,
2728+
ec_connector_output,
27272729
)
2730+
self.kv_connector_output = kv_connector_output
27282731
return None
27292732

27302733
@torch.inference_mode
27312734
def sample_tokens(
27322735
self, grammar_output: "GrammarOutput | None"
27332736
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2737+
kv_connector_output = self.kv_connector_output
2738+
self.kv_connector_output = None
2739+
27342740
if self.execute_model_state is None:
27352741
# Nothing to do (PP non-final rank case), output isn't used.
2736-
return None # noqa
2742+
if not kv_connector_output:
2743+
return None # noqa
2744+
2745+
# In case of PP with kv transfer, we need to pass through the
2746+
# kv_connector_output
2747+
if kv_connector_output.is_empty():
2748+
return EMPTY_MODEL_RUNNER_OUTPUT
2749+
2750+
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
2751+
output.kv_connector_output = kv_connector_output
2752+
return output
27372753

27382754
# Unpack ephemeral state.
27392755
(
@@ -2744,7 +2760,7 @@ def sample_tokens(
27442760
hidden_states,
27452761
sample_hidden_states,
27462762
aux_hidden_states,
2747-
kv_connector_output,
2763+
ec_connector_output,
27482764
) = self.execute_model_state
27492765
# Clear ephemeral state.
27502766
self.execute_model_state = None

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A GPU worker class."""
44

5-
import copy
65
import gc
76
import os
87
from contextlib import AbstractContextManager, nullcontext
@@ -44,7 +43,6 @@
4443
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
4544
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4645
from vllm.v1.outputs import (
47-
EMPTY_MODEL_RUNNER_OUTPUT,
4846
AsyncModelRunnerOutput,
4947
DraftTokenIds,
5048
ModelRunnerOutput,
@@ -572,18 +570,7 @@ def execute_model(
572570
all_gather_tensors=all_gather_tensors,
573571
)
574572

575-
kv_connector_output = output.kv_connector_output
576-
if not kv_connector_output:
577-
return None
578-
579-
# In case of PP with kv transfer, we need to pass through the
580-
# kv_connector_output
581-
if kv_connector_output.is_empty():
582-
return EMPTY_MODEL_RUNNER_OUTPUT
583-
584-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
585-
output.kv_connector_output = kv_connector_output
586-
return output
573+
return None
587574

588575
def take_draft_token_ids(self) -> DraftTokenIds | None:
589576
return self.model_runner.take_draft_token_ids()

0 commit comments

Comments
 (0)