From 1fa5e73234b3950dd7f25d40c253b151763e95d5 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Sat, 8 Nov 2025 01:36:51 +0000 Subject: [PATCH] enable pp on tpu jax platform Signed-off-by: Chenyaaang --- tpu_inference/platforms/tpu_platform.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index c19d67352..e2eb23e1a 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -184,8 +184,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() if not multihost_backend: # Single host - logger.info("Force using UniProcExecutor for JAX on single host.") - parallel_config.distributed_executor_backend = "uni" + if parallel_config.pipeline_parallel_size == 1: + logger.info("Force using UniProcExecutor for JAX on \ + single host without pipeline parallelism.") + parallel_config.distributed_executor_backend = "uni" + else: + logger.info("Force using MultiprocExecutor for JAX on \ + single host with pipeline parallelism.") + parallel_config.distributed_executor_backend = "mp" elif multihost_backend == "ray": from tpu_inference.executors.ray_distributed_executor import \ RayDistributedExecutor