Skip to content

Commit f7f2b52

Browse files
committed
fix comments
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent b4efa5d commit f7f2b52

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

tpu_inference/worker/tpu_worker.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import tempfile
5+
from dataclasses import dataclass, field
56
from typing import Callable, Dict, Optional, Tuple
67

78
import jax
@@ -43,6 +44,26 @@
4344
}
4445

4546

47+
@dataclass
48+
class PPConfig:
49+
rank: int
50+
ip: str
51+
prev_worker_ip: str
52+
pp_world_size: int
53+
54+
# default env vars for
55+
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56+
# if PP is used in single host.
57+
default_tpu_process_bounds: str = field(init=False)
58+
default_tpu_chips_per_process_bounds: str = field(init=False)
59+
default_tpu_visible_chips: str = field(init=False)
60+
61+
def __post_init__(self):
62+
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63+
self.default_tpu_chips_per_process_bounds = "1,1,1"
64+
self.default_tpu_visible_chips = f"{self.rank}"
65+
66+
4667
class TPUWorker:
4768

4869
def __init__(
@@ -82,9 +103,8 @@ def __init__(
82103
self.devices = devices if devices is not None else []
83104
self.device_ranks = set(device.id for device in self.devices
84105
if isinstance(device, jaxlib._jax.Device))
85-
self.ip = ip
86-
self.prev_worker_ip = prev_worker_ip
87-
self.pp_world_size = self.parallel_config.pipeline_parallel_size
106+
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107+
self.parallel_config.pipeline_parallel_size)
88108

89109
if self.model_config.trust_remote_code:
90110
# note: lazy import to avoid importing torch before initializing
@@ -107,8 +127,10 @@ def __init__(
107127

108128
# For PP, we use MPMD so we want to profile every worker.
109129
if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR:
110-
self.profile_dir = os.path.join(envs.VLLM_TORCH_PROFILER_DIR,
111-
f"rank_{self.rank}")
130+
self.profile_dir = os.path.join(
131+
envs.VLLM_TORCH_PROFILER_DIR,
132+
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133+
)
112134
os.makedirs(self.profile_dir, exist_ok=True)
113135

114136
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
@@ -122,24 +144,21 @@ def __init__(
122144
)
123145
jax.profiler.start_server(jax_profiler_server_port)
124146

147+
# step_counter is used to calculate uuid to transfer intermediate tensors.
125148
self.step_counter = 0
126149

127150
def initialize_cache(self, num_gpu_blocks: int,
128151
num_cpu_blocks: int) -> None:
129152
self.cache_config.num_gpu_blocks = num_gpu_blocks
130153
self.cache_config.num_cpu_blocks = num_cpu_blocks
131154

132-
def init_device(self):
155+
def init_device(self,
156+
tpu_process_bounds="",
157+
tpu_chips_per_process_bounds="",
158+
tpu_visible_chips=""):
133159
# set tpu visible devices for Jax runtime in single host PP.
134160
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
135161
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
136-
# Note: Below is the setting for v6e8 host (8 chips of v6e)
137-
# There are 2 ways of subslicing a v6e:
138-
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
139-
# 2) 1 chip for each subslice, with at most 8 subslices,
140-
# we can do TP=1, PP=1/2/3/4/5/6/7/8
141-
# Replace with your own topology.
142-
143162
tpu_ports = [
144163
jax_parallel_state.BASE_JAX_PORT + i
145164
for i in range(self.pp_world_size)
@@ -149,23 +168,38 @@ def init_device(self):
149168
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
150169
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
151170

152-
# first way of subslicing.
153-
# os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
154-
# os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
155-
# os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"
156-
157-
# second way of subslicing.
158-
os.environ["TPU_PROCESS_BOUNDS"] = f"1,{self.pp_world_size},1"
159-
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1"
160-
os.environ["TPU_VISIBLE_CHIPS"] = f"{self.rank}"
171+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
172+
# Replace with your own topology.
173+
# There are 2 ways of subslicing a v6e
174+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175+
# TPU_PROCESS_BOUNDS = "1,1,1"
176+
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177+
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178+
# 2) 1 chip for each subslice, with at most 8 subslices,
179+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
180+
os.environ[
181+
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182+
if tpu_process_bounds \
183+
else self.pp_config.default_tpu_process_bounds
184+
os.environ[
185+
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186+
if tpu_chips_per_process_bounds \
187+
else self.pp_config.default_tpu_chips_per_process_bounds
188+
os.environ[
189+
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190+
if tpu_visible_chips \
191+
else self.pp_config.default_tpu_visible_chips
161192

162193
if not self.devices:
163194
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
164195
device_indexes = sharding_config.device_indexes
165196
if device_indexes is not None and len(device_indexes) > 0:
166197
# Enforcing the devices sequence to be consistent with the specified device indexes
167198
all_local_devices = jax.local_devices()
168-
device_dict = {device.id: device for device in all_local_devices}
199+
device_dict = {
200+
device.id: device
201+
for device in all_local_devices
202+
}
169203
self.devices = []
170204
for device_index in device_indexes:
171205
device = device_dict[device_index]
@@ -178,7 +212,8 @@ def init_device(self):
178212
assert len(self.devices) >= sharding_config.total_devices
179213
self.devices = self.devices[:sharding_config.total_devices]
180214
else:
181-
assert jax.local_device_count() >= sharding_config.total_devices
215+
assert jax.local_device_count(
216+
) >= sharding_config.total_devices
182217
self.devices = jax.local_devices()[:sharding_config.
183218
total_devices]
184219
# Initialize the vLLM distribution layer as a single chip environment,
@@ -198,7 +233,7 @@ def init_device(self):
198233
)
199234

200235
jax_parallel_state.init_pp_distributed_environment(
201-
self.ip,
236+
self.pp_config.ip,
202237
self.rank,
203238
self.parallel_config.pipeline_parallel_size,
204239
self.devices[0],
@@ -218,7 +253,8 @@ def init_device(self):
218253
def initialize_pp_transfer_connect(self):
219254
if self.rank == 0:
220255
return
221-
jax_parallel_state.connect(self.prev_worker_ip, self.rank - 1)
256+
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
257+
self.rank - 1)
222258

223259
def determine_available_memory(self) -> int:
224260
gpu_memory_utilization = self.cache_config.gpu_memory_utilization

0 commit comments

Comments
 (0)