Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions tests/worker/tpu_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def mock_vllm_config():
mock_parallel_conf = MagicMock()
mock_parallel_conf.tensor_parallel_size = 2
mock_parallel_conf.data_parallel_size = 1
mock_parallel_conf.pipeline_parallel_size = 1
mock_parallel_conf.nnodes = 1
mock_parallel_conf.nnodes_within_dp = 1

Expand Down Expand Up @@ -118,8 +119,14 @@ def test_init_device_with_provided_devices(

worker.init_device()

mock_jax.devices.assert_not_called()
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices)
mock_jax.local_devices.assert_not_called()
expected_rank = 0
expected_is_first_rank = True
expected_is_last_rank = True
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices,
expected_rank,
expected_is_first_rank,
expected_is_last_rank)
assert isinstance(worker.model_runner, MagicMock)

@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
Expand All @@ -137,15 +144,24 @@ def test_init_device_autodetects_devices(
distributed_init_method="test_method",
devices=[] # No devices provided, should trigger auto-detection
)
mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']
mock_jax.local_device_count.return_value = 4
mock_jax.local_devices.return_value = [
'tpu:0', 'tpu:1', 'tpu:2', 'tpu:3'
]

worker.init_device()

mock_jax.devices.assert_called_once()
mock_jax.local_devices.assert_called_once()
expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size
assert worker.devices == expected_devices
expected_rank = 0
expected_is_first_rank = True
expected_is_last_rank = True
mock_runner_cls.assert_called_once_with(mock_vllm_config,
expected_devices)
expected_devices,
expected_rank,
expected_is_first_rank,
expected_is_last_rank)

@patch('tpu_inference.worker.tpu_worker.utils')
def test_determine_available_memory(self, mock_utils, mock_vllm_config):
Expand Down Expand Up @@ -194,7 +210,7 @@ def test_execute_model(self, mock_runner_cls, mock_vllm_config):

# Assert the runner was called with the scheduler output directly
worker.model_runner.execute_model.assert_called_once_with(
mock_scheduler_input)
mock_scheduler_input, None)
# Assert the final result is the concrete model output
assert result == mock_model_output

Expand Down
3 changes: 3 additions & 0 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def __init__(
self,
vllm_config: VllmConfig,
devices: List[Any],
rank: int = 0,
is_first_rank: bool = True,
is_last_rank: bool = True,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down
172 changes: 149 additions & 23 deletions tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import tempfile
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional, Tuple

import jax
Expand All @@ -10,6 +11,7 @@
import jaxtyping
import vllm.envs as vllm_envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
Expand All @@ -23,10 +25,13 @@
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput

from tpu_inference import envs, utils
from tpu_inference.distributed import jax_parallel_state
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
get_node_id)
from tpu_inference.layers.common.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
from tpu_inference.runner.tpu_runner import TPUModelRunner

Expand All @@ -39,15 +44,39 @@
}


@dataclass
class PPConfig:
rank: int
ip: str
prev_worker_ip: str
pp_world_size: int

# default env vars for
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
# if PP is used in single host.
default_tpu_process_bounds: str = field(init=False)
default_tpu_chips_per_process_bounds: str = field(init=False)
default_tpu_visible_chips: str = field(init=False)

def __post_init__(self):
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
self.default_tpu_chips_per_process_bounds = "1,1,1"
self.default_tpu_visible_chips = f"{self.rank}"


class TPUWorker:

def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
devices=None):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
devices=None,
ip: str = "localhost",
prev_worker_ip: str = "localhost",
):
# If we use vLLM's model implementation in PyTorch, we should set it
# with torch version of the dtype.
impl = envs.MODEL_IMPL_TYPE
Expand All @@ -74,6 +103,8 @@ def __init__(self,
self.devices = devices if devices is not None else []
self.device_ranks = set(device.id for device in self.devices
if isinstance(device, jaxlib._jax.Device))
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
self.parallel_config.pipeline_parallel_size)

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand All @@ -86,14 +117,22 @@ def __init__(self,
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profile_dir = None
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
if not self.devices or 0 in self.device_ranks:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)

# For PP, we use MPMD so we want to profile every worker.
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
self.profile_dir = os.path.join(
vllm_envs.VLLM_TORCH_PROFILER_DIR,
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
)
os.makedirs(self.profile_dir, exist_ok=True)

use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
# Only one instance of profiler is allowed
if use_jax_profiler_server and self.rank < 1:
Expand All @@ -105,32 +144,78 @@ def __init__(self,
)
jax.profiler.start_server(jax_profiler_server_port)

# step_counter is used to calculate uuid to transfer intermediate tensors.
self.step_counter = 0

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
def init_device(self,
tpu_process_bounds="",
tpu_chips_per_process_bounds="",
tpu_visible_chips=""):
# set tpu visible devices for Jax runtime in single host PP.
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
tpu_ports = [
jax_parallel_state.BASE_JAX_PORT + i
for i in range(self.pp_config.pp_world_size)
]
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
[f"localhost:{port}" for port in tpu_ports])
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"

# Note: Below is the setting for v6e8 host (8 chips of v6e)
# Replace with your own topology.
# There are 2 ways of subslicing a v6e
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
# TPU_PROCESS_BOUNDS = "1,1,1"
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
# 2) 1 chip for each subslice, with at most 8 subslices,
# we can do TP=1, PP=1/2/3/4/5/6/7/8
os.environ[
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
if tpu_process_bounds \
else self.pp_config.default_tpu_process_bounds
os.environ[
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
if tpu_chips_per_process_bounds \
else self.pp_config.default_tpu_chips_per_process_bounds
os.environ[
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
if tpu_visible_chips \
else self.pp_config.default_tpu_visible_chips

if not self.devices:
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
device_indexes = sharding_config.device_indexes
if device_indexes is not None and len(device_indexes) > 0:
# Enforcing the devices sequence to be consistent with the specified device indexes
all_devices = jax.devices()
device_dict = {device.id: device for device in all_devices}
all_local_devices = jax.local_devices()
device_dict = {
device.id: device
for device in all_local_devices
}
self.devices = []
for device_index in device_indexes:
device = device_dict[device_index]
if device is None:
raise KeyError(
f"Device index {device_index} not found in "
f"jax.devices() with IDs {list(device_dict.keys())}!"
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
)
self.devices.append(device)
assert len(self.devices) >= sharding_config.total_devices
self.devices = self.devices[:sharding_config.total_devices]
else:
self.devices = jax.devices()[:sharding_config.total_devices]

assert jax.local_device_count(
) >= sharding_config.total_devices
self.devices = jax.local_devices()[:sharding_config.
total_devices]
# Initialize the vLLM distribution layer as a single chip environment,
# we'll swap the model's parallel modules with TPU SPMD equivalents.
with set_current_vllm_config(self.vllm_config):
Expand All @@ -146,15 +231,31 @@ def init_device(self):
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)

jax_parallel_state.init_pp_distributed_environment(
self.pp_config.ip,
self.rank,
self.parallel_config.pipeline_parallel_size,
self.devices[0],
need_pp=self.parallel_config.pipeline_parallel_size > 1)

ensure_kv_transfer_initialized(self.vllm_config)
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
self.model_runner = TPUModelRunner(
self.vllm_config, self.devices, self.rank, self.rank == 0,
self.rank == self.pp_config.pp_world_size - 1)
logger.info(f"Init worker | "
f"rank={self.rank} | "
f"node_id={get_node_id()} | "
f"is_driver_worker={self.is_driver_worker} | "
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
vllm_utils.report_usage_stats(self.vllm_config)

def initialize_pp_transfer_connect(self):
if self.rank == 0:
return
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
self.rank - 1)

def determine_available_memory(self) -> int:
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
hbm_usage = utils.hbm_usage_bytes(self.devices)
Expand Down Expand Up @@ -194,14 +295,39 @@ def execute_model(
# deliberate, temporary compromise for the same reasons outlined in
# the `get_kv_cache_spec` method.

output = self.model_runner.execute_model(scheduler_output)

# With a connector, the scheduler expects output from all workers
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
if has_kv_transfer_group():
return output

return output if self.is_driver_worker else None
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
intermediate_tensors = None
else:
# receive intermediate tensors
uuid = self.model_runner.get_uuid_for_jax_transfer(
scheduler_output, self.rank - 1, self.step_counter)
# TODO: this method might only works for vllm model, not sure about jax models.
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
scheduler_output.total_num_scheduled_tokens)
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
uuid, tensor_spec)
intermediate_tensors = JaxIntermediateTensors(
intermediate_tensors_dict)

output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)

if isinstance(output, JaxIntermediateTensors):
assert self.parallel_config.pipeline_parallel_size > 1
assert not get_pp_group().is_last_rank
# send intermediate tensors
uuid = self.model_runner.get_uuid_for_jax_transfer(
scheduler_output, self.rank, self.step_counter)
get_pp_group().send_tensor_dict(uuid, output.tensors)
self.step_counter += 1
return None
else:
self.step_counter += 1
# With a connector, the scheduler expects output from all workers
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
if has_kv_transfer_group():
return output
return output if self.is_driver_worker else None

def sample_tokens(self,
grammar_output: GrammarOutput) -> ModelRunnerOutput:
Expand Down