From 8502111954626dc3146b9b12225b52b11320292c Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 14 Nov 2025 07:05:17 +0000 Subject: [PATCH 1/7] Centralize environment variable access through envs.py Signed-off-by: Xing Liu --- tpu_inference/__init__.py | 4 ++-- tpu_inference/distributed/tpu_connector.py | 4 ++-- tpu_inference/distributed/utils.py | 5 +++-- tpu_inference/envs.py | 2 +- tpu_inference/layers/vllm/sharding.py | 4 ++-- tpu_inference/mock/vllm_envs.py | 16 +++++++++++++++- tpu_inference/models/jax/utils/weight_utils.py | 3 ++- tpu_inference/platforms/tpu_platform.py | 2 +- tpu_inference/tpu_info.py | 7 ++++--- tpu_inference/utils.py | 7 ++++--- 10 files changed, 36 insertions(+), 18 deletions(-) diff --git a/tpu_inference/__init__.py b/tpu_inference/__init__.py index d10311cb0..05a0d51c3 100644 --- a/tpu_inference/__init__.py +++ b/tpu_inference/__init__.py @@ -4,12 +4,12 @@ # modules to ensure that the environment variables are set before any # other modules are imported. import tpu_inference.env_override # noqa: F401 -from tpu_inference import tpu_info as ti +from tpu_inference import envs, tpu_info as ti from tpu_inference.logger import init_logger logger = init_logger(__name__) -if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower(): +if "proxy" in envs.JAX_PLATFORMS: logger.info("Running vLLM on TPU via Pathways proxy.") # Must run pathwaysutils.initialize() before any JAX operations try: diff --git a/tpu_inference/distributed/tpu_connector.py b/tpu_inference/distributed/tpu_connector.py index 66a50b26a..b5df18828 100644 --- a/tpu_inference/distributed/tpu_connector.py +++ b/tpu_inference/distributed/tpu_connector.py @@ -441,8 +441,8 @@ def __init__(self, vllm_config: VllmConfig): self.runner: TPUModelRunner = None self.mesh: Mesh = None - self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND", - "").lower() == "ray" + from tpu_inference import envs + self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray" # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor. # The worker rank is assigned with vLLM's sorting logic, which does not work # for TPU host topology. diff --git a/tpu_inference/distributed/utils.py b/tpu_inference/distributed/utils.py index cf1a0b966..61dde5e60 100644 --- a/tpu_inference/distributed/utils.py +++ b/tpu_inference/distributed/utils.py @@ -2,6 +2,7 @@ from vllm.utils.network_utils import get_ip +from tpu_inference import envs from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]): def get_kv_ips() -> str: - if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + if envs.TPU_MULTIHOST_BACKEND == "ray": num_nodes = len(_NODES_KV_IP_PORT) ips = [] for node_id in range(num_nodes): @@ -28,7 +29,7 @@ def get_kv_ips() -> str: def get_kv_ports() -> str: - if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + if envs.TPU_MULTIHOST_BACKEND == "ray": num_nodes = len(_NODES_KV_IP_PORT) ports = [] for node_id in range(num_nodes): diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 1ef212f00..e97993204 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -26,7 +26,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # JAX platform selection (e.g., "tpu", "cpu", "proxy") "JAX_PLATFORMS": - lambda: os.getenv("JAX_PLATFORMS", ""), + lambda: os.getenv("JAX_PLATFORMS", "").lower(), # TPU accelerator type (e.g., "v5litepod-16", "v4-8") "TPU_ACCELERATOR_TYPE": lambda: os.getenv("TPU_ACCELERATOR_TYPE", None), diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index b06f8b35f..396c1658e 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -211,8 +211,8 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None: def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array: if isinstance(tensor, tuple): return tuple(_sharded_device_put(t, sharding) for t in tensor) - import os - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + from tpu_inference import envs + multihost_backend = envs.TPU_MULTIHOST_BACKEND if multihost_backend != "ray": return jax.device_put(tensor, sharding) diff --git a/tpu_inference/mock/vllm_envs.py b/tpu_inference/mock/vllm_envs.py index 1a938002a..b9f7c10c0 100644 --- a/tpu_inference/mock/vllm_envs.py +++ b/tpu_inference/mock/vllm_envs.py @@ -189,6 +189,20 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: return bool(int(value)) +def _get_jax_platforms() -> str: + """Get JAX_PLATFORMS from tpu_inference.envs module. + + Returns: + The JAX_PLATFORMS value. + """ + try: + from tpu_inference import envs + return envs.JAX_PLATFORMS + except ImportError: + # Fallback if tpu_inference.envs is not available + return os.getenv("JAX_PLATFORMS", "").lower() + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -941,7 +955,7 @@ def get_vllm_port() -> Optional[int]: # Whether using Pathways "VLLM_TPU_USING_PATHWAYS": - lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()), + lambda: bool("proxy" in _get_jax_platforms()), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..0ef4e7e2b 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -421,7 +421,8 @@ def load_hf_weights(vllm_config, # NOTE(xiang): Disable multi-threading mode if running on multi-host. # Because multi-threading would cause different JAX processes to load # different weights at the same time. - if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + from tpu_inference import envs + if envs.TPU_MULTIHOST_BACKEND == "ray": max_workers = 1 with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index e23d4f7e8..f2be115d0 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -183,7 +183,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = \ "tpu_inference.worker.tpu_worker.TPUWorker" - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + multihost_backend = envs.TPU_MULTIHOST_BACKEND if not multihost_backend: # Single host if parallel_config.pipeline_parallel_size == 1: logger.info("Force using UniProcExecutor for JAX on \ diff --git a/tpu_inference/tpu_info.py b/tpu_inference/tpu_info.py index 9f5d02269..41b1a7e21 100644 --- a/tpu_inference/tpu_info.py +++ b/tpu_inference/tpu_info.py @@ -3,6 +3,7 @@ import requests +from tpu_inference import envs from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str: def get_tpu_type() -> str: - tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None) + tpu_type = envs.TPU_ACCELERATOR_TYPE if tpu_type is None: tpu_type = get_tpu_metadata(key="accelerator-type") return tpu_type def get_node_name() -> str: - tpu_name = os.getenv("TPU_NAME", None) + tpu_name = envs.TPU_NAME if not tpu_name: tpu_name = get_tpu_metadata(key="instance-id") return tpu_name @@ -47,7 +48,7 @@ def get_node_name() -> str: def get_node_worker_id() -> int: """For multi-host TPU VM, this returns the worker id for the current node.""" - worker_id = os.getenv("TPU_WORKER_ID", None) + worker_id = envs.TPU_WORKER_ID if worker_id is None: worker_id = get_tpu_metadata(key="agent-worker-number") if worker_id is None: diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index ea9edd20a..4fdb835fe 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -14,8 +14,9 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax.sharding import Mesh, NamedSharding, PartitionSpec -from vllm import envs, utils +from vllm import envs as vllm_envs, utils +from tpu_inference import envs from tpu_inference.logger import init_logger GBYTES = 1024 * 1024 * 1024 @@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int: def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]: usage = [] - if envs.VLLM_TPU_USING_PATHWAYS: + if vllm_envs.VLLM_TPU_USING_PATHWAYS: return pathways_hbm_usage_gb(devices) - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + multihost_backend = envs.TPU_MULTIHOST_BACKEND if multihost_backend == "ray": # MemoryStats is only supported for addressable PjRt devices. # Assume all the devices have similar memory usage for now. From af6727e39a2fc3878ddfdeb89fda5db9e526024e Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 14 Nov 2025 07:09:53 +0000 Subject: [PATCH 2/7] Use PREFILL_SLICES and DECODE_SLICES from envs.py in disagg_utils Signed-off-by: Xing Liu --- tpu_inference/core/disagg_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tpu_inference/core/disagg_utils.py b/tpu_inference/core/disagg_utils.py index 58528b8ad..25c40dd10 100644 --- a/tpu_inference/core/disagg_utils.py +++ b/tpu_inference/core/disagg_utils.py @@ -3,6 +3,8 @@ import os from typing import Tuple +from tpu_inference import envs + PREFILL_SLICES = 'PREFILL_SLICES' DECODE_SLICES = 'DECODE_SLICES' @@ -11,7 +13,7 @@ def is_disagg_enabled() -> bool: # We triggrer our code path as long as prefill slices are set. This # allows us to test interleave mode effectively with the code path # for comparison purposes. - return PREFILL_SLICES in os.environ + return bool(envs.PREFILL_SLICES) def _parse_slices(slices_str: str) -> Tuple[int, ...]: @@ -40,12 +42,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]: def get_prefill_slices() -> Tuple[int, ...]: - if PREFILL_SLICES not in os.environ: + if not envs.PREFILL_SLICES: return () - return _parse_slices(os.environ[PREFILL_SLICES]) + return _parse_slices(envs.PREFILL_SLICES) def get_decode_slices() -> Tuple[int, ...]: - if DECODE_SLICES not in os.environ: + if not envs.DECODE_SLICES: return () - return _parse_slices(os.environ[DECODE_SLICES]) + return _parse_slices(envs.DECODE_SLICES) From f97c119635ac0214b7b9ab7a099a11f259bab04f Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 14 Nov 2025 07:37:53 +0000 Subject: [PATCH 3/7] Remove unused os import from __init__.py Signed-off-by: Xing Liu --- tpu_inference/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tpu_inference/__init__.py b/tpu_inference/__init__.py index 05a0d51c3..e34016cf7 100644 --- a/tpu_inference/__init__.py +++ b/tpu_inference/__init__.py @@ -1,5 +1,3 @@ -import os - # The environment variables override should be imported before any other # modules to ensure that the environment variables are set before any # other modules are imported. From 521ba394df4c598bba1daf80d931e3e5f55f506c Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 14 Nov 2025 07:41:27 +0000 Subject: [PATCH 4/7] Fix linting issues: isort, ruff, and trailing whitespace Signed-off-by: Xing Liu --- tpu_inference/__init__.py | 3 ++- tpu_inference/mock/vllm_envs.py | 2 +- tpu_inference/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tpu_inference/__init__.py b/tpu_inference/__init__.py index e34016cf7..09d2fcdd7 100644 --- a/tpu_inference/__init__.py +++ b/tpu_inference/__init__.py @@ -2,7 +2,8 @@ # modules to ensure that the environment variables are set before any # other modules are imported. import tpu_inference.env_override # noqa: F401 -from tpu_inference import envs, tpu_info as ti +from tpu_inference import envs +from tpu_inference import tpu_info as ti from tpu_inference.logger import init_logger logger = init_logger(__name__) diff --git a/tpu_inference/mock/vllm_envs.py b/tpu_inference/mock/vllm_envs.py index b9f7c10c0..476643579 100644 --- a/tpu_inference/mock/vllm_envs.py +++ b/tpu_inference/mock/vllm_envs.py @@ -191,7 +191,7 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: def _get_jax_platforms() -> str: """Get JAX_PLATFORMS from tpu_inference.envs module. - + Returns: The JAX_PLATFORMS value. """ diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index 4fdb835fe..ca3d693da 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import os import time from collections import defaultdict from collections.abc import Sequence @@ -14,7 +13,8 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax.sharding import Mesh, NamedSharding, PartitionSpec -from vllm import envs as vllm_envs, utils +from vllm import envs as vllm_envs +from vllm import utils from tpu_inference import envs from tpu_inference.logger import init_logger From 8d4854e6922bd579a5936f8c7376653a68569371 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 14 Nov 2025 07:44:05 +0000 Subject: [PATCH 5/7] Remove unused os imports identified by ruff Signed-off-by: Xing Liu --- tpu_inference/core/disagg_utils.py | 1 - tpu_inference/distributed/tpu_connector.py | 1 - tpu_inference/platforms/tpu_platform.py | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tpu_inference/core/disagg_utils.py b/tpu_inference/core/disagg_utils.py index 25c40dd10..4db0c6d11 100644 --- a/tpu_inference/core/disagg_utils.py +++ b/tpu_inference/core/disagg_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import os from typing import Tuple from tpu_inference import envs diff --git a/tpu_inference/distributed/tpu_connector.py b/tpu_inference/distributed/tpu_connector.py index b5df18828..f259b3f0e 100644 --- a/tpu_inference/distributed/tpu_connector.py +++ b/tpu_inference/distributed/tpu_connector.py @@ -60,7 +60,6 @@ import copy import functools -import os import threading import time from concurrent.futures import Future, ThreadPoolExecutor diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index f2be115d0..3ada517d0 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import os -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import jax.numpy as jnp import vllm.envs as vllm_envs From 3d23c0c542f2165b5ec5a3c07b235e693b79564e Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Sat, 15 Nov 2025 08:40:18 +0000 Subject: [PATCH 6/7] Add missing Any import to tpu_platform.py Signed-off-by: Xing Liu --- tpu_inference/platforms/tpu_platform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 3ada517d0..b3a4a7de3 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast import jax.numpy as jnp import vllm.envs as vllm_envs From 3d6c285939ef71c9a192dde11970838eaa79ceef Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Mon, 17 Nov 2025 20:13:33 +0000 Subject: [PATCH 7/7] fixes Signed-off-by: Xing Liu --- tpu_inference/core/disagg_utils.py | 3 --- tpu_inference/distributed/tpu_connector.py | 2 +- tpu_inference/layers/vllm/sharding.py | 2 +- tpu_inference/models/jax/utils/weight_utils.py | 3 +-- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tpu_inference/core/disagg_utils.py b/tpu_inference/core/disagg_utils.py index 4db0c6d11..ecb16e9ac 100644 --- a/tpu_inference/core/disagg_utils.py +++ b/tpu_inference/core/disagg_utils.py @@ -4,9 +4,6 @@ from tpu_inference import envs -PREFILL_SLICES = 'PREFILL_SLICES' -DECODE_SLICES = 'DECODE_SLICES' - def is_disagg_enabled() -> bool: # We triggrer our code path as long as prefill slices are set. This diff --git a/tpu_inference/distributed/tpu_connector.py b/tpu_inference/distributed/tpu_connector.py index f259b3f0e..cf09dcea7 100644 --- a/tpu_inference/distributed/tpu_connector.py +++ b/tpu_inference/distributed/tpu_connector.py @@ -85,6 +85,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +from tpu_inference import envs from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips, get_kv_ports, get_kv_transfer_port, get_node_id, @@ -440,7 +441,6 @@ def __init__(self, vllm_config: VllmConfig): self.runner: TPUModelRunner = None self.mesh: Mesh = None - from tpu_inference import envs self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray" # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor. # The worker rank is assigned with vLLM's sorting logic, which does not work diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index 396c1658e..b9fd4fdd9 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from tpu_inference import envs from tpu_inference.logger import init_logger P = PartitionSpec @@ -211,7 +212,6 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None: def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array: if isinstance(tensor, tuple): return tuple(_sharded_device_put(t, sharding) for t in tensor) - from tpu_inference import envs multihost_backend = envs.TPU_MULTIHOST_BACKEND if multihost_backend != "ray": return jax.device_put(tensor, sharding) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 0ef4e7e2b..64730748f 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -18,7 +18,7 @@ from jax.sharding import PartitionSpec as P from safetensors import safe_open -from tpu_inference import utils +from tpu_inference import envs, utils from tpu_inference.logger import init_logger from tpu_inference.models.jax.utils import file_utils @@ -421,7 +421,6 @@ def load_hf_weights(vllm_config, # NOTE(xiang): Disable multi-threading mode if running on multi-host. # Because multi-threading would cause different JAX processes to load # different weights at the same time. - from tpu_inference import envs if envs.TPU_MULTIHOST_BACKEND == "ray": max_workers = 1 with ThreadPoolExecutor(max_workers=max_workers) as executor: