Skip to content

Commit 7a9a4ae

Browse files
authored
Centralizes environment variable access by routing variables reads through the envs.py module. (#1102)
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent 75408c1 commit 7a9a4ae

File tree

11 files changed

+43
-31
lines changed

11 files changed

+43
-31
lines changed

tpu_inference/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import os
2-
31
# The environment variables override should be imported before any other
42
# modules to ensure that the environment variables are set before any
53
# other modules are imported.
64
import tpu_inference.env_override # noqa: F401
5+
from tpu_inference import envs
76
from tpu_inference import tpu_info as ti
87
from tpu_inference.logger import init_logger
98

109
logger = init_logger(__name__)
1110

12-
if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11+
if "proxy" in envs.JAX_PLATFORMS:
1312
logger.info("Running vLLM on TPU via Pathways proxy.")
1413
# Must run pathwaysutils.initialize() before any JAX operations
1514
try:

tpu_inference/core/disagg_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import os
43
from typing import Tuple
54

6-
PREFILL_SLICES = 'PREFILL_SLICES'
7-
DECODE_SLICES = 'DECODE_SLICES'
5+
from tpu_inference import envs
86

97

108
def is_disagg_enabled() -> bool:
119
# We triggrer our code path as long as prefill slices are set. This
1210
# allows us to test interleave mode effectively with the code path
1311
# for comparison purposes.
14-
return PREFILL_SLICES in os.environ
12+
return bool(envs.PREFILL_SLICES)
1513

1614

1715
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
4038

4139

4240
def get_prefill_slices() -> Tuple[int, ...]:
43-
if PREFILL_SLICES not in os.environ:
41+
if not envs.PREFILL_SLICES:
4442
return ()
45-
return _parse_slices(os.environ[PREFILL_SLICES])
43+
return _parse_slices(envs.PREFILL_SLICES)
4644

4745

4846
def get_decode_slices() -> Tuple[int, ...]:
49-
if DECODE_SLICES not in os.environ:
47+
if not envs.DECODE_SLICES:
5048
return ()
51-
return _parse_slices(os.environ[DECODE_SLICES])
49+
return _parse_slices(envs.DECODE_SLICES)

tpu_inference/distributed/tpu_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060

6161
import copy
6262
import functools
63-
import os
6463
import threading
6564
import time
6665
from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@
8685
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
8786
from vllm.v1.request import Request
8887

88+
from tpu_inference import envs
8989
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
9090
get_kv_ports,
9191
get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ def __init__(self, vllm_config: VllmConfig):
441441

442442
self.runner: TPUModelRunner = None
443443
self.mesh: Mesh = None
444-
self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445-
"").lower() == "ray"
444+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446445
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447446
# The worker rank is assigned with vLLM's sorting logic, which does not work
448447
# for TPU host topology.

tpu_inference/distributed/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from vllm.utils.network_utils import get_ip
44

5+
from tpu_inference import envs
56
from tpu_inference.logger import init_logger
67

78
logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
1718

1819

1920
def get_kv_ips() -> str:
20-
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21+
if envs.TPU_MULTIHOST_BACKEND == "ray":
2122
num_nodes = len(_NODES_KV_IP_PORT)
2223
ips = []
2324
for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
2829

2930

3031
def get_kv_ports() -> str:
31-
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32+
if envs.TPU_MULTIHOST_BACKEND == "ray":
3233
num_nodes = len(_NODES_KV_IP_PORT)
3334
ports = []
3435
for node_id in range(num_nodes):

tpu_inference/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
environment_variables: dict[str, Callable[[], Any]] = {
2727
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
2828
"JAX_PLATFORMS":
29-
lambda: os.getenv("JAX_PLATFORMS", ""),
29+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
3030
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
3131
"TPU_ACCELERATOR_TYPE":
3232
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),

tpu_inference/layers/vllm/sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.model_executor.layers.vocab_parallel_embedding import (
2020
ParallelLMHead, VocabParallelEmbedding)
2121

