Skip to content

Commit ddbb94e

Browse files
author
bzgoogle
committed
fix bug after rebase
1 parent ca8596d commit ddbb94e

File tree

3 files changed

+92
-53
lines changed

3 files changed

+92
-53
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import os
22
import time
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3+
from collections.abc import Callable
4+
from typing import TYPE_CHECKING, Any
45

56
import jax
67
import jax.numpy as jnp
78
import numpy as np
8-
import vllm.envs as envs
99
from jax.sharding import NamedSharding, PartitionSpec
1010

11+
import vllm.envs as envs
1112
from tpu_inference.core.disagg_utils import is_disagg_enabled
1213
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1314
from tpu_inference.layers.jax.sample.sampling import sample
14-
from tpu_inference.layers.jax.sample.sampling_metadata import \
15-
TPUSupportedSamplingMetadata
15+
from tpu_inference.layers.jax.sample.sampling_metadata import (
16+
TPUSupportedSamplingMetadata,
17+
)
1618
from tpu_inference.layers.jax.sharding import ShardingAxisName
1719
from tpu_inference.logger import init_logger
1820
from tpu_inference.utils import device_array
@@ -36,9 +38,9 @@ def __init__(self, runner: "TPUModelRunner"):
3638
envs.VLLM_XLA_CACHE_PATH)
3739

3840
def _create_dummy_tensor(self,
39-
shape: Tuple[int, ...],
41+
shape: tuple[int, ...],
4042
dtype: Any,
41-
sharding: Optional[NamedSharding] = None) -> Any:
43+
sharding: NamedSharding | None = None) -> Any:
4244
"""Helper to create dummy tensors for precompilation."""
4345
tensor = jnp.ones(shape, dtype=dtype)
4446
if sharding:
@@ -273,11 +275,11 @@ def _precompile_backbone_with_inputs_embeds(self) -> None:
273275
def _precompile_select_from_array_helper(
274276
self,
275277
name: str,
276-
source_paddings: List[int],
277-
indices_paddings: List[int],
278+
source_paddings: list[int],
279+
indices_paddings: list[int],
278280
hidden_dim: int,
279-
input_sharding: Optional[NamedSharding] = None,
280-
indices_sharding: Optional[NamedSharding] = None,
281+
input_sharding: NamedSharding | None = None,
282+
indices_sharding: NamedSharding | None = None,
281283
only_equal_paddings: bool = False,
282284
check_should_skip_padding: bool = True,
283285
) -> None:
@@ -349,16 +351,18 @@ def _precompile_select_from_array(self) -> None:
349351
source_paddings=self.runner.num_logits_paddings,
350352
indices_paddings=self.runner.num_reqs_paddings,
351353
hidden_dim=vocab_size,
352-
input_sharding=NamedSharding(self.runner.mesh,
353-
PartitionSpec(None, ('model', 'expert')),
354+
input_sharding=NamedSharding(
355+
self.runner.mesh, PartitionSpec(None,
356+
('model', 'expert'))),
354357
)
355358
self._precompile_select_from_array_helper(
356359
name="select target tokens for spec decoding",
357360
source_paddings=self.runner.num_logits_paddings,
358361
indices_paddings=self.runner.num_logits_paddings,
359362
hidden_dim=vocab_size,
360-
input_sharding=NamedSharding(self.runner.mesh,
361-
PartitionSpec(None, ('model', 'expert')),
363+
input_sharding=NamedSharding(
364+
self.runner.mesh, PartitionSpec(None,
365+
('model', 'expert'))),
362366
only_equal_paddings=True,
363367
)
364368

@@ -390,7 +394,7 @@ def _precompile_sampling(self) -> None:
390394
for num_reqs in self.runner.num_reqs_paddings:
391395
logits_sharding = NamedSharding(
392396
self.runner.mesh,
393-
PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert'))
397+
PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert')))
394398
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
395399
sampling_metadata_sharding = NamedSharding(
396400
self.runner.mesh, PartitionSpec(
@@ -479,8 +483,8 @@ def _precompile_rejection_sampler(self) -> None:
479483
vocab_size = self.runner.model_config.get_vocab_size()
480484
for num_logits in self.runner.num_logits_paddings:
481485
for num_reqs in self.runner.num_reqs_paddings:
482-
sharding = NamedSharding(self.runner.mesh,
483-
PartitionSpec(None, ('model', 'expert')))
486+
sharding = NamedSharding(
487+
self.runner.mesh, PartitionSpec(None, ('model', 'expert')))
484488
target_probs = self._create_dummy_tensor(
485489
(num_logits, vocab_size), jnp.bfloat16, sharding)
486490
draft_token_ids = self._create_dummy_tensor((num_logits, ),

tpu_inference/runner/kv_cache.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List
1+
from typing import Any
22

33
import jax
44
import jax.numpy as jnp
@@ -46,9 +46,9 @@ def create_kv_caches(
4646
num_kv_heads: int,
4747
head_size: int,
4848
mesh: Mesh,
49-
layer_names: List[str],
49+
layer_names: list[str],
5050
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
51-
) -> List[jax.Array]:
51+
) -> list[jax.Array]:
5252
"""
5353
Creates a list of KV cache where each array mapps to single attention layer.
5454
@@ -78,8 +78,7 @@ def create_kv_caches(
7878

7979
sharding = NamedSharding(
8080
mesh,
81-
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82-
('model', 'expert'))
81+
PartitionSpec(ShardingAxisName.ATTN_DATA, None, ('model', 'expert')))
8382

8483
def _allocate() -> jax.Array:
8584
return jnp.empty(

tpu_inference/runner/tpu_runner.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,73 @@
22
import functools
33
import os
44
import random
5+
from collections.abc import Callable
56
from contextlib import nullcontext
67
from dataclasses import dataclass
7-
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
8+
from typing import Any, cast
89

910
import jax
1011
import jax.numpy as jnp
1112
import jaxtyping
1213
import numpy as np
1314
import torch
14-
import vllm.envs as envs
1515
from flax import nnx
1616
from jax.experimental import mesh_utils
1717
from jax.sharding import NamedSharding, PartitionSpec
1818
from torchax.ops.mappings import j2t_dtype
19+
20+
import vllm.envs as envs
21+
from tpu_inference import utils as common_utils
22+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
23+
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
24+
from tpu_inference.layers.jax.sample.sampling import (
25+
compute_logprobs,
26+
gather_logprobs,
27+
sample,
28+
)
29+
from tpu_inference.layers.jax.sample.sampling_metadata import (
30+
TPUSupportedSamplingMetadata,
31+
)
32+
from tpu_inference.layers.jax.sharding import ShardingAxisName, ShardingConfigManager
33+
from tpu_inference.logger import init_logger
34+
from tpu_inference.models.common.model_loader import get_model
35+
from tpu_inference.models.jax.utils.weight_utils import (
36+
shard_put,
37+
transfer_state_with_mappings,
38+
)
39+
from tpu_inference.runner import utils as runner_utils
40+
from tpu_inference.runner.compilation_manager import CompilationManager
41+
from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
42+
from tpu_inference.runner.kv_cache_manager import KVCacheManager
43+
from tpu_inference.runner.lora_utils import LoraUtils
44+
from tpu_inference.runner.multimodal_manager import MultiModalManager
45+
from tpu_inference.runner.persistent_batch_manager import PersistentBatchManager
46+
from tpu_inference.runner.speculative_decoding_manager import (
47+
SpecDecodeMetadata,
48+
SpeculativeDecodingManager,
49+
)
50+
from tpu_inference.runner.structured_decoding_manager import StructuredDecodingManager
51+
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
52+
from tpu_inference.utils import device_array, make_optimized_mesh, time_function
1953
from vllm.config import VllmConfig
20-
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21-
has_kv_transfer_group)
54+
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
2255
from vllm.forward_context import set_forward_context
2356
from vllm.sequence import IntermediateTensors
2457
from vllm.tasks import SupportedTask
2558
from vllm.utils.math_utils import cdiv
2659
from vllm.v1.core.sched.output import GrammarOutput
2760
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
2861
from vllm.v1.kv_cache_interface import KVCacheConfig
29-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
30-
DraftTokenIds, KVConnectorOutput,
31-
ModelRunnerOutput)
62+
from vllm.v1.outputs import (
63+
EMPTY_MODEL_RUNNER_OUTPUT,
64+
AsyncModelRunnerOutput,
65+
DraftTokenIds,
66+
KVConnectorOutput,
67+
ModelRunnerOutput,
68+
)
3269
from vllm.v1.request import Request
3370
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
34-
from vllm.v1.worker.kv_connector_model_runner_mixin import \
35-
KVConnectorModelRunnerMixin
71+
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
3672
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
3773

3874
from tpu_inference import utils as common_utils
@@ -108,7 +144,7 @@ def __init__(
108144
next_tokens: jax.Array,
109145
num_reqs: int,
110146
discard_sampled_tokens_req_indices: list[int],
111-
logits_indices_selector: Optional[List[int]] = None,
147+
logits_indices_selector: list[int] | None = None,
112148
):
113149
self._model_runner_output = model_runner_output
114150
self._next_tokens = next_tokens
@@ -136,7 +172,7 @@ class AsyncPreResults:
136172
request_seq_lens: list[tuple[int, CachedRequestState, int]]
137173
discard_sampled_tokens_req_indices: list[int]
138174
placeholder_req_id_to_index: dict[str, int]
139-
logits_indices_selector: Optional[List[int]] = None
175+
logits_indices_selector: list[int] | None = None
140176

141177

142178
@dataclass
@@ -146,13 +182,13 @@ class ExecuteModelState:
146182

147183
scheduler_output: "VllmSchedulerOutput"
148184
attn_metadata: AttentionMetadata
149-
input_ids: Optional[jax.Array]
185+
input_ids: jax.Array | None
150186
hidden_states: jax.Array
151187
logits: jax.Array
152-
aux_hidden_states: Optional[jax.Array]
153-
spec_decode_metadata: Optional[SpecDecodeMetadata]
154-
kv_connector_output: Optional[KVConnectorOutput]
155-
logits_indices_selector: Optional[List[int]] = None
188+
aux_hidden_states: jax.Array | None
189+
spec_decode_metadata: SpecDecodeMetadata | None
190+
kv_connector_output: KVConnectorOutput | None
191+
logits_indices_selector: list[int] | None = None
156192

157193

158194
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -195,7 +231,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
195231
def __init__(
196232
self,
197233
vllm_config: VllmConfig,
198-
devices: List[Any],
234+
devices: list[Any],
199235
):
200236
self.vllm_config = vllm_config
201237
self.model_config = vllm_config.model_config
@@ -517,7 +553,7 @@ def capture_model(self) -> None:
517553
def execute_model(
518554
self,
519555
scheduler_output: "VllmSchedulerOutput",
520-
intermediate_tensors: Optional[IntermediateTensors] = None,
556+
intermediate_tensors: IntermediateTensors | None = None,
521557
) -> ModelRunnerOutput | None:
522558
if self.execute_model_state is not None:
523559
raise RuntimeError("State error: sample_tokens() must be called "
@@ -746,13 +782,13 @@ def _sample_from_logits(
746782
self,
747783
scheduler_output: "VllmSchedulerOutput",
748784
attn_metadata: AttentionMetadata,
749-
input_ids: Optional[jax.Array],
785+
input_ids: jax.Array | None,
750786
hidden_states: jax.Array,
751787
logits: jax.Array,
752-
aux_hidden_states: Optional[jax.Array],
753-
spec_decode_metadata: Optional[SpecDecodeMetadata],
754-
kv_connector_output: Optional[KVConnectorOutput],
755-
logits_indices_selector: Optional[List[int]] = None,
788+
aux_hidden_states: jax.Array | None,
789+
spec_decode_metadata: SpecDecodeMetadata | None,
790+
kv_connector_output: KVConnectorOutput | None,
791+
logits_indices_selector: list[int] | None = None,
756792
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
757793
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
758794
self.input_batch.num_reqs, self.max_num_reqs)
@@ -1548,26 +1584,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
15481584
else:
15491585
return input_ids, None
15501586

1551-
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
1587+
def take_draft_token_ids(self) -> DraftTokenIds | None:
15521588
return self.speculative_decoding_manager.take_draft_token_ids()
15531589

15541590
###### Local disagg utilities ######
15551591

15561592
def get_kv_cache_for_block_ids(
15571593
self,
1558-
block_ids: List[int],
1559-
) -> List[jax.Array]:
1594+
block_ids: list[int],
1595+
) -> list[jax.Array]:
15601596
return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids)
15611597

15621598
def transfer_kv_cache(self,
1563-
kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
1599+
kv_cache_slices: list[jax.Array]) -> list[jax.Array]:
15641600
return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices)
15651601

15661602
def insert_request_with_kv_cache(
15671603
self,
15681604
request: "Request",
1569-
kv_cache_slices: List[jax.Array],
1570-
block_ids: List[List[int]],
1605+
kv_cache_slices: list[jax.Array],
1606+
block_ids: list[list[int]],
15711607
):
15721608
return self.kv_cache_manager.insert_request_with_kv_cache(
15731609
request, kv_cache_slices, block_ids)
@@ -1577,8 +1613,8 @@ def insert_request_with_kv_cache(
15771613
def _sync_weights(
15781614
self,
15791615
updated_weights: jaxtyping.PyTree,
1580-
mappings: Dict[str, Tuple[str, Tuple[str]]],
1581-
transpose_keys: Dict[str, Tuple[int]],
1616+
mappings: dict[str, tuple[str, tuple[str]]],
1617+
transpose_keys: dict[str, tuple[int]],
15821618
reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
15831619
jaxtyping.PyTree] = None
15841620
) -> None:

0 commit comments

Comments
 (0)