Skip to content

Commit 28607fc

Browse files
authored
Centralizes environment variable access by routing variables reads through the envs.py module (#1147)
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent f362289 commit 28607fc

File tree

11 files changed

+64
-33
lines changed

11 files changed

+64
-33
lines changed

examples/offline_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import os
55

6-
import vllm.envs as envs
6+
import vllm.envs as vllm_envs
77
from vllm import LLM, EngineArgs
88
from vllm.utils.argparse_utils import FlexibleArgumentParser
99

@@ -87,10 +87,10 @@ def main(args: dict):
8787
'Who wrote the novel "Pride and Prejudice"?',
8888
]
8989

90-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
90+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
9191
llm.start_profile()
9292
outputs = llm.generate(prompts, sampling_params)
93-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
93+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
9494
llm.stop_profile()
9595

9696
# Print the outputs.

examples/offline_lora_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import time
66

7-
import vllm.envs as envs
7+
import vllm.envs as vllm_envs
88
from vllm import LLM, EngineArgs
99
from vllm.lora.request import LoRARequest
1010
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -55,13 +55,13 @@ def main(args: dict):
5555
"lora_adapter_3", 3,
5656
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_3_adapter")
5757

58-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
58+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
5959
llm.start_profile()
6060
start = time.perf_counter()
6161
outputs = llm.generate(prompt,
6262
sampling_params=sampling_params,
6363
lora_request=lora_request)
64-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
64+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
6565
llm.stop_profile()
6666

6767
# Print the outputs.

examples/offline_safety_model_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import os
2222

23-
import vllm.envs as envs
23+
import vllm.envs as vllm_envs
2424
from vllm import LLM, EngineArgs
2525
from vllm.utils.argparse_utils import FlexibleArgumentParser
2626

@@ -170,7 +170,7 @@ def main(args: dict):
170170

171171
prompts.append(TokensPrompt(prompt_token_ids=tokenized_prompt))
172172

173-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
173+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
174174
llm.start_profile()
175175

176176
outputs = llm.generate(
@@ -179,7 +179,7 @@ def main(args: dict):
179179
use_tqdm=True,
180180
)
181181

182-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
182+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
183183
llm.stop_profile()
184184

185185
passed_tests = 0

tests/test_envs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,26 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
5656

5757

5858
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59+
# Ensure clean environment for boolean vars by setting to default "0"
60+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
64+
5965
# Test SKIP_JAX_PRECOMPILE (default False)
6066
assert envs.SKIP_JAX_PRECOMPILE is False
6167
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
6268
assert envs.SKIP_JAX_PRECOMPILE is True
6369
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
6470
assert envs.SKIP_JAX_PRECOMPILE is False
6571

72+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
73+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
74+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
75+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
76+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
77+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
78+
6679
# Test NEW_MODEL_DESIGN (default False)
6780
assert envs.NEW_MODEL_DESIGN is False
6881
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -75,12 +88,23 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
7588

7689

7790
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91+
# Ensure clean environment for integer vars by setting to defaults
92+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
93+
monkeypatch.setenv("NUM_SLICES", "1")
94+
7895
assert envs.PYTHON_TRACER_LEVEL == 1
7996
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
8097
assert envs.PYTHON_TRACER_LEVEL == 3
8198
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
8299
assert envs.PYTHON_TRACER_LEVEL == 0
83100

101+
# Test NUM_SLICES (default 1)
102+
assert envs.NUM_SLICES == 1
103+
monkeypatch.setenv("NUM_SLICES", "2")
104+
assert envs.NUM_SLICES == 2
105+
monkeypatch.setenv("NUM_SLICES", "4")
106+
assert envs.NUM_SLICES == 4
107+
84108

85109
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86110
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
@@ -134,6 +158,7 @@ def test_dir_returns_all_env_vars():
134158
assert "JAX_PLATFORMS" in env_vars
135159
assert "TPU_NAME" in env_vars
136160
assert "SKIP_JAX_PRECOMPILE" in env_vars
161+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
137162
assert "MODEL_IMPL_TYPE" in env_vars
138163

