22
33import os
44import tempfile
5+ from dataclasses import dataclass , field
56from typing import Callable , Dict , Optional , Tuple
67
78import jax
4344}
4445
4546
47+ @dataclass
48+ class PPConfig :
49+ rank : int
50+ ip : str
51+ prev_worker_ip : str
52+ pp_world_size : int
53+
54+ # default env vars for
55+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56+ # if PP is used in single host.
57+ default_tpu_process_bounds : str = field (init = False )
58+ default_tpu_chips_per_process_bounds : str = field (init = False )
59+ default_tpu_visible_chips : str = field (init = False )
60+
61+ def __post_init__ (self ):
62+ self .default_tpu_process_bounds = f"1,{ self .pp_world_size } ,1"
63+ self .default_tpu_chips_per_process_bounds = "1,1,1"
64+ self .default_tpu_visible_chips = f"{ self .rank } "
65+
66+
4667class TPUWorker :
4768
4869 def __init__ (
@@ -82,9 +103,8 @@ def __init__(
82103 self .devices = devices if devices is not None else []
83104 self .device_ranks = set (device .id for device in self .devices
84105 if isinstance (device , jaxlib ._jax .Device ))
85- self .ip = ip
86- self .prev_worker_ip = prev_worker_ip
87- self .pp_world_size = self .parallel_config .pipeline_parallel_size
106+ self .pp_config = PPConfig (rank , ip , prev_worker_ip ,
107+ self .parallel_config .pipeline_parallel_size )
88108
89109 if self .model_config .trust_remote_code :
90110 # note: lazy import to avoid importing torch before initializing
@@ -107,8 +127,10 @@ def __init__(
107127
108128 # For PP, we use MPMD so we want to profile every worker.
109129 if self .pp_world_size > 1 and envs .VLLM_TORCH_PROFILER_DIR :
110- self .profile_dir = os .path .join (envs .VLLM_TORCH_PROFILER_DIR ,
111- f"rank_{ self .rank } " )
130+ self .profile_dir = os .path .join (
131+ envs .VLLM_TORCH_PROFILER_DIR ,
132+ f"pprank_{ self .rank } _ppworldsize_{ self .pp_config .pp_world_size } "
133+ )
112134 os .makedirs (self .profile_dir , exist_ok = True )
113135
114136 use_jax_profiler_server = os .getenv ("USE_JAX_PROFILER_SERVER" , False )
@@ -122,24 +144,21 @@ def __init__(
122144 )
123145 jax .profiler .start_server (jax_profiler_server_port )
124146
147+ # step_counter is used to calculate uuid to transfer intermediate tensors.
125148 self .step_counter = 0
126149
127150 def initialize_cache (self , num_gpu_blocks : int ,
128151 num_cpu_blocks : int ) -> None :
129152 self .cache_config .num_gpu_blocks = num_gpu_blocks
130153 self .cache_config .num_cpu_blocks = num_cpu_blocks
131154
132- def init_device (self ):
155+ def init_device (self ,
156+ tpu_process_bounds = "" ,
157+ tpu_chips_per_process_bounds = "" ,
158+ tpu_visible_chips = "" ):
133159 # set tpu visible devices for Jax runtime in single host PP.
134160 multihost_backend = os .environ .get ("TPU_MULTIHOST_BACKEND" , "" ).lower ()
135161 if multihost_backend != "ray" and self .parallel_config .pipeline_parallel_size > 1 :
136- # Note: Below is the setting for v6e8 host (8 chips of v6e)
137- # There are 2 ways of subslicing a v6e:
138- # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
139- # 2) 1 chip for each subslice, with at most 8 subslices,
140- # we can do TP=1, PP=1/2/3/4/5/6/7/8
141- # Replace with your own topology.
142-
143162 tpu_ports = [
144163 jax_parallel_state .BASE_JAX_PORT + i
145164 for i in range (self .pp_world_size )
@@ -149,23 +168,38 @@ def init_device(self):
149168 os .environ ["TPU_PROCESS_PORT" ] = f"{ tpu_ports [self .rank ]} "
150169 os .environ ["CLOUD_TPU_TASK_ID" ] = f"{ self .rank } "
151170
152- # first way of subslicing.
153- # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
154- # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
155- # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"
156-
157- # second way of subslicing.
158- os .environ ["TPU_PROCESS_BOUNDS" ] = f"1,{ self .pp_world_size } ,1"
159- os .environ ["TPU_CHIPS_PER_PROCESS_BOUNDS" ] = "1,1,1"
160- os .environ ["TPU_VISIBLE_CHIPS" ] = f"{ self .rank } "
171+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172+ # Replace with your own topology.
173+ # There are 2 ways of subslicing a v6e
174+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175+ # TPU_PROCESS_BOUNDS = "1,1,1"
176+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178+ # 2) 1 chip for each subslice, with at most 8 subslices,
179+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180+ os .environ [
181+ "TPU_PROCESS_BOUNDS" ] = tpu_process_bounds \
182+ if tpu_process_bounds \
183+ else self .pp_config .default_tpu_process_bounds
184+ os .environ [
185+ "TPU_CHIPS_PER_PROCESS_BOUNDS" ] = tpu_chips_per_process_bounds \
186+ if tpu_chips_per_process_bounds \
187+ else self .pp_config .default_tpu_chips_per_process_bounds
188+ os .environ [
189+ "TPU_VISIBLE_CHIPS" ] = tpu_visible_chips \
190+ if tpu_visible_chips \
191+ else self .pp_config .default_tpu_visible_chips
161192
162193 if not self .devices :
163194 sharding_config : ShardingConfigManager = self .vllm_config .sharding_config
164195 device_indexes = sharding_config .device_indexes
165196 if device_indexes is not None and len (device_indexes ) > 0 :
166197 # Enforcing the devices sequence to be consistent with the specified device indexes
167198 all_local_devices = jax .local_devices ()
168- device_dict = {device .id : device for device in all_local_devices }
199+ device_dict = {
200+ device .id : device
201+ for device in all_local_devices
202+ }
169203 self .devices = []
170204 for device_index in device_indexes :
171205 device = device_dict [device_index ]
@@ -178,7 +212,8 @@ def init_device(self):
178212 assert len (self .devices ) >= sharding_config .total_devices
179213 self .devices = self .devices [:sharding_config .total_devices ]
180214 else :
181- assert jax .local_device_count () >= sharding_config .total_devices
215+ assert jax .local_device_count (
216+ ) >= sharding_config .total_devices
182217 self .devices = jax .local_devices ()[:sharding_config .
183218 total_devices ]
184219 # Initialize the vLLM distribution layer as a single chip environment,
@@ -198,7 +233,7 @@ def init_device(self):
198233 )
199234
200235 jax_parallel_state .init_pp_distributed_environment (
201- self .ip ,
236+ self .pp_config . ip ,
202237 self .rank ,
203238 self .parallel_config .pipeline_parallel_size ,
204239 self .devices [0 ],
@@ -218,7 +253,8 @@ def init_device(self):
218253 def initialize_pp_transfer_connect (self ):
219254 if self .rank == 0 :
220255 return
221- jax_parallel_state .connect (self .prev_worker_ip , self .rank - 1 )
256+ jax_parallel_state .connect (self .pp_config .prev_worker_ip ,
257+ self .rank - 1 )
222258
223259 def determine_available_memory (self ) -> int :
224260 gpu_memory_utilization = self .cache_config .gpu_memory_utilization
0 commit comments