Skip to content

Commit 15ee8d3

Browse files
committed
fixes
1 parent 0e4dc9c commit 15ee8d3

File tree

4 files changed

+3
-7
lines changed

4 files changed

+3
-7
lines changed

tpu_inference/core/disagg_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
from tpu_inference import envs
66

7-
PREFILL_SLICES = 'PREFILL_SLICES'
8-
DECODE_SLICES = 'DECODE_SLICES'
9-
107

118
def is_disagg_enabled() -> bool:
129
# We triggrer our code path as long as prefill slices are set. This

tpu_inference/distributed/tpu_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
8686
from vllm.v1.request import Request
8787

88+
from tpu_inference import envs
8889
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
8990
get_kv_ports,
9091
get_kv_transfer_port, get_node_id,
@@ -440,7 +441,6 @@ def __init__(self, vllm_config: VllmConfig):
440441

441442
self.runner: TPUModelRunner = None
442443
self.mesh: Mesh = None
443-
from tpu_inference import envs
444444
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
445445
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
446446
# The worker rank is assigned with vLLM's sorting logic, which does not work

tpu_inference/layers/vllm/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.model_executor.layers.vocab_parallel_embedding import (
2020
ParallelLMHead, VocabParallelEmbedding)
2121

22+
from tpu_inference import envs
2223
from tpu_inference.logger import init_logger
2324

2425
P = PartitionSpec
@@ -211,7 +212,6 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
211212
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212213
if isinstance(tensor, tuple):
213214
return tuple(_sharded_device_put(t, sharding) for t in tensor)
214-
from tpu_inference import envs
215215
multihost_backend = envs.TPU_MULTIHOST_BACKEND
216216
if multihost_backend != "ray":
217217
return jax.device_put(tensor, sharding)

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from jax.sharding import PartitionSpec as P
1919
from safetensors import safe_open
2020

21-
from tpu_inference import utils
21+
from tpu_inference import envs, utils
2222
from tpu_inference.logger import init_logger
2323
from tpu_inference.models.jax.utils import file_utils
2424

@@ -421,7 +421,6 @@ def load_hf_weights(vllm_config,
421421
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
422422
# Because multi-threading would cause different JAX processes to load
423423
# different weights at the same time.
424-
from tpu_inference import envs
425424
if envs.TPU_MULTIHOST_BACKEND == "ray":
426425
max_workers = 1
427426
with ThreadPoolExecutor(max_workers=max_workers) as executor:

0 commit comments

Comments
 (0)