diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 54eb53e62..6845cd07f 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -63,7 +63,7 @@ def test_init_success(self, mock_vllm_config): assert worker.profile_dir is None assert worker.devices == ['tpu:0'] - @patch('tpu_inference.worker.tpu_worker.envs') + @patch('tpu_inference.worker.tpu_worker.vllm_envs') def test_init_with_profiler_on_rank_zero(self, mock_envs, mock_vllm_config): """Tests that the profiler directory is set correctly on rank 0.""" @@ -74,7 +74,7 @@ def test_init_with_profiler_on_rank_zero(self, mock_envs, distributed_init_method="test_method") assert worker.profile_dir == "/tmp/profiles" - @patch('tpu_inference.worker.tpu_worker.envs') + @patch('tpu_inference.worker.tpu_worker.vllm_envs') def test_init_with_profiler_on_other_ranks(self, mock_envs, mock_vllm_config): """Tests that the profiler directory is NOT set on non-rank 0 workers.""" diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py new file mode 100644 index 000000000..1ef212f00 --- /dev/null +++ b/tpu_inference/envs.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project + +import functools +import os +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + JAX_PLATFORMS: str = "" + TPU_ACCELERATOR_TYPE: str | None = None + TPU_NAME: str | None = None + TPU_WORKER_ID: str | None = None + TPU_MULTIHOST_BACKEND: str = "" + PREFILL_SLICES: str = "" + DECODE_SLICES: str = "" + SKIP_JAX_PRECOMPILE: bool = False + MODEL_IMPL_TYPE: str = "flax_nnx" + NEW_MODEL_DESIGN: bool = False + PHASED_PROFILING_DIR: str = "" + PYTHON_TRACER_LEVEL: int = 1 + USE_MOE_EP_KERNEL: bool = False + RAY_USAGE_STATS_ENABLED: str = "0" + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm" + +environment_variables: dict[str, Callable[[], Any]] = { + # JAX platform selection (e.g., "tpu", "cpu", "proxy") + "JAX_PLATFORMS": + lambda: os.getenv("JAX_PLATFORMS", ""), + # TPU accelerator type (e.g., "v5litepod-16", "v4-8") + "TPU_ACCELERATOR_TYPE": + lambda: os.getenv("TPU_ACCELERATOR_TYPE", None), + # Name of the TPU resource + "TPU_NAME": + lambda: os.getenv("TPU_NAME", None), + # Worker ID for multi-host TPU setups + "TPU_WORKER_ID": + lambda: os.getenv("TPU_WORKER_ID", None), + # Backend for multi-host communication on TPU + "TPU_MULTIHOST_BACKEND": + lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(), + # Slice configuration for disaggregated prefill workers + "PREFILL_SLICES": + lambda: os.getenv("PREFILL_SLICES", ""), + # Slice configuration for disaggregated decode workers + "DECODE_SLICES": + lambda: os.getenv("DECODE_SLICES", ""), + # Skip JAX precompilation step during initialization + "SKIP_JAX_PRECOMPILE": + lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))), + # Model implementation type (e.g., "flax_nnx") + "MODEL_IMPL_TYPE": + lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(), + # Enable new experimental model design + "NEW_MODEL_DESIGN": + lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))), + # Directory to store phased profiling output + "PHASED_PROFILING_DIR": + lambda: os.getenv("PHASED_PROFILING_DIR", ""), + # Python tracer level for profiling + "PYTHON_TRACER_LEVEL": + lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")), + # Use custom expert-parallel kernel for MoE (Mixture of Experts) + "USE_MOE_EP_KERNEL": + lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))), + # Enable/disable Ray usage statistics collection + "RAY_USAGE_STATS_ENABLED": + lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"), + # Ray compiled DAG channel type for TPU + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": + lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"), +} + + +def __getattr__(name: str) -> Any: + """ + Gets environment variables lazily. + + NOTE: After enable_envs_cache() invocation (which triggered after service + initialization), all environment variables will be cached. + """ + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def enable_envs_cache() -> None: + """ + Enables caching of environment variables by wrapping the module's __getattr__ + function with functools.cache(). This improves performance by avoiding + repeated re-evaluation of environment variables. + + NOTE: This should be called after service initialization. Once enabled, + environment variable values are cached and will not reflect changes to + os.environ until the process is restarted. + """ + # Tag __getattr__ with functools.cache + global __getattr__ + __getattr__ = functools.cache(__getattr__) + + # Cache all environment variables + for key in environment_variables: + __getattr__(key) + + +def __dir__() -> list[str]: + return list(environment_variables.keys()) diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 7881332f7..5ae06d9e2 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -1,4 +1,3 @@ -import os from typing import Any, Callable, Optional, Union import jax @@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from tpu_inference import envs from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded from tpu_inference.layers.vllm.linear_common import ( @@ -164,7 +164,7 @@ def __init__(self, ep_axis_name: str = 'model'): super().__init__(moe) self.mesh = mesh - self.use_kernel = bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))) + self.use_kernel = envs.USE_MOE_EP_KERNEL self.ep_axis_name = ep_axis_name # TODO: Use autotune table once we have it. self.block_size = { diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index c0b1546b6..72ec188b9 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -1,5 +1,4 @@ import functools -import os from typing import Any, Optional import jax @@ -11,6 +10,7 @@ from vllm.config import VllmConfig from vllm.utils.func_utils import supports_kw +from tpu_inference import envs from tpu_inference.layers.jax.sharding import ShardingAxisName from tpu_inference.logger import init_logger from tpu_inference.models.jax.utils.quantization.quantization_utils import ( @@ -314,7 +314,7 @@ def get_model( mesh: Mesh, is_draft_model: bool = False, ) -> Any: - impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + impl = envs.MODEL_IMPL_TYPE logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}") if impl == "flax_nnx": diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 6b5fcf425..0f47ed60f 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import jax.numpy as jnp -import vllm.envs as envs +import vllm.envs as vllm_envs from torchax.ops.mappings import j2t_dtype from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType from vllm.platforms.interface import Platform, PlatformEnum from vllm.sampling_params import SamplingParams, SamplingType +from tpu_inference import envs from tpu_inference.layers.jax.sharding import ShardingConfigManager from tpu_inference.logger import init_logger @@ -71,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, @classmethod def get_device_name(cls, device_id: int = 0) -> str: try: - if envs.VLLM_TPU_USING_PATHWAYS: + if vllm_envs.VLLM_TPU_USING_PATHWAYS: # Causes mutliprocess accessing IFRT when calling jax.devices() return "TPU v6 lite" else: @@ -87,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return not envs.VLLM_USE_V1 + return not vllm_envs.VLLM_USE_V1 @classmethod def get_punica_wrapper(cls) -> str: @@ -118,11 +119,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - if not envs.VLLM_USE_V1: + if not vllm_envs.VLLM_USE_V1: raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.") - if envs.VLLM_TPU_USING_PATHWAYS: - assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( + if vllm_envs.VLLM_TPU_USING_PATHWAYS: + assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( "VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)" ) cls._initialize_sharding_config(vllm_config) @@ -144,7 +145,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.backend = "openxla" # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype. - impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + impl = envs.MODEL_IMPL_TYPE # NOTE(xiang): convert dtype to jnp.dtype # NOTE(wenlong): skip this logic for mm model preprocessing @@ -164,7 +165,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.model_config.dtype = j2t_dtype( vllm_config.model_config.dtype.dtype) - if envs.VLLM_USE_V1: + if vllm_envs.VLLM_USE_V1: # TODO(cuiq): remove this dependency. from vllm.v1.attention.backends.pallas import \ PallasAttentionBackend @@ -250,7 +251,7 @@ def validate_request( """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): - if params.structured_outputs is not None and not envs.VLLM_USE_V1: + if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1: raise ValueError("Structured output is not supported on " f"{cls.device_name} V0.") if params.sampling_type == SamplingType.RANDOM_SEED: diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 7e1fae002..172f7ab44 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import jaxlib import jaxtyping -import vllm.envs as envs +import vllm.envs as vllm_envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, has_kv_transfer_group) @@ -22,7 +22,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput -from tpu_inference import utils +from tpu_inference import envs, utils from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.jax.sharding import ShardingConfigManager @@ -50,7 +50,7 @@ def __init__(self, devices=None): # If we use vLLM's model implementation in PyTorch, we should set it # with torch version of the dtype. - impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + impl = envs.MODEL_IMPL_TYPE if impl != "vllm": # vllm-pytorch implementation does not need this conversion # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing @@ -86,11 +86,11 @@ def __init__(self, # TPU Worker is initialized. The profiler server needs to start after # MP runtime is initialized. self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: if not self.devices or 0 in self.device_ranks: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. - self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR logger.info("Profiling enabled. Traces will be saved to: %s", self.profile_dir)