22+
from tpu_inference import envs
2223
from tpu_inference.logger import init_logger
2324

2425
P = PartitionSpec
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
211212
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212213
if isinstance(tensor, tuple):
213214
return tuple(_sharded_device_put(t, sharding) for t in tensor)
214-
import os
215-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
215+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
216216
if multihost_backend != "ray":
217217
return jax.device_put(tensor, sharding)
218218

tpu_inference/mock/vllm_envs.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,20 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
189189
return bool(int(value))
190190

191191

192+
def _get_jax_platforms() -> str:
193+
"""Get JAX_PLATFORMS from tpu_inference.envs module.
194+
195+
Returns:
196+
The JAX_PLATFORMS value.
197+
"""
198+
try:
199+
from tpu_inference import envs
200+
return envs.JAX_PLATFORMS
201+
except ImportError:
202+
# Fallback if tpu_inference.envs is not available
203+
return os.getenv("JAX_PLATFORMS", "").lower()
204+
205+
192206
def get_vllm_port() -> Optional[int]:
193207
"""Get the port from VLLM_PORT environment variable.
194208
@@ -941,7 +955,7 @@ def get_vllm_port() -> Optional[int]:
941955

942956
# Whether using Pathways
943957
"VLLM_TPU_USING_PATHWAYS":
944-
lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()),
958+
lambda: bool("proxy" in _get_jax_platforms()),
945959

946960
# Allow use of DeepGemm kernels for fused moe ops.
947961
"VLLM_USE_DEEP_GEMM":

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from jax.sharding import PartitionSpec as P
1919
from safetensors import safe_open
2020

21-
from tpu_inference import utils
21+
from tpu_inference import envs, utils
2222
from tpu_inference.logger import init_logger
2323
from tpu_inference.models.jax.utils import file_utils
2424

@@ -421,7 +421,7 @@ def load_hf_weights(vllm_config,
421421
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
422422
# Because multi-threading would cause different JAX processes to load
423423
# different weights at the same time.
424-
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
424+
if envs.TPU_MULTIHOST_BACKEND == "ray":
425425
max_workers = 1
426426
with ThreadPoolExecutor(max_workers=max_workers) as executor:
427427
futures = [

tpu_inference/platforms/tpu_platform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import os
43
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
54

65
import jax.numpy as jnp
@@ -183,7 +182,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
183182
parallel_config.worker_cls = \
184183
"tpu_inference.worker.tpu_worker.TPUWorker"
185184

186-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
185+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
187186
if not multihost_backend: # Single host
188187
if parallel_config.pipeline_parallel_size == 1:
189188
logger.info("Force using UniProcExecutor for JAX on \

tpu_inference/tpu_info.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import requests
55

6+
from tpu_inference import envs
67
from tpu_inference.logger import init_logger
78

89
logger = init_logger(__name__)
@@ -32,22 +33,22 @@ def get_tpu_metadata(key: str = "") -> str:
3233

3334

3435
def get_tpu_type() -> str:
35-
tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36+
tpu_type = envs.TPU_ACCELERATOR_TYPE
3637
if tpu_type is None:
3738
tpu_type = get_tpu_metadata(key="accelerator-type")
3839
return tpu_type
3940

4041

4142
def get_node_name() -> str:
42-
tpu_name = os.getenv("TPU_NAME", None)
43+
tpu_name = envs.TPU_NAME
4344
if not tpu_name:
4445
tpu_name = get_tpu_metadata(key="instance-id")
4546
return tpu_name
4647

4748

4849
def get_node_worker_id() -> int:
4950
"""For multi-host TPU VM, this returns the worker id for the current node."""
50-
worker_id = os.getenv("TPU_WORKER_ID", None)
51+
worker_id = envs.TPU_WORKER_ID
5152
if worker_id is None:
5253
worker_id = get_tpu_metadata(key="agent-worker-number")
5354
if worker_id is None:

0 commit comments

Comments
 (0)