22import functools
33import os
44import random
5+ from collections .abc import Callable
56from contextlib import nullcontext
67from dataclasses import dataclass
7- from typing import Any , Callable , Dict , List , Optional , Tuple , cast
8+ from typing import Any , cast
89
910import jax
1011import jax .numpy as jnp
1112import jaxtyping
1213import numpy as np
1314import torch
14- import vllm .envs as envs
1515from flax import nnx
1616from jax .experimental import mesh_utils
1717from jax .sharding import NamedSharding , PartitionSpec
1818from torchax .ops .mappings import j2t_dtype
19+
20+ import vllm .envs as envs
21+ from tpu_inference import utils as common_utils
22+ from tpu_inference .layers .common .attention_metadata import AttentionMetadata
23+ from tpu_inference .layers .jax .sample .rejection_sampler import RejectionSampler
24+ from tpu_inference .layers .jax .sample .sampling import (
25+ compute_logprobs ,
26+ gather_logprobs ,
27+ sample ,
28+ )
29+ from tpu_inference .layers .jax .sample .sampling_metadata import (
30+ TPUSupportedSamplingMetadata ,
31+ )
32+ from tpu_inference .layers .jax .sharding import ShardingAxisName , ShardingConfigManager
33+ from tpu_inference .logger import init_logger
34+ from tpu_inference .models .common .model_loader import get_model
35+ from tpu_inference .models .jax .utils .weight_utils import (
36+ shard_put ,
37+ transfer_state_with_mappings ,
38+ )
39+ from tpu_inference .runner import utils as runner_utils
40+ from tpu_inference .runner .compilation_manager import CompilationManager
41+ from tpu_inference .runner .input_batch_jax import CachedRequestState , InputBatch
42+ from tpu_inference .runner .kv_cache_manager import KVCacheManager
43+ from tpu_inference .runner .lora_utils import LoraUtils
44+ from tpu_inference .runner .multimodal_manager import MultiModalManager
45+ from tpu_inference .runner .persistent_batch_manager import PersistentBatchManager
46+ from tpu_inference .runner .speculative_decoding_manager import (
47+ SpecDecodeMetadata ,
48+ SpeculativeDecodingManager ,
49+ )
50+ from tpu_inference .runner .structured_decoding_manager import StructuredDecodingManager
51+ from tpu_inference .spec_decode .jax .eagle3 import Eagle3Proposer
52+ from tpu_inference .utils import device_array , make_optimized_mesh , time_function
1953from vllm .config import VllmConfig
20- from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
21- has_kv_transfer_group )
54+ from vllm .distributed .kv_transfer import get_kv_transfer_group , has_kv_transfer_group
2255from vllm .forward_context import set_forward_context
2356from vllm .sequence import IntermediateTensors
2457from vllm .tasks import SupportedTask
3164 LogprobsTensors , ModelRunnerOutput )
3265from vllm .v1 .request import Request
3366from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
34- from vllm .v1 .worker .kv_connector_model_runner_mixin import \
35- KVConnectorModelRunnerMixin
67+ from vllm .v1 .worker .kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
3668from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
3769
3870from tpu_inference import utils as common_utils
@@ -108,7 +140,7 @@ def __init__(
108140 next_tokens : jax .Array ,
109141 num_reqs : int ,
110142 discard_sampled_tokens_req_indices : list [int ],
111- logits_indices_selector : Optional [ List [ int ]] = None ,
143+ logits_indices_selector : list [ int ] | None = None ,
112144 ):
113145 self ._model_runner_output = model_runner_output
114146 self ._next_tokens = next_tokens
@@ -137,7 +169,7 @@ class AsyncPreResults:
137169 request_seq_lens : list [tuple [int , CachedRequestState , int ]]
138170 discard_sampled_tokens_req_indices : list [int ]
139171 placeholder_req_id_to_index : dict [str , int ]
140- logits_indices_selector : Optional [ List [ int ]] = None
172+ logits_indices_selector : list [ int ] | None = None
141173
142174
143175@dataclass
@@ -147,7 +179,7 @@ class ExecuteModelState:
147179
148180 scheduler_output : "VllmSchedulerOutput"
149181 attn_metadata : AttentionMetadata
150- input_ids : Optional [ jax .Array ]
182+ input_ids : jax .Array | None
151183 hidden_states : jax .Array
152184 logits : jax .Array
153185 aux_hidden_states : Optional [jax .Array ]
@@ -552,7 +584,7 @@ def capture_model(self) -> None:
552584 def execute_model (
553585 self ,
554586 scheduler_output : "VllmSchedulerOutput" ,
555- intermediate_tensors : Optional [ IntermediateTensors ] = None ,
587+ intermediate_tensors : IntermediateTensors | None = None ,
556588 ) -> ModelRunnerOutput | None :
557589 if self .execute_model_state is not None :
558590 raise RuntimeError ("State error: sample_tokens() must be called "
@@ -797,7 +829,7 @@ def _sample_from_logits(
797829 self ,
798830 scheduler_output : "VllmSchedulerOutput" ,
799831 attn_metadata : AttentionMetadata ,
800- input_ids : Optional [ jax .Array ] ,
832+ input_ids : jax .Array | None ,
801833 hidden_states : jax .Array ,
802834 logits : jax .Array ,
803835 aux_hidden_states : Optional [jax .Array ],
@@ -1617,26 +1649,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
16171649 else :
16181650 return input_ids , None
16191651
1620- def take_draft_token_ids (self ) -> Optional [ DraftTokenIds ] :
1652+ def take_draft_token_ids (self ) -> DraftTokenIds | None :
16211653 return self .speculative_decoding_manager .take_draft_token_ids ()
16221654
16231655 ###### Local disagg utilities ######
16241656
16251657 def get_kv_cache_for_block_ids (
16261658 self ,
1627- block_ids : List [int ],
1628- ) -> List [jax .Array ]:
1659+ block_ids : list [int ],
1660+ ) -> list [jax .Array ]:
16291661 return self .kv_cache_manager .get_kv_cache_for_block_ids (block_ids )
16301662
16311663 def transfer_kv_cache (self ,
1632- kv_cache_slices : List [jax .Array ]) -> List [jax .Array ]:
1664+ kv_cache_slices : list [jax .Array ]) -> list [jax .Array ]:
16331665 return self .kv_cache_manager .transfer_kv_cache (kv_cache_slices )
16341666
16351667 def insert_request_with_kv_cache (
16361668 self ,
16371669 request : "Request" ,
1638- kv_cache_slices : List [jax .Array ],
1639- block_ids : List [ List [int ]],
1670+ kv_cache_slices : list [jax .Array ],
1671+ block_ids : list [ list [int ]],
16401672 ):
16411673 return self .kv_cache_manager .insert_request_with_kv_cache (
16421674 request , kv_cache_slices , block_ids )
@@ -1646,8 +1678,8 @@ def insert_request_with_kv_cache(
16461678 def _sync_weights (
16471679 self ,
16481680 updated_weights : jaxtyping .PyTree ,
1649- mappings : Dict [str , Tuple [str , Tuple [str ]]],
1650- transpose_keys : Dict [str , Tuple [int ]],
1681+ mappings : dict [str , tuple [str , tuple [str ]]],
1682+ transpose_keys : dict [str , tuple [int ]],
16511683 reshard_fn : Callable [[jaxtyping .PyTree , jaxtyping .PyTree ],
16521684 jaxtyping .PyTree ] = None
16531685 ) -> None :
0 commit comments