22import functools
33import os
44import random
5+ from collections .abc import Callable
56from contextlib import nullcontext
67from dataclasses import dataclass
7- from typing import Any , Callable , Dict , List , Optional , Tuple , cast
8+ from typing import Any , cast
89
910import jax
1011import jax .numpy as jnp
1112import jaxtyping
1213import numpy as np
1314import torch
14- import vllm .envs as envs
1515from flax import nnx
1616from jax .sharding import NamedSharding , PartitionSpec
1717from torchax .ops .mappings import j2t_dtype
18- from vllm .config import VllmConfig
19- from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
20- has_kv_transfer_group )
21- from vllm .forward_context import set_forward_context
22- from vllm .sequence import IntermediateTensors
23- from vllm .tasks import SupportedTask
24- from vllm .utils .math_utils import cdiv
25- from vllm .v1 .core .sched .output import GrammarOutput
26- from vllm .v1 .core .sched .output import SchedulerOutput as VllmSchedulerOutput
27- from vllm .v1 .kv_cache_interface import KVCacheConfig
28- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
29- DraftTokenIds , KVConnectorOutput ,
30- ModelRunnerOutput )
31- from vllm .v1 .request import Request
32- from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
33- from vllm .v1 .worker .kv_connector_model_runner_mixin import \
34- KVConnectorModelRunnerMixin
35- from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
3618
19+ import vllm .envs as envs
3720from tpu_inference import utils as common_utils
3821from tpu_inference .layers .common .attention_metadata import AttentionMetadata
3922from tpu_inference .layers .jax .sample .rejection_sampler import RejectionSampler
40- from tpu_inference .layers .jax .sample .sampling import (compute_logprobs ,
41- gather_logprobs , sample )
42- from tpu_inference .layers .jax .sample .sampling_metadata import \
43- TPUSupportedSamplingMetadata
44- from tpu_inference .layers .jax .sharding import (ShardingAxisName ,
45- ShardingConfigManager )
23+ from tpu_inference .layers .jax .sample .sampling import (
24+ compute_logprobs ,
25+ gather_logprobs ,
26+ sample ,
27+ )
28+ from tpu_inference .layers .jax .sample .sampling_metadata import (
29+ TPUSupportedSamplingMetadata ,
30+ )
31+ from tpu_inference .layers .jax .sharding import ShardingAxisName , ShardingConfigManager
4632from tpu_inference .logger import init_logger
4733from tpu_inference .models .common .model_loader import get_model
4834from tpu_inference .models .jax .utils .weight_utils import (
49- shard_put , transfer_state_with_mappings )
35+ shard_put ,
36+ transfer_state_with_mappings ,
37+ )
5038from tpu_inference .runner import utils as runner_utils
5139from tpu_inference .runner .compilation_manager import CompilationManager
5240from tpu_inference .runner .input_batch_jax import CachedRequestState , InputBatch
5341from tpu_inference .runner .kv_cache_manager import KVCacheManager
5442from tpu_inference .runner .lora_utils import LoraUtils
5543from tpu_inference .runner .multimodal_manager import MultiModalManager
56- from tpu_inference .runner .persistent_batch_manager import \
57- PersistentBatchManager
44+ from tpu_inference .runner .persistent_batch_manager import PersistentBatchManager
5845from tpu_inference .runner .speculative_decoding_manager import (
59- SpecDecodeMetadata , SpeculativeDecodingManager )
60- from tpu_inference .runner .structured_decoding_manager import \
61- StructuredDecodingManager
46+ SpecDecodeMetadata ,
47+ SpeculativeDecodingManager ,
48+ )
49+ from tpu_inference .runner .structured_decoding_manager import StructuredDecodingManager
6250from tpu_inference .spec_decode .jax .eagle3 import Eagle3Proposer
63- from tpu_inference .utils import (device_array , make_optimized_mesh ,
64- time_function )
51+ from tpu_inference .utils import device_array , make_optimized_mesh , time_function
52+ from vllm .config import VllmConfig
53+ from vllm .distributed .kv_transfer import get_kv_transfer_group , has_kv_transfer_group
54+ from vllm .forward_context import set_forward_context
55+ from vllm .sequence import IntermediateTensors
56+ from vllm .tasks import SupportedTask
57+ from vllm .utils .math_utils import cdiv
58+ from vllm .v1 .core .sched .output import GrammarOutput
59+ from vllm .v1 .core .sched .output import SchedulerOutput as VllmSchedulerOutput
60+ from vllm .v1 .kv_cache_interface import KVCacheConfig
61+ from vllm .v1 .outputs import (
62+ EMPTY_MODEL_RUNNER_OUTPUT ,
63+ AsyncModelRunnerOutput ,
64+ DraftTokenIds ,
65+ KVConnectorOutput ,
66+ ModelRunnerOutput ,
67+ )
68+ from vllm .v1 .request import Request
69+ from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
70+ from vllm .v1 .worker .kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
71+ from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
6572
6673logger = init_logger (__name__ )
6774
@@ -105,7 +112,7 @@ def __init__(
105112 next_tokens : jax .Array ,
106113 num_reqs : int ,
107114 discard_sampled_tokens_req_indices : list [int ],
108- logits_indices_selector : Optional [ List [ int ]] = None ,
115+ logits_indices_selector : list [ int ] | None = None ,
109116 ):
110117 self ._model_runner_output = model_runner_output
111118 self ._next_tokens = next_tokens
@@ -133,7 +140,7 @@ class AsyncPreResults:
133140 request_seq_lens : list [tuple [int , CachedRequestState , int ]]
134141 discard_sampled_tokens_req_indices : list [int ]
135142 placeholder_req_id_to_index : dict [str , int ]
136- logits_indices_selector : Optional [ List [ int ]] = None
143+ logits_indices_selector : list [ int ] | None = None
137144
138145
139146@dataclass
@@ -143,13 +150,13 @@ class ExecuteModelState:
143150
144151 scheduler_output : "VllmSchedulerOutput"
145152 attn_metadata : AttentionMetadata
146- input_ids : Optional [ jax .Array ]
153+ input_ids : jax .Array | None
147154 hidden_states : jax .Array
148155 logits : jax .Array
149- aux_hidden_states : Optional [ jax .Array ]
150- spec_decode_metadata : Optional [ SpecDecodeMetadata ]
151- kv_connector_output : Optional [ KVConnectorOutput ]
152- logits_indices_selector : Optional [ List [ int ]] = None
156+ aux_hidden_states : jax .Array | None
157+ spec_decode_metadata : SpecDecodeMetadata | None
158+ kv_connector_output : KVConnectorOutput | None
159+ logits_indices_selector : list [ int ] | None = None
153160
154161
155162@functools .partial (jax .jit , donate_argnums = (0 , 1 , 2 ))
@@ -192,7 +199,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
192199 def __init__ (
193200 self ,
194201 vllm_config : VllmConfig ,
195- devices : List [Any ],
202+ devices : list [Any ],
196203 ):
197204 self .vllm_config = vllm_config
198205 self .model_config = vllm_config .model_config
@@ -266,8 +273,8 @@ def _init_mesh(self) -> None:
266273 axis_names = ("data" , "attn_dp" , "expert" , "model" )
267274 mesh_shape = (sharding_strategy .model_dp_size ,
268275 sharding_strategy .attn_dp_size ,
269- sharding_strategy .expert_size ,
270- sharding_strategy . tp_size )
276+ sharding_strategy .expert_size , 2 )
277+ print ( f"DEBUG: { sharding_strategy } " )
271278
272279 else :
273280 axis_names = ("data" , "model" )
@@ -462,7 +469,7 @@ def capture_model(self) -> None:
462469 def execute_model (
463470 self ,
464471 scheduler_output : "VllmSchedulerOutput" ,
465- intermediate_tensors : Optional [ IntermediateTensors ] = None ,
472+ intermediate_tensors : IntermediateTensors | None = None ,
466473 ) -> ModelRunnerOutput | None :
467474 if self .execute_model_state is not None :
468475 raise RuntimeError ("State error: sample_tokens() must be called "
@@ -691,13 +698,13 @@ def _sample_from_logits(
691698 self ,
692699 scheduler_output : "VllmSchedulerOutput" ,
693700 attn_metadata : AttentionMetadata ,
694- input_ids : Optional [ jax .Array ] ,
701+ input_ids : jax .Array | None ,
695702 hidden_states : jax .Array ,
696703 logits : jax .Array ,
697- aux_hidden_states : Optional [ jax .Array ] ,
698- spec_decode_metadata : Optional [ SpecDecodeMetadata ] ,
699- kv_connector_output : Optional [ KVConnectorOutput ] ,
700- logits_indices_selector : Optional [ List [ int ]] = None ,
704+ aux_hidden_states : jax .Array | None ,
705+ spec_decode_metadata : SpecDecodeMetadata | None ,
706+ kv_connector_output : KVConnectorOutput | None ,
707+ logits_indices_selector : list [ int ] | None = None ,
701708 ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput :
702709 padded_num_reqs = runner_utils .get_padded_num_reqs_with_upper_limit (
703710 self .input_batch .num_reqs , self .max_num_reqs )
@@ -1493,26 +1500,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
14931500 else :
14941501 return input_ids , None
14951502
1496- def take_draft_token_ids (self ) -> Optional [ DraftTokenIds ] :
1503+ def take_draft_token_ids (self ) -> DraftTokenIds | None :
14971504 return self .speculative_decoding_manager .take_draft_token_ids ()
14981505
14991506 ###### Local disagg utilities ######
15001507
15011508 def get_kv_cache_for_block_ids (
15021509 self ,
1503- block_ids : List [int ],
1504- ) -> List [jax .Array ]:
1510+ block_ids : list [int ],
1511+ ) -> list [jax .Array ]:
15051512 return self .kv_cache_manager .get_kv_cache_for_block_ids (block_ids )
15061513
15071514 def transfer_kv_cache (self ,
1508- kv_cache_slices : List [jax .Array ]) -> List [jax .Array ]:
1515+ kv_cache_slices : list [jax .Array ]) -> list [jax .Array ]:
15091516 return self .kv_cache_manager .transfer_kv_cache (kv_cache_slices )
15101517
15111518 def insert_request_with_kv_cache (
15121519 self ,
15131520 request : "Request" ,
1514- kv_cache_slices : List [jax .Array ],
1515- block_ids : List [ List [int ]],
1521+ kv_cache_slices : list [jax .Array ],
1522+ block_ids : list [ list [int ]],
15161523 ):
15171524 return self .kv_cache_manager .insert_request_with_kv_cache (
15181525 request , kv_cache_slices , block_ids )
@@ -1522,8 +1529,8 @@ def insert_request_with_kv_cache(
15221529 def _sync_weights (
15231530 self ,
15241531 updated_weights : jaxtyping .PyTree ,
1525- mappings : Dict [str , Tuple [str , Tuple [str ]]],
1526- transpose_keys : Dict [str , Tuple [int ]],
1532+ mappings : dict [str , tuple [str , tuple [str ]]],
1533+ transpose_keys : dict [str , tuple [int ]],
15271534 reshard_fn : Callable [[jaxtyping .PyTree , jaxtyping .PyTree ],
15281535 jaxtyping .PyTree ] = None
15291536 ) -> None :
0 commit comments