diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 0ae4149c8..4801c861a 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -25,6 +25,7 @@ def mock_vllm_config(): mock_parallel_conf = MagicMock() mock_parallel_conf.tensor_parallel_size = 2 mock_parallel_conf.data_parallel_size = 1 + mock_parallel_conf.pipeline_parallel_size = 1 mock_parallel_conf.nnodes = 1 mock_parallel_conf.nnodes_within_dp = 1 @@ -118,8 +119,14 @@ def test_init_device_with_provided_devices( worker.init_device() - mock_jax.devices.assert_not_called() - mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices) + mock_jax.local_devices.assert_not_called() + expected_rank = 0 + expected_is_first_rank = True + expected_is_last_rank = True + mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices, + expected_rank, + expected_is_first_rank, + expected_is_last_rank) assert isinstance(worker.model_runner, MagicMock) @patch('tpu_inference.worker.tpu_worker.TPUModelRunner') @@ -137,15 +144,24 @@ def test_init_device_autodetects_devices( distributed_init_method="test_method", devices=[] # No devices provided, should trigger auto-detection ) - mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3'] + mock_jax.local_device_count.return_value = 4 + mock_jax.local_devices.return_value = [ + 'tpu:0', 'tpu:1', 'tpu:2', 'tpu:3' + ] worker.init_device() - mock_jax.devices.assert_called_once() + mock_jax.local_devices.assert_called_once() expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size assert worker.devices == expected_devices + expected_rank = 0 + expected_is_first_rank = True + expected_is_last_rank = True mock_runner_cls.assert_called_once_with(mock_vllm_config, - expected_devices) + expected_devices, + expected_rank, + expected_is_first_rank, + expected_is_last_rank) @patch('tpu_inference.worker.tpu_worker.utils') def test_determine_available_memory(self, mock_utils, mock_vllm_config): @@ -194,7 +210,7 @@ def test_execute_model(self, mock_runner_cls, mock_vllm_config): # Assert the runner was called with the scheduler output directly worker.model_runner.execute_model.assert_called_once_with( - mock_scheduler_input) + mock_scheduler_input, None) # Assert the final result is the concrete model output assert result == mock_model_output diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index f3c6d7899..e4f8e791e 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -223,6 +223,9 @@ def __init__( self, vllm_config: VllmConfig, devices: List[Any], + rank: int = 0, + is_first_rank: bool = True, + is_last_rank: bool = True, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index e207fa1a7..efab89e07 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -2,6 +2,7 @@ import os import tempfile +from dataclasses import dataclass, field from typing import Callable, Dict, Optional, Tuple import jax @@ -10,6 +11,7 @@ import jaxtyping import vllm.envs as vllm_envs from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, has_kv_transfer_group) from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -23,10 +25,13 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from tpu_inference import envs, utils +from tpu_inference.distributed import jax_parallel_state from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.common.sharding import ShardingConfigManager from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -39,15 +44,39 @@ } +@dataclass +class PPConfig: + rank: int + ip: str + prev_worker_ip: str + pp_world_size: int + + # default env vars for + # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS + # if PP is used in single host. + default_tpu_process_bounds: str = field(init=False) + default_tpu_chips_per_process_bounds: str = field(init=False) + default_tpu_visible_chips: str = field(init=False) + + def __post_init__(self): + self.default_tpu_process_bounds = f"1,{self.pp_world_size},1" + self.default_tpu_chips_per_process_bounds = "1,1,1" + self.default_tpu_visible_chips = f"{self.rank}" + + class TPUWorker: - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - devices=None): + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + devices=None, + ip: str = "localhost", + prev_worker_ip: str = "localhost", + ): # If we use vLLM's model implementation in PyTorch, we should set it # with torch version of the dtype. impl = envs.MODEL_IMPL_TYPE @@ -74,6 +103,8 @@ def __init__(self, self.devices = devices if devices is not None else [] self.device_ranks = set(device.id for device in self.devices if isinstance(device, jaxlib._jax.Device)) + self.pp_config = PPConfig(rank, ip, prev_worker_ip, + self.parallel_config.pipeline_parallel_size) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -86,7 +117,7 @@ def __init__(self, # TPU Worker is initialized. The profiler server needs to start after # MP runtime is initialized. self.profile_dir = None - if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1: if not self.devices or 0 in self.device_ranks: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. @@ -94,6 +125,14 @@ def __init__(self, logger.info("Profiling enabled. Traces will be saved to: %s", self.profile_dir) + # For PP, we use MPMD so we want to profile every worker. + if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR: + self.profile_dir = os.path.join( + vllm_envs.VLLM_TORCH_PROFILER_DIR, + f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}" + ) + os.makedirs(self.profile_dir, exist_ok=True) + use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False) # Only one instance of profiler is allowed if use_jax_profiler_server and self.rank < 1: @@ -105,32 +144,78 @@ def __init__(self, ) jax.profiler.start_server(jax_profiler_server_port) + # step_counter is used to calculate uuid to transfer intermediate tensors. + self.step_counter = 0 + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - def init_device(self): + def init_device(self, + tpu_process_bounds="", + tpu_chips_per_process_bounds="", + tpu_visible_chips=""): + # set tpu visible devices for Jax runtime in single host PP. + multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1: + tpu_ports = [ + jax_parallel_state.BASE_JAX_PORT + i + for i in range(self.pp_config.pp_world_size) + ] + os.environ["TPU_PROCESS_ADDRESSES"] = ",".join( + [f"localhost:{port}" for port in tpu_ports]) + os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}" + os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}" + + # Note: Below is the setting for v6e8 host (8 chips of v6e) + # Replace with your own topology. + # There are 2 ways of subslicing a v6e + # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4 + # TPU_PROCESS_BOUNDS = "1,1,1" + # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1" + # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7" + # 2) 1 chip for each subslice, with at most 8 subslices, + # we can do TP=1, PP=1/2/3/4/5/6/7/8 + os.environ[ + "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \ + if tpu_process_bounds \ + else self.pp_config.default_tpu_process_bounds + os.environ[ + "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \ + if tpu_chips_per_process_bounds \ + else self.pp_config.default_tpu_chips_per_process_bounds + os.environ[ + "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \ + if tpu_visible_chips \ + else self.pp_config.default_tpu_visible_chips + if not self.devices: sharding_config: ShardingConfigManager = self.vllm_config.sharding_config device_indexes = sharding_config.device_indexes if device_indexes is not None and len(device_indexes) > 0: # Enforcing the devices sequence to be consistent with the specified device indexes - all_devices = jax.devices() - device_dict = {device.id: device for device in all_devices} + all_local_devices = jax.local_devices() + device_dict = { + device.id: device + for device in all_local_devices + } self.devices = [] for device_index in device_indexes: device = device_dict[device_index] if device is None: raise KeyError( f"Device index {device_index} not found in " - f"jax.devices() with IDs {list(device_dict.keys())}!" + f"jax.local_devices() with IDs {list(device_dict.keys())}!" ) self.devices.append(device) + assert len(self.devices) >= sharding_config.total_devices self.devices = self.devices[:sharding_config.total_devices] else: - self.devices = jax.devices()[:sharding_config.total_devices] - + assert jax.local_device_count( + ) >= sharding_config.total_devices + self.devices = jax.local_devices()[:sharding_config. + total_devices] # Initialize the vLLM distribution layer as a single chip environment, # we'll swap the model's parallel modules with TPU SPMD equivalents. with set_current_vllm_config(self.vllm_config): @@ -146,8 +231,18 @@ def init_device(self): tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ) + + jax_parallel_state.init_pp_distributed_environment( + self.pp_config.ip, + self.rank, + self.parallel_config.pipeline_parallel_size, + self.devices[0], + need_pp=self.parallel_config.pipeline_parallel_size > 1) + ensure_kv_transfer_initialized(self.vllm_config) - self.model_runner = TPUModelRunner(self.vllm_config, self.devices) + self.model_runner = TPUModelRunner( + self.vllm_config, self.devices, self.rank, self.rank == 0, + self.rank == self.pp_config.pp_world_size - 1) logger.info(f"Init worker | " f"rank={self.rank} | " f"node_id={get_node_id()} | " @@ -155,6 +250,12 @@ def init_device(self): f"hbm={utils.hbm_usage_gb(self.devices)}GiB") vllm_utils.report_usage_stats(self.vllm_config) + def initialize_pp_transfer_connect(self): + if self.rank == 0: + return + jax_parallel_state.connect(self.pp_config.prev_worker_ip, + self.rank - 1) + def determine_available_memory(self) -> int: gpu_memory_utilization = self.cache_config.gpu_memory_utilization hbm_usage = utils.hbm_usage_bytes(self.devices) @@ -194,14 +295,39 @@ def execute_model( # deliberate, temporary compromise for the same reasons outlined in # the `get_kv_cache_spec` method. - output = self.model_runner.execute_model(scheduler_output) - - # With a connector, the scheduler expects output from all workers - # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 - if has_kv_transfer_group(): - return output - - return output if self.is_driver_worker else None + if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0: + intermediate_tensors = None + else: + # receive intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank - 1, self.step_counter) + # TODO: this method might only works for vllm model, not sure about jax models. + tensor_spec = self.model_runner.get_intermediate_tensor_spec( + scheduler_output.total_num_scheduled_tokens) + intermediate_tensors_dict = get_pp_group().recv_tensor_dict( + uuid, tensor_spec) + intermediate_tensors = JaxIntermediateTensors( + intermediate_tensors_dict) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if isinstance(output, JaxIntermediateTensors): + assert self.parallel_config.pipeline_parallel_size > 1 + assert not get_pp_group().is_last_rank + # send intermediate tensors + uuid = self.model_runner.get_uuid_for_jax_transfer( + scheduler_output, self.rank, self.step_counter) + get_pp_group().send_tensor_dict(uuid, output.tensors) + self.step_counter += 1 + return None + else: + self.step_counter += 1 + # With a connector, the scheduler expects output from all workers + # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866 + if has_kv_transfer_group(): + return output + return output if self.is_driver_worker else None def sample_tokens(self, grammar_output: GrammarOutput) -> ModelRunnerOutput: