22import functools
33import os
44import random
5+ from collections .abc import Callable
56from contextlib import nullcontext
67from dataclasses import dataclass
7- from typing import Any , Callable , Dict , List , Optional , Tuple , cast
8+ from typing import Any , cast
89
910import jax
1011import jax .numpy as jnp
1112import jaxtyping
1213import numpy as np
1314import torch
14- import vllm .envs as envs
1515from flax import nnx
1616from jax .experimental import mesh_utils
1717from jax .sharding import NamedSharding , PartitionSpec
1818from torchax .ops .mappings import j2t_dtype
19+
20+ import vllm .envs as envs
21+ from tpu_inference import utils as common_utils
22+ from tpu_inference .layers .common .attention_metadata import AttentionMetadata
23+ from tpu_inference .layers .jax .sample .rejection_sampler import RejectionSampler
24+ from tpu_inference .layers .jax .sample .sampling import (
25+ compute_logprobs ,
26+ gather_logprobs ,
27+ sample ,
28+ )
29+ from tpu_inference .layers .jax .sample .sampling_metadata import (
30+ TPUSupportedSamplingMetadata ,
31+ )
32+ from tpu_inference .layers .jax .sharding import ShardingAxisName , ShardingConfigManager
33+ from tpu_inference .logger import init_logger
34+ from tpu_inference .models .common .model_loader import get_model
35+ from tpu_inference .models .jax .utils .weight_utils import (
36+ shard_put ,
37+ transfer_state_with_mappings ,
38+ )
39+ from tpu_inference .runner import utils as runner_utils
40+ from tpu_inference .runner .compilation_manager import CompilationManager
41+ from tpu_inference .runner .input_batch_jax import CachedRequestState , InputBatch
42+ from tpu_inference .runner .kv_cache_manager import KVCacheManager
43+ from tpu_inference .runner .lora_utils import LoraUtils
44+ from tpu_inference .runner .multimodal_manager import MultiModalManager
45+ from tpu_inference .runner .persistent_batch_manager import PersistentBatchManager
46+ from tpu_inference .runner .speculative_decoding_manager import (
47+ SpecDecodeMetadata ,
48+ SpeculativeDecodingManager ,
49+ )
50+ from tpu_inference .runner .structured_decoding_manager import StructuredDecodingManager
51+ from tpu_inference .spec_decode .jax .eagle3 import Eagle3Proposer
52+ from tpu_inference .utils import device_array , make_optimized_mesh , time_function
1953from vllm .config import VllmConfig
20- from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
21- has_kv_transfer_group )
54+ from vllm .distributed .kv_transfer import get_kv_transfer_group , has_kv_transfer_group
2255from vllm .forward_context import set_forward_context
2356from vllm .sequence import IntermediateTensors
2457from vllm .tasks import SupportedTask
2558from vllm .utils .math_utils import cdiv
2659from vllm .v1 .core .sched .output import GrammarOutput
2760from vllm .v1 .core .sched .output import SchedulerOutput as VllmSchedulerOutput
2861from vllm .v1 .kv_cache_interface import KVCacheConfig
29- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
30- DraftTokenIds , KVConnectorOutput ,
31- ModelRunnerOutput )
62+ from vllm .v1 .outputs import (
63+ EMPTY_MODEL_RUNNER_OUTPUT ,
64+ AsyncModelRunnerOutput ,
65+ DraftTokenIds ,
66+ KVConnectorOutput ,
67+ ModelRunnerOutput ,
68+ )
3269from vllm .v1 .request import Request
3370from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
34- from vllm .v1 .worker .kv_connector_model_runner_mixin import \
35- KVConnectorModelRunnerMixin
71+ from vllm .v1 .worker .kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
3672from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
3773
3874from tpu_inference import utils as common_utils
@@ -108,7 +144,7 @@ def __init__(
108144 next_tokens : jax .Array ,
109145 num_reqs : int ,
110146 discard_sampled_tokens_req_indices : list [int ],
111- logits_indices_selector : Optional [ List [ int ]] = None ,
147+ logits_indices_selector : list [ int ] | None = None ,
112148 ):
113149 self ._model_runner_output = model_runner_output
114150 self ._next_tokens = next_tokens
@@ -136,7 +172,7 @@ class AsyncPreResults:
136172 request_seq_lens : list [tuple [int , CachedRequestState , int ]]
137173 discard_sampled_tokens_req_indices : list [int ]
138174 placeholder_req_id_to_index : dict [str , int ]
139- logits_indices_selector : Optional [ List [ int ]] = None
175+ logits_indices_selector : list [ int ] | None = None
140176
141177
142178@dataclass
@@ -146,13 +182,13 @@ class ExecuteModelState:
146182
147183 scheduler_output : "VllmSchedulerOutput"
148184 attn_metadata : AttentionMetadata
149- input_ids : Optional [ jax .Array ]
185+ input_ids : jax .Array | None
150186 hidden_states : jax .Array
151187 logits : jax .Array
152- aux_hidden_states : Optional [ jax .Array ]
153- spec_decode_metadata : Optional [ SpecDecodeMetadata ]
154- kv_connector_output : Optional [ KVConnectorOutput ]
155- logits_indices_selector : Optional [ List [ int ]] = None
188+ aux_hidden_states : jax .Array | None
189+ spec_decode_metadata : SpecDecodeMetadata | None
190+ kv_connector_output : KVConnectorOutput | None
191+ logits_indices_selector : list [ int ] | None = None
156192
157193
158194@functools .partial (jax .jit , donate_argnums = (0 , 1 , 2 ))
@@ -195,7 +231,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
195231 def __init__ (
196232 self ,
197233 vllm_config : VllmConfig ,
198- devices : List [Any ],
234+ devices : list [Any ],
199235 ):
200236 self .vllm_config = vllm_config
201237 self .model_config = vllm_config .model_config
@@ -517,7 +553,7 @@ def capture_model(self) -> None:
517553 def execute_model (
518554 self ,
519555 scheduler_output : "VllmSchedulerOutput" ,
520- intermediate_tensors : Optional [ IntermediateTensors ] = None ,
556+ intermediate_tensors : IntermediateTensors | None = None ,
521557 ) -> ModelRunnerOutput | None :
522558 if self .execute_model_state is not None :
523559 raise RuntimeError ("State error: sample_tokens() must be called "
@@ -746,13 +782,13 @@ def _sample_from_logits(
746782 self ,
747783 scheduler_output : "VllmSchedulerOutput" ,
748784 attn_metadata : AttentionMetadata ,
749- input_ids : Optional [ jax .Array ] ,
785+ input_ids : jax .Array | None ,
750786 hidden_states : jax .Array ,
751787 logits : jax .Array ,
752- aux_hidden_states : Optional [ jax .Array ] ,
753- spec_decode_metadata : Optional [ SpecDecodeMetadata ] ,
754- kv_connector_output : Optional [ KVConnectorOutput ] ,
755- logits_indices_selector : Optional [ List [ int ]] = None ,
788+ aux_hidden_states : jax .Array | None ,
789+ spec_decode_metadata : SpecDecodeMetadata | None ,
790+ kv_connector_output : KVConnectorOutput | None ,
791+ logits_indices_selector : list [ int ] | None = None ,
756792 ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput :
757793 padded_num_reqs = runner_utils .get_padded_num_reqs_with_upper_limit (
758794 self .input_batch .num_reqs , self .max_num_reqs )
@@ -1548,26 +1584,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
15481584 else :
15491585 return input_ids , None
15501586
1551- def take_draft_token_ids (self ) -> Optional [ DraftTokenIds ] :
1587+ def take_draft_token_ids (self ) -> DraftTokenIds | None :
15521588 return self .speculative_decoding_manager .take_draft_token_ids ()
15531589
15541590 ###### Local disagg utilities ######
15551591
15561592 def get_kv_cache_for_block_ids (
15571593 self ,
1558- block_ids : List [int ],
1559- ) -> List [jax .Array ]:
1594+ block_ids : list [int ],
1595+ ) -> list [jax .Array ]:
15601596 return self .kv_cache_manager .get_kv_cache_for_block_ids (block_ids )
15611597
15621598 def transfer_kv_cache (self ,
1563- kv_cache_slices : List [jax .Array ]) -> List [jax .Array ]:
1599+ kv_cache_slices : list [jax .Array ]) -> list [jax .Array ]:
15641600 return self .kv_cache_manager .transfer_kv_cache (kv_cache_slices )
15651601
15661602 def insert_request_with_kv_cache (
15671603 self ,
15681604 request : "Request" ,
1569- kv_cache_slices : List [jax .Array ],
1570- block_ids : List [ List [int ]],
1605+ kv_cache_slices : list [jax .Array ],
1606+ block_ids : list [ list [int ]],
15711607 ):
15721608 return self .kv_cache_manager .insert_request_with_kv_cache (
15731609 request , kv_cache_slices , block_ids )
@@ -1577,8 +1613,8 @@ def insert_request_with_kv_cache(
15771613 def _sync_weights (
15781614 self ,
15791615 updated_weights : jaxtyping .PyTree ,
1580- mappings : Dict [str , Tuple [str , Tuple [str ]]],
1581- transpose_keys : Dict [str , Tuple [int ]],
1616+ mappings : dict [str , tuple [str , tuple [str ]]],
1617+ transpose_keys : dict [str , tuple [int ]],
15821618 reshard_fn : Callable [[jaxtyping .PyTree , jaxtyping .PyTree ],
15831619 jaxtyping .PyTree ] = None
15841620 ) -> None :
0 commit comments