1010import jaxtyping
1111import vllm .envs as vllm_envs
1212from vllm .config import VllmConfig , set_current_vllm_config
13+ from vllm .distributed import get_pp_group
1314from vllm .distributed .kv_transfer import (ensure_kv_transfer_initialized ,
1415 has_kv_transfer_group )
1516from vllm .distributed .parallel_state import (ensure_model_parallel_initialized ,
2324from vllm .v1 .outputs import DraftTokenIds , ModelRunnerOutput
2425
2526from tpu_inference import envs , utils
27+ from tpu_inference .distributed import jax_parallel_state
2628from tpu_inference .distributed .utils import (get_host_ip , get_kv_transfer_port ,
2729 get_node_id )
2830from tpu_inference .layers .common .sharding import ShardingConfigManager
2931from tpu_inference .logger import init_logger
32+ from tpu_inference .models .jax .jax_intermediate_tensor import \
33+ JaxIntermediateTensors
3034from tpu_inference .runner .kv_cache import get_rpa_page_size_bytes
3135from tpu_inference .runner .tpu_runner import TPUModelRunner
3236
4145
4246class TPUWorker :
4347
44- def __init__ (self ,
45- vllm_config : VllmConfig ,
46- local_rank : int ,
47- rank : int ,
48- distributed_init_method : str ,
49- is_driver_worker : bool = False ,
50- devices = None ):
48+ def __init__ (
49+ self ,
50+ vllm_config : VllmConfig ,
51+ local_rank : int ,
52+ rank : int ,
53+ distributed_init_method : str ,
54+ is_driver_worker : bool = False ,
55+ devices = None ,
56+ ip : str = "localhost" ,
57+ prev_worker_ip : str = "localhost" ,
58+ ):
5159 # If we use vLLM's model implementation in PyTorch, we should set it
5260 # with torch version of the dtype.
5361 impl = envs .MODEL_IMPL_TYPE
@@ -74,6 +82,9 @@ def __init__(self,
7482 self .devices = devices if devices is not None else []
7583 self .device_ranks = set (device .id for device in self .devices
7684 if isinstance (device , jaxlib ._jax .Device ))
85+ self .ip = ip
86+ self .prev_worker_ip = prev_worker_ip
87+ self .pp_world_size = self .parallel_config .pipeline_parallel_size
7788
7889 if self .model_config .trust_remote_code :
7990 # note: lazy import to avoid importing torch before initializing
@@ -86,14 +97,20 @@ def __init__(self,
8697 # TPU Worker is initialized. The profiler server needs to start after
8798 # MP runtime is initialized.
8899 self .profile_dir = None
89- if vllm_envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 :
100+ if envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 and self . pp_world_size == 1 :
90101 if not self .devices or 0 in self .device_ranks :
91102 # For TPU, we can only have 1 active profiler session for 1 profiler
92103 # server. So we only profile on rank0.
93104 self .profile_dir = vllm_envs .VLLM_TORCH_PROFILER_DIR
94105 logger .info ("Profiling enabled. Traces will be saved to: %s" ,
95106 self .profile_dir )
96107
108+ # For PP, we use MPMD so we want to profile every worker.
109+ if self .pp_world_size > 1 and envs .VLLM_TORCH_PROFILER_DIR :
110+ self .profile_dir = os .path .join (envs .VLLM_TORCH_PROFILER_DIR ,
111+ f"rank_{ self .rank } " )
112+ os .makedirs (self .profile_dir , exist_ok = True )
113+
97114 use_jax_profiler_server = os .getenv ("USE_JAX_PROFILER_SERVER" , False )
98115 # Only one instance of profiler is allowed
99116 if use_jax_profiler_server and self .rank < 1 :
@@ -105,18 +122,49 @@ def __init__(self,
105122 )
106123 jax .profiler .start_server (jax_profiler_server_port )
107124
125+ self .step_counter = 0
126+
108127 def initialize_cache (self , num_gpu_blocks : int ,
109128 num_cpu_blocks : int ) -> None :
110129 self .cache_config .num_gpu_blocks = num_gpu_blocks
111130 self .cache_config .num_cpu_blocks = num_cpu_blocks
112131
113132 def init_device (self ):
133+ # set tpu visible devices for Jax runtime in single host PP.
134+ multihost_backend = os .environ .get ("TPU_MULTIHOST_BACKEND" , "" ).lower ()
135+ if multihost_backend != "ray" and self .parallel_config .pipeline_parallel_size > 1 :
136+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
137+ # There are 2 ways of subslicing a v6e:
138+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
139+ # 2) 1 chip for each subslice, with at most 8 subslices,
140+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
141+ # Replace with your own topology.
142+
143+ tpu_ports = [
144+ jax_parallel_state .BASE_JAX_PORT + i
145+ for i in range (self .pp_world_size )
146+ ]
147+ os .environ ["TPU_PROCESS_ADDRESSES" ] = "," .join (
148+ [f"localhost:{ port } " for port in tpu_ports ])
149+ os .environ ["TPU_PROCESS_PORT" ] = f"{ tpu_ports [self .rank ]} "
150+ os .environ ["CLOUD_TPU_TASK_ID" ] = f"{ self .rank } "
151+
152+ # first way of subslicing.
153+ # os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
154+ # os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = f"1,4,1"
155+ # os.environ["TPU_VISIBLE_CHIPS"] = "0,1,2,3" if self.rank == 0 else "4,5,6,7"
156+
157+ # second way of subslicing.
158+ os .environ ["TPU_PROCESS_BOUNDS" ] = f"1,{ self .pp_world_size } ,1"
159+ os .environ ["TPU_CHIPS_PER_PROCESS_BOUNDS" ] = "1,1,1"
160+ os .environ ["TPU_VISIBLE_CHIPS" ] = f"{ self .rank } "
161+
114162 if not self .devices :
115163 sharding_config : ShardingConfigManager = self .vllm_config .sharding_config
116164 device_indexes = sharding_config .device_indexes
117165 if device_indexes is not None and len (device_indexes ) > 0 :
118166 # Enforcing the devices sequence to be consistent with the specified device indexes
119- all_devices = jax .devices ()
167+ all_devices = jax .local_devices ()
120168 device_dict = {device .id : device for device in all_devices }
121169 self .devices = []
122170 for device_index in device_indexes :
@@ -127,10 +175,12 @@ def init_device(self):
127175 f"jax.devices() with IDs { list (device_dict .keys ())} !"
128176 )
129177 self .devices .append (device )
178+ assert len (self .devices ) >= sharding_config .total_devices
130179 self .devices = self .devices [:sharding_config .total_devices ]
131180 else :
132- self .devices = jax .devices ()[:sharding_config .total_devices ]
133-
181+ assert jax .local_device_count () >= sharding_config .total_devices
182+ self .devices = jax .local_devices ()[:sharding_config .
183+ total_devices ]
134184 # Initialize the vLLM distribution layer as a single chip environment,
135185 # we'll swap the model's parallel modules with TPU SPMD equivalents.
136186 with set_current_vllm_config (self .vllm_config ):
@@ -146,15 +196,30 @@ def init_device(self):
146196 tensor_model_parallel_size = 1 ,
147197 pipeline_model_parallel_size = 1 ,
148198 )
199+
200+ jax_parallel_state .init_pp_distributed_environment (
201+ self .ip ,
202+ self .rank ,
203+ self .parallel_config .pipeline_parallel_size ,
204+ self .devices [0 ],
205+ need_pp = self .parallel_config .pipeline_parallel_size > 1 )
206+
149207 ensure_kv_transfer_initialized (self .vllm_config )
150- self .model_runner = TPUModelRunner (self .vllm_config , self .devices )
208+ self .model_runner = TPUModelRunner (self .vllm_config , self .devices ,
209+ self .rank , self .rank == 0 ,
210+ self .rank == self .pp_world_size - 1 )
151211 logger .info (f"Init worker | "
152212 f"rank={ self .rank } | "
153213 f"node_id={ get_node_id ()} | "
154214 f"is_driver_worker={ self .is_driver_worker } | "
155215 f"hbm={ utils .hbm_usage_gb (self .devices )} GiB" )
156216 vllm_utils .report_usage_stats (self .vllm_config )
157217
218+ def initialize_pp_transfer_connect (self ):
219+ if self .rank == 0 :
220+ return
221+ jax_parallel_state .connect (self .prev_worker_ip , self .rank - 1 )
222+
158223 def determine_available_memory (self ) -> int :
159224 gpu_memory_utilization = self .cache_config .gpu_memory_utilization
160225 hbm_usage = utils .hbm_usage_bytes (self .devices )
@@ -194,14 +259,39 @@ def execute_model(
194259 # deliberate, temporary compromise for the same reasons outlined in
195260 # the `get_kv_cache_spec` method.
196261
197- output = self .model_runner .execute_model (scheduler_output )
198-
199- # With a connector, the scheduler expects output from all workers
200- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201- if has_kv_transfer_group ():
202- return output
203-
204- return output if self .is_driver_worker else None
262+ if self .parallel_config .pipeline_parallel_size == 1 or self .rank == 0 :
263+ intermediate_tensors = None
264+ else :
265+ # receive intermediate tensors
266+ uuid = self .model_runner .get_uuid_for_jax_transfer (
267+ scheduler_output , self .rank - 1 , self .step_counter )
268+ # TODO: this method might only works for vllm model, not sure about jax models.
269+ tensor_spec = self .model_runner .get_intermediate_tensor_spec (
270+ scheduler_output .total_num_scheduled_tokens )
271+ intermediate_tensors_dict = get_pp_group ().recv_tensor_dict (
272+ uuid , tensor_spec )
273+ intermediate_tensors = JaxIntermediateTensors (
274+ intermediate_tensors_dict )
275+
276+ output = self .model_runner .execute_model (scheduler_output ,
277+ intermediate_tensors )
278+
279+ if isinstance (output , JaxIntermediateTensors ):
280+ assert self .parallel_config .pipeline_parallel_size > 1
281+ assert not get_pp_group ().is_last_rank
282+ # send intermediate tensors
283+ uuid = self .model_runner .get_uuid_for_jax_transfer (
284+ scheduler_output , self .rank , self .step_counter )
285+ get_pp_group ().send_tensor_dict (uuid , output .tensors )
286+ self .step_counter += 1
287+ return None
288+ else :
289+ self .step_counter += 1
290+ # With a connector, the scheduler expects output from all workers
291+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
292+ if has_kv_transfer_group ():
293+ return output
294+ return output if self .is_driver_worker else None
205295
206296 def sample_tokens (self ,
207297 grammar_output : GrammarOutput ) -> ModelRunnerOutput :
0 commit comments