From 3f646d6a19bde19ff5d591a5cf6e6edddafd985d Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Fri, 7 Nov 2025 17:59:12 +0000 Subject: [PATCH 1/4] worker changes for pp Signed-off-by: Chenyaaang --- tpu_inference/worker/tpu_worker.py | 130 ++++++++++++++++++++++++----- 1 file changed, 110 insertions(+), 20 deletions(-) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index e207fa1a7..b98b9b846 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -10,6 +10,7 @@ import jaxtyping import vllm.envs as vllm_envs from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, has_kv_transfer_group) from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -23,10 +24,13 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from tpu_inference import envs, utils +from tpu_inference.distributed import jax_parallel_state from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.common.sharding import ShardingConfigManager from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -41,13 +45,17 @@ class TPUWorker: - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - devices=None): + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + devices=None, + ip: str = "localhost", + prev_worker_ip: str = "localhost", + ): # If we use vLLM's model implementation in PyTorch, we should set it # with torch version of the dtype. impl = envs.MODEL_IMPL_TYPE @@ -74,6 +82,9 @@ def __init__(self, self.devices = devices if devices is not None else [] self.device_ranks = set(device.id for device in self.devices if isinstance(device, jaxlib._jax.Device)) + self.ip = ip + self.prev_worker_ip = prev_worker_ip + self.pp_world_size = self.parallel_config.pipeline_parallel_size if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -86,7 +97,7 @@ def __init__(self, # TPU Worker is initialized. The profiler server needs to start after # MP runtime is initialized. self.profile_dir = None - if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1: if not self.devices or 0 in self.device_ranks: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. @@ -94,6 +105,12 @@ def __init__(self, logger.info("Profiling enabled. Traces will be saved to: %s", self.profile_dir) + # For PP, we use MPMD so we want to profile every worker. + if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR: + self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR, + f"rank_{self.rank}") + os.makedirs(self.profile_dir, exist_ok=True) + use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False) # Only one instance of profiler is allowed if use_jax_profiler_server and self.rank < 1: @@ -105,18 +122,49 @@ def __init__(self, ) jax.profiler.start_server(jax_profiler_server_port) + self.step_counter = 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks def init_device(self): + # set tpu visible devices for Jax runtime in single host PP. + multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: + # Note: Below is the setting for v6e8 host (8 chips of v6e) + # There are 2 ways of subslicing a v6e: + # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4 + # 2) 1 chip for each subslice, with at most 8 subslices, + # we can do TP=1, PP=1/2/3/4/5/6/7/8 + # Replace with your own topology. + + tpu_ports = [ + jax_parallel_state.BASE_JAX_PORT + i + for i in range(self.pp_world_size) + ] + os.environ["TPU_PROCESS_ADDRESSES"] = ",".join( + [f"localhost:{port}" for port in tpu_ports]) + os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}" + os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}" + + # first way of subslicing. + # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" + # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1" + # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7" + + # second way of subslicing. + os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1" + os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1" + os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}" + if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config device_indexes = sharding_config.device_indexes if device_indexes is not None and len(device_indexes) > 0: # Enforcing the devices sequence to be consistent with the specified device indexes - all_devices = jax.devices() + all_devices = jax.local_devices() device_dict = {device.id: device for device in all_devices} self.devices = [] for device_index in device_indexes: @@ -127,10 +175,12 @@ def init_device(self): f"jax.devices() with IDs {list(device_dict.keys())}!" ) self.devices.append(device) + assert len(self.devices) >= sharding_config.total_devices self.devices = self.devices[:sharding_config.total_devices] else: - self.devices = jax.devices()[:sharding_config.total_devices] - + assert jax.local_device_count() >= sharding_config.total_devices + self.devices = jax.local_devices()[:sharding_config. + total_devices] # Initialize the vLLM distribution layer as a single chip environment, # we'll swap the model's parallel modules with TPU SPMD equivalents. with set_current_vllm_config(self.vllm_config): @@ -146,8 +196,18 @@ def init_device(self): tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ) + + jax_parallel_state.init_pp_distributed_environment( + self.ip, + self.rank, + self.parallel_config.pipeline_parallel_size, + self.devices[0], + need_pp=self.parallel_config.pipeline_parallel_size > 1) + ensure_kv_transfer_initialized(self.vllm_config) - self.model_runner = TPUModelRunner(self.vllm_config, self.devices) + self.model_runner = TPUModelRunner(self.vllm_config, self.devices, + self.rank, self.rank == 0, + self.rank == self.pp_world_size - 1) logger.info(f"Init worker | " f"rank={self.rank} | " f"node_id={get_node_id()} | " @@ -155,6 +215,11 @@ def init_device(self): f"hbm={utils.hbm_usage_gb(self.devices)}GiB") vllm_utils.report_usage_stats(self.vllm_config) + def initialize_pp_transfer_connect(self): + if self.rank == 0: + return + jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1) + def determine_available_memory(self) -> int: gpu_memory_utilization = self.cache_config.gpu_memory_utilization hbm_usage = utils.hbm_usage_bytes(self.devices) @@ -194,14 +259,39 @@ def execute_model( # deliberate, temporary compromise for the same reasons outlined in # the `get_kv_cache_spec` method. - output = self.model_runner.execute_model(scheduler_output) - - # With a connector, the scheduler expects output from all workers - # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 - if has_kv_transfer_group(): - return output - - return output if self.is_driver_worker else None + if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0: + intermediate_tensors = None + else: + # receive intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank - 1, self.step_counter) + # TODO: this method might only works for vllm model, not sure about jax models. + tensor_spec = self.model_runner.get_intermediate_tensor_spec( + scheduler_output.total_num_scheduled_tokens) + intermediate_tensors_dict = get_pp_group().recv_tensor_dict( + uuid, tensor_spec) + intermediate_tensors = JaxIntermediateTensors( + intermediate_tensors_dict) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if isinstance(output, JaxIntermediateTensors): + assert self.parallel_config.pipeline_parallel_size > 1 + assert not get_pp_group().is_last_rank + # send intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank, self.step_counter) + get_pp_group().send_tensor_dict(uuid, output.tensors) + self.step_counter += 1 + return None + else: + self.step_counter += 1 + # With a connector, the scheduler expects output from all workers + # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 + if has_kv_transfer_group(): + return output + return output if self.is_driver_worker else None def sample_tokens(self, grammar_output: GrammarOutput) -> ModelRunnerOutput: From b4efa5d8827e10dc9375f3098f2b5bc2c5c66876 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Wed, 12 Nov 2025 20:28:29 +0000 Subject: [PATCH 2/4] resolve comments Signed-off-by: Chenyaaang --- tpu_inference/worker/tpu_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index b98b9b846..d7e501a58 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -164,8 +164,8 @@ def init_device(self): device_indexes = sharding_config.device_indexes if device_indexes is not None and len(device_indexes) > 0: # Enforcing the devices sequence to be consistent with the specified device indexes - all_devices = jax.local_devices() - device_dict = {device.id: device for device in all_devices} + all_local_devices = jax.local_devices() + device_dict = {device.id: device for device in all_local_devices} self.devices = [] for device_index in device_indexes: device = device_dict[device_index] From f7f2b5283281f8dcdf78066a146928c6729080d5 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Fri, 14 Nov 2025 03:36:49 +0000 Subject: [PATCH 3/4] fix comments Signed-off-by: Chenyaaang --- tpu_inference/worker/tpu_worker.py | 88 +++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 26 deletions(-) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index d7e501a58..65252a4d4 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -2,6 +2,7 @@ import os import tempfile +from dataclasses import dataclass, field from typing import Callable, Dict, Optional, Tuple import jax @@ -43,6 +44,26 @@ } +@dataclass +class PPConfig: + rank: int + ip: str + prev_worker_ip: str + pp_world_size: int + + # default env vars for + # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS + # if PP is used in single host. + default_tpu_process_bounds: str = field(init=False) + default_tpu_chips_per_process_bounds: str = field(init=False) + default_tpu_visible_chips: str = field(init=False) + + def __post_init__(self): + self.default_tpu_process_bounds = f"1,{self.pp_world_size},1" + self.default_tpu_chips_per_process_bounds = "1,1,1" + self.default_tpu_visible_chips = f"{self.rank}" + + class TPUWorker: def __init__( @@ -82,9 +103,8 @@ def __init__( self.devices = devices if devices is not None else [] self.device_ranks = set(device.id for device in self.devices if isinstance(device, jaxlib._jax.Device)) - self.ip = ip - self.prev_worker_ip = prev_worker_ip - self.pp_world_size = self.parallel_config.pipeline_parallel_size + self.pp_config = PPConfig(rank, ip, prev_worker_ip, + self.parallel_config.pipeline_parallel_size) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -107,8 +127,10 @@ def __init__( # For PP, we use MPMD so we want to profile every worker. if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR: - self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR, - f"rank_{self.rank}") + self.profile_dir = os.path.join( + envs.VLLM_TORCH_PROFILER_DIR, + f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}" + ) os.makedirs(self.profile_dir, exist_ok=True) use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False) @@ -122,6 +144,7 @@ def __init__( ) jax.profiler.start_server(jax_profiler_server_port) + # step_counter is used to calculate uuid to transfer intermediate tensors. self.step_counter = 0 def initialize_cache(self, num_gpu_blocks: int, @@ -129,17 +152,13 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - def init_device(self): + def init_device(self, + tpu_process_bounds="", + tpu_chips_per_process_bounds="", + tpu_visible_chips=""): # set tpu visible devices for Jax runtime in single host PP. multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: - # Note: Below is the setting for v6e8 host (8 chips of v6e) - # There are 2 ways of subslicing a v6e: - # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4 - # 2) 1 chip for each subslice, with at most 8 subslices, - # we can do TP=1, PP=1/2/3/4/5/6/7/8 - # Replace with your own topology. - tpu_ports = [ jax_parallel_state.BASE_JAX_PORT + i for i in range(self.pp_world_size) @@ -149,15 +168,27 @@ def init_device(self): os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}" os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}" - # first way of subslicing. - # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" - # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1" - # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7" - - # second way of subslicing. - os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1" - os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1" - os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}" + # Note: Below is the setting for v6e8 host (8 chips of v6e) + # Replace with your own topology. + # There are 2 ways of subslicing a v6e + # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4 + # TPU_PROCESS_BOUNDS = "1,1,1" + # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1" + # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7" + # 2) 1 chip for each subslice, with at most 8 subslices, + # we can do TP=1, PP=1/2/3/4/5/6/7/8 + os.environ[ + "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \ + if tpu_process_bounds \ + else self.pp_config.default_tpu_process_bounds + os.environ[ + "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \ + if tpu_chips_per_process_bounds \ + else self.pp_config.default_tpu_chips_per_process_bounds + os.environ[ + "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \ + if tpu_visible_chips \ + else self.pp_config.default_tpu_visible_chips if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config @@ -165,7 +196,10 @@ def init_device(self): if device_indexes is not None and len(device_indexes) > 0: # Enforcing the devices sequence to be consistent with the specified device indexes all_local_devices = jax.local_devices() - device_dict = {device.id: device for device in all_local_devices} + device_dict = { + device.id: device + for device in all_local_devices + } self.devices = [] for device_index in device_indexes: device = device_dict[device_index] @@ -178,7 +212,8 @@ def init_device(self): assert len(self.devices) >= sharding_config.total_devices self.devices = self.devices[:sharding_config.total_devices] else: - assert jax.local_device_count() >= sharding_config.total_devices + assert jax.local_device_count( + ) >= sharding_config.total_devices self.devices = jax.local_devices()[:sharding_config. total_devices] # Initialize the vLLM distribution layer as a single chip environment, @@ -198,7 +233,7 @@ def init_device(self): ) jax_parallel_state.init_pp_distributed_environment( - self.ip, + self.pp_config.ip, self.rank, self.parallel_config.pipeline_parallel_size, self.devices[0], @@ -218,7 +253,8 @@ def init_device(self): def initialize_pp_transfer_connect(self): if self.rank == 0: return - jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1) + jax_parallel_state.connect(self.pp_config.prev_worker_ip, + self.rank - 1) def determine_available_memory(self) -> int: gpu_memory_utilization = self.cache_config.gpu_memory_utilization From 8ba5cc99f50b11c2977f944b490a25aa15ec37e0 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Wed, 19 Nov 2025 22:03:13 +0000 Subject: [PATCH 4/4] fix unit tests Signed-off-by: Chenyaaang --- tests/worker/tpu_worker_test.py | 28 ++++++++++++++++++++++------ tpu_inference/runner/tpu_runner.py | 3 +++ tpu_inference/worker/tpu_worker.py | 16 ++++++++-------- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 0ae4149c8..4801c861a 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -25,6 +25,7 @@ def mock_vllm_config(): mock_parallel_conf = MagicMock() mock_parallel_conf.tensor_parallel_size = 2 mock_parallel_conf.data_parallel_size = 1 + mock_parallel_conf.pipeline_parallel_size = 1 mock_parallel_conf.nnodes = 1 mock_parallel_conf.nnodes_within_dp = 1 @@ -118,8 +119,14 @@ def test_init_device_with_provided_devices( worker.init_device() - mock_jax.devices.assert_not_called() - mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices) + mock_jax.local_devices.assert_not_called() + expected_rank = 0 + expected_is_first_rank = True + expected_is_last_rank = True + mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices, + expected_rank, + expected_is_first_rank, + expected_is_last_rank) assert isinstance(worker.model_runner, MagicMock) @patch('tpu_inference.worker.tpu_worker.TPUModelRunner') @@ -137,15 +144,24 @@ def test_init_device_autodetects_devices( distributed_init_method="test_method", devices=[] # No devices provided, should trigger auto-detection ) - mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3'] + mock_jax.local_device_count.return_value = 4 + mock_jax.local_devices.return_value = [ + 'tpu:0', 'tpu:1', 'tpu:2', 'tpu:3' + ] worker.init_device() - mock_jax.devices.assert_called_once() + mock_jax.local_devices.assert_called_once() expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size assert worker.devices == expected_devices + expected_rank = 0 + expected_is_first_rank = True + expected_is_last_rank = True mock_runner_cls.assert_called_once_with(mock_vllm_config, - expected_devices) + expected_devices, + expected_rank, + expected_is_first_rank, + expected_is_last_rank) @patch('tpu_inference.worker.tpu_worker.utils') def test_determine_available_memory(self, mock_utils, mock_vllm_config): @@ -194,7 +210,7 @@ def test_execute_model(self, mock_runner_cls, mock_vllm_config): # Assert the runner was called with the scheduler output directly worker.model_runner.execute_model.assert_called_once_with( - mock_scheduler_input) + mock_scheduler_input, None) # Assert the final result is the concrete model output assert result == mock_model_output diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index f3c6d7899..e4f8e791e 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -223,6 +223,9 @@ def __init__( self, vllm_config: VllmConfig, devices: List[Any], + rank: int = 0, + is_first_rank: bool = True, + is_last_rank: bool = True, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 65252a4d4..efab89e07 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -117,7 +117,7 @@ def __init__( # TPU Worker is initialized. The profiler server needs to start after # MP runtime is initialized. self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1: + if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1: if not self.devices or 0 in self.device_ranks: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. @@ -126,9 +126,9 @@ def __init__( self.profile_dir) # For PP, we use MPMD so we want to profile every worker. - if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR: + if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR: self.profile_dir = os.path.join( - envs.VLLM_TORCH_PROFILER_DIR, + vllm_envs.VLLM_TORCH_PROFILER_DIR, f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}" ) os.makedirs(self.profile_dir, exist_ok=True) @@ -161,7 +161,7 @@ def init_device(self, if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: tpu_ports = [ jax_parallel_state.BASE_JAX_PORT + i - for i in range(self.pp_world_size) + for i in range(self.pp_config.pp_world_size) ] os.environ["TPU_PROCESS_ADDRESSES"] = ",".join( [f"localhost:{port}" for port in tpu_ports]) @@ -206,7 +206,7 @@ def init_device(self, if device is None: raise KeyError( f"Device index {device_index} not found in " - f"jax.devices() with IDs {list(device_dict.keys())}!" + f"jax.local_devices() with IDs {list(device_dict.keys())}!" ) self.devices.append(device) assert len(self.devices) >= sharding_config.total_devices @@ -240,9 +240,9 @@ def init_device(self, need_pp=self.parallel_config.pipeline_parallel_size > 1) ensure_kv_transfer_initialized(self.vllm_config) - self.model_runner = TPUModelRunner(self.vllm_config, self.devices, - self.rank, self.rank == 0, - self.rank == self.pp_world_size - 1) + self.model_runner = TPUModelRunner( + self.vllm_config, self.devices, self.rank, self.rank == 0, + self.rank == self.pp_config.pp_world_size - 1) logger.info(f"Init worker | " f"rank={self.rank} | " f"node_id={get_node_id()} | "