139164

tests/worker/tpu_worker_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_profile_start(self, mock_jax, mock_vllm_config):
294294
args, kwargs = mock_jax.profiler.start_trace.call_args
295295
assert args[0] == "/tmp/profile_dir"
296296
# Verify options from env var were used
297-
assert kwargs['profiler_options'].python_tracer_level == '1'
297+
assert kwargs['profiler_options'].python_tracer_level == 1
298298

299299
@patch('tpu_inference.worker.tpu_worker.jax')
300300
def test_profile_stop(self, mock_jax, mock_vllm_config):

tpu_inference/envs.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
PREFILL_SLICES: str = ""
1616
DECODE_SLICES: str = ""
1717
SKIP_JAX_PRECOMPILE: bool = False
18+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
1819
MODEL_IMPL_TYPE: str = "flax_nnx"
1920
NEW_MODEL_DESIGN: bool = False
2021
PHASED_PROFILING_DIR: str = ""
2122
PYTHON_TRACER_LEVEL: int = 1
2223
USE_MOE_EP_KERNEL: bool = False
24+
NUM_SLICES: int = 1
2325
RAY_USAGE_STATS_ENABLED: str = "0"
2426
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
2527

@@ -47,22 +49,28 @@
4749
lambda: os.getenv("DECODE_SLICES", ""),
4850
# Skip JAX precompilation step during initialization
4951
"SKIP_JAX_PRECOMPILE":
50-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
52+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
53+
# Check for XLA recompilation during execution
54+
"VLLM_XLA_CHECK_RECOMPILATION":
55+
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
5156
# Model implementation type (e.g., "flax_nnx")
5257
"MODEL_IMPL_TYPE":
5358
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
5459
# Enable new experimental model design
5560
"NEW_MODEL_DESIGN":
56-
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
61+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
5762
# Directory to store phased profiling output
5863
"PHASED_PROFILING_DIR":
5964
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
6065
# Python tracer level for profiling
6166
"PYTHON_TRACER_LEVEL":
62-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
67+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
6368
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
6469
"USE_MOE_EP_KERNEL":
65-
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
70+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
71+
# Number of TPU slices for multi-slice mesh
72+
"NUM_SLICES":
73+
lambda: int(os.getenv("NUM_SLICES") or "1"),
6674
# Enable/disable Ray usage statistics collection
6775
"RAY_USAGE_STATS_ENABLED":
6876
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),

tpu_inference/layers/common/sharding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import json
22
import math
3-
import os
43
from dataclasses import asdict, dataclass
54
from typing import TYPE_CHECKING, List, Optional
65

76
import jax.numpy as jnp
87
import numpy as np
98
from jax.sharding import Mesh
109

11-
from tpu_inference import utils
10+
from tpu_inference import envs, utils
1211

1312
if TYPE_CHECKING:
1413
from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
4847

4948

5049
try:
51-
_use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
50+
_use_base_sharding = envs.NEW_MODEL_DESIGN
5251
if _use_base_sharding:
5352
ShardingAxisName = ShardingAxisNameBase
5453
else:
@@ -167,7 +166,7 @@ def validate(cls, vllm_config, sharding_strategy):
167166
f"(DP size: {total_dp_size}). Please disable LoRA or "
168167
f"set data parallelism to 1.")
169168
if sharding_strategy.attention_data_parallelism > 1:
170-
if not os.environ.get("NEW_MODEL_DESIGN", False):
169+
if not envs.NEW_MODEL_DESIGN:
171170
raise ValueError(
172171
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
173172
"NEW_MODEL_DESIGN=True.")

tpu_inference/runner/compilation_manager.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import os
21
import time
32
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
43

54
import jax
65
import jax.numpy as jnp
76
import numpy as np
8-
import vllm.envs as envs
7+
import vllm.envs as vllm_envs
98
from jax.sharding import NamedSharding, PartitionSpec
109

10+
import tpu_inference.envs as envs
1111
from tpu_inference.core.disagg_utils import is_disagg_enabled
1212
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1313
from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -30,10 +30,10 @@ class CompilationManager:
3030

3131
def __init__(self, runner: "TPUModelRunner"):
3232
self.runner = runner
33-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
33+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
3434
logger.info("Enabling JAX compile cache.")
3535
jax.config.update("jax_compilation_cache_dir",
36-
envs.VLLM_XLA_CACHE_PATH)
36+
vllm_envs.VLLM_XLA_CACHE_PATH)
3737

3838
def _create_dummy_tensor(self,
3939
shape: Tuple[int, ...],
@@ -67,8 +67,7 @@ def _run_compilation(self, name: str, fn: Callable, *args,
6767
logger.info("Compilation finished in %.2f [secs].", end - start)
6868

6969
def capture_model(self) -> None:
70-
if os.getenv("SKIP_JAX_PRECOMPILE",
71-
False) or self.runner.model_config.enforce_eager:
70+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
7271
return
7372
logger.info("Precompile all the subgraphs with possible input shapes.")
7473

tpu_inference/runner/tpu_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
import random
54
from contextlib import nullcontext
65
from dataclasses import dataclass
@@ -11,7 +10,7 @@
1110
import jaxtyping
1211
import numpy as np
1312
import torch
14-
import vllm.envs as envs
13+
import vllm.envs as vllm_envs
1514
from flax import nnx
1615
from jax.experimental import mesh_utils
1716
from jax.sharding import NamedSharding, PartitionSpec
@@ -35,6 +34,7 @@
3534
KVConnectorModelRunnerMixin
3635
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
3736

37+
import tpu_inference.envs as envs
3838
from tpu_inference import utils as common_utils
3939
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
4040
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -291,7 +291,7 @@ def _init_random(self):
291291
self.rng_key = jax.random.key(self.model_config.seed)
292292

293293
def _init_mesh(self) -> None:
294-
if os.getenv("NEW_MODEL_DESIGN", False):
294+
if envs.NEW_MODEL_DESIGN:
295295
self.mesh = self._create_new_model_mesh()
296296
else:
297297
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -302,7 +302,7 @@ def _init_mesh(self) -> None:
302302
logger.info(f"Init mesh | mesh={self.mesh}")
303303

304304
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
305-
num_slices = int(os.environ.get('NUM_SLICES', 1))
305+
num_slices = envs.NUM_SLICES
306306

307307
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
308308
f"num_slices={num_slices}")
@@ -371,7 +371,7 @@ def _create_2d_mesh(self) -> jax.sharding.Mesh:
371371
devices=self.devices)
372372

373373
def _init_phased_profiling(self) -> None:
374-
self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
374+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
375375
self.phase_based_profiler = None
376376
if self.phased_profiling_dir:
377377
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -413,7 +413,7 @@ def _init_inputs(self) -> None:
413413
min_token_size=max(16, self.dp_size),
414414
max_token_size=scheduler_config.max_num_batched_tokens *
415415
self.dp_size,
416-
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
416+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
417417
self.num_tokens_paddings_per_dp = [
418418
padding // self.dp_size for padding in self.num_tokens_paddings
419419
]

tpu_inference/runner/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from jax._src.interpreters import pxla
1616
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
1717

18+
from tpu_inference import envs
1819
from tpu_inference.logger import init_logger
1920
from tpu_inference.runner.input_batch import InputBatch
2021

@@ -306,8 +307,7 @@ def __init__(self, profile_dir: str):
306307
InferencePhase.BALANCED: False
307308
}
308309
self.default_profiling_options = jax.profiler.ProfileOptions()
309-
self.default_profiling_options.python_tracer_level = os.getenv(
310-
"PYTHON_TRACER_LEVEL", 0)
310+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311311

312312
self.current_phase: str = ""
313313

0 commit comments

Comments
 (0)