Skip to content

Commit bcdf5e8

Browse files
author
bzgoogle
committed
fix bug after rebase
1 parent d438930 commit bcdf5e8

File tree

3 files changed

+90
-80
lines changed

3 files changed

+90
-80
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import os
22
import time
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3+
from collections.abc import Callable
4+
from typing import TYPE_CHECKING, Any
45

56
import jax
67
import jax.numpy as jnp
78
import numpy as np
8-
import vllm.envs as envs
99
from jax.sharding import NamedSharding, PartitionSpec
1010

11+
import vllm.envs as envs
1112
from tpu_inference.core.disagg_utils import is_disagg_enabled
1213
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1314
from tpu_inference.layers.jax.sample.sampling import sample
14-
from tpu_inference.layers.jax.sample.sampling_metadata import \
15-
TPUSupportedSamplingMetadata
15+
from tpu_inference.layers.jax.sample.sampling_metadata import (
16+
TPUSupportedSamplingMetadata,
17+
)
1618
from tpu_inference.layers.jax.sharding import ShardingAxisName
1719
from tpu_inference.logger import init_logger
1820
from tpu_inference.utils import device_array
@@ -36,9 +38,9 @@ def __init__(self, runner: "TPUModelRunner"):
3638
envs.VLLM_XLA_CACHE_PATH)
3739

3840
def _create_dummy_tensor(self,
39-
shape: Tuple[int, ...],
41+
shape: tuple[int, ...],
4042
dtype: Any,
41-
sharding: Optional[NamedSharding] = None) -> Any:
43+
sharding: NamedSharding | None = None) -> Any:
4244
"""Helper to create dummy tensors for precompilation."""
4345
tensor = jnp.ones(shape, dtype=dtype)
4446
if sharding:
@@ -272,11 +274,11 @@ def _precompile_backbone_with_inputs_embeds(self) -> None:
272274
def _precompile_select_from_array_helper(
273275
self,
274276
name: str,
275-
source_paddings: List[int],
276-
indices_paddings: List[int],
277+
source_paddings: list[int],
278+
indices_paddings: list[int],
277279
hidden_dim: int,
278-
input_sharding: Optional[NamedSharding] = None,
279-
indices_sharding: Optional[NamedSharding] = None,
280+
input_sharding: NamedSharding | None = None,
281+
indices_sharding: NamedSharding | None = None,
280282
only_equal_paddings: bool = False,
281283
check_should_skip_padding: bool = True,
282284
) -> None:
@@ -348,16 +350,18 @@ def _precompile_select_from_array(self) -> None:
348350
source_paddings=self.runner.num_logits_paddings,
349351
indices_paddings=self.runner.num_reqs_paddings,
350352
hidden_dim=vocab_size,
351-
input_sharding=NamedSharding(self.runner.mesh,
352-
PartitionSpec(None, ('model', 'expert')),
353+
input_sharding=NamedSharding(
354+
self.runner.mesh, PartitionSpec(None,
355+
('model', 'expert'))),
353356
)
354357
self._precompile_select_from_array_helper(
355358
name="select target tokens for spec decoding",
356359
source_paddings=self.runner.num_logits_paddings,
357360
indices_paddings=self.runner.num_logits_paddings,
358361
hidden_dim=vocab_size,
359-
input_sharding=NamedSharding(self.runner.mesh,
360-
PartitionSpec(None, ('model', 'expert')),
362+
input_sharding=NamedSharding(
363+
self.runner.mesh, PartitionSpec(None,
364+
('model', 'expert'))),
361365
only_equal_paddings=True,
362366
)
363367

@@ -389,7 +393,7 @@ def _precompile_sampling(self) -> None:
389393
for num_reqs in self.runner.num_reqs_paddings:
390394
logits_sharding = NamedSharding(
391395
self.runner.mesh,
392-
PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert'))
396+
PartitionSpec(ShardingAxisName.ATTN_DATA, ('model', 'expert')))
393397
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
394398
sampling_metadata_sharding = NamedSharding(
395399
self.runner.mesh, PartitionSpec(
@@ -478,8 +482,8 @@ def _precompile_rejection_sampler(self) -> None:
478482
vocab_size = self.runner.model_config.get_vocab_size()
479483
for num_logits in self.runner.num_logits_paddings:
480484
for num_reqs in self.runner.num_reqs_paddings:
481-
sharding = NamedSharding(self.runner.mesh,
482-
PartitionSpec(None, ('model', 'expert')))
485+
sharding = NamedSharding(
486+
self.runner.mesh, PartitionSpec(None, ('model', 'expert')))
483487
target_probs = self._create_dummy_tensor(
484488
(num_logits, vocab_size), jnp.bfloat16, sharding)
485489
draft_token_ids = self._create_dummy_tensor((num_logits, ),

tpu_inference/runner/kv_cache.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List
1+
from typing import Any
22

33
import jax
44
import jax.numpy as jnp
@@ -46,9 +46,9 @@ def create_kv_caches(
4646
num_kv_heads: int,
4747
head_size: int,
4848
mesh: Mesh,
49-
layer_names: List[str],
49+
layer_names: list[str],
5050
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
51-
) -> List[jax.Array]:
51+
) -> list[jax.Array]:
5252
"""
5353
Creates a list of KV cache where each array mapps to single attention layer.
5454
@@ -78,8 +78,7 @@ def create_kv_caches(
7878

7979
sharding = NamedSharding(
8080
mesh,
81-
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82-
('model', 'expert'))
81+
PartitionSpec(ShardingAxisName.ATTN_DATA, None, ('model', 'expert')))
8382

8483
def _allocate() -> jax.Array:
8584
return jnp.empty(

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 65 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,66 +2,73 @@
22
import functools
33
import os
44
import random
5+
from collections.abc import Callable
56
from contextlib import nullcontext
67
from dataclasses import dataclass
7-
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
8+
from typing import Any, cast
89

910
import jax
1011
import jax.numpy as jnp
1112
import jaxtyping
1213
import numpy as np
1314
import torch
14-
import vllm.envs as envs
1515
from flax import nnx
1616
from jax.sharding import NamedSharding, PartitionSpec
1717
from torchax.ops.mappings import j2t_dtype
18-
from vllm.config import VllmConfig
19-
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
20-
has_kv_transfer_group)
21-
from vllm.forward_context import set_forward_context
22-
from vllm.sequence import IntermediateTensors
23-
from vllm.tasks import SupportedTask
24-
from vllm.utils.math_utils import cdiv
25-
from vllm.v1.core.sched.output import GrammarOutput
26-
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
27-
from vllm.v1.kv_cache_interface import KVCacheConfig
28-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
29-
DraftTokenIds, KVConnectorOutput,
30-
ModelRunnerOutput)
31-
from vllm.v1.request import Request
32-
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
33-
from vllm.v1.worker.kv_connector_model_runner_mixin import \
34-
KVConnectorModelRunnerMixin
35-
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
3618

19+
import vllm.envs as envs
3720
from tpu_inference import utils as common_utils
3821
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
3922
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
40-
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
41-
gather_logprobs, sample)
42-
from tpu_inference.layers.jax.sample.sampling_metadata import \
43-
TPUSupportedSamplingMetadata
44-
from tpu_inference.layers.jax.sharding import (ShardingAxisName,
45-
ShardingConfigManager)
23+
from tpu_inference.layers.jax.sample.sampling import (
24+
compute_logprobs,
25+
gather_logprobs,
26+
sample,
27+
)
28+
from tpu_inference.layers.jax.sample.sampling_metadata import (
29+
TPUSupportedSamplingMetadata,
30+
)
31+
from tpu_inference.layers.jax.sharding import ShardingAxisName, ShardingConfigManager
4632
from tpu_inference.logger import init_logger
4733
from tpu_inference.models.common.model_loader import get_model
4834
from tpu_inference.models.jax.utils.weight_utils import (
49-
shard_put, transfer_state_with_mappings)
35+
shard_put,
36+
transfer_state_with_mappings,
37+
)
5038
from tpu_inference.runner import utils as runner_utils
5139
from tpu_inference.runner.compilation_manager import CompilationManager
5240
from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
5341
from tpu_inference.runner.kv_cache_manager import KVCacheManager
5442
from tpu_inference.runner.lora_utils import LoraUtils
5543
from tpu_inference.runner.multimodal_manager import MultiModalManager
56-
from tpu_inference.runner.persistent_batch_manager import \
57-
PersistentBatchManager
44+
from tpu_inference.runner.persistent_batch_manager import PersistentBatchManager
5845
from tpu_inference.runner.speculative_decoding_manager import (
59-
SpecDecodeMetadata, SpeculativeDecodingManager)
60-
from tpu_inference.runner.structured_decoding_manager import \
61-
StructuredDecodingManager
46+
SpecDecodeMetadata,
47+
SpeculativeDecodingManager,
48+
)
49+
from tpu_inference.runner.structured_decoding_manager import StructuredDecodingManager
6250
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
63-
from tpu_inference.utils import (device_array, make_optimized_mesh,
64-
time_function)
51+
from tpu_inference.utils import device_array, make_optimized_mesh, time_function
52+
from vllm.config import VllmConfig
53+
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
54+
from vllm.forward_context import set_forward_context
55+
from vllm.sequence import IntermediateTensors
56+
from vllm.tasks import SupportedTask
57+
from vllm.utils.math_utils import cdiv
58+
from vllm.v1.core.sched.output import GrammarOutput
59+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
60+
from vllm.v1.kv_cache_interface import KVCacheConfig
61+
from vllm.v1.outputs import (
62+
EMPTY_MODEL_RUNNER_OUTPUT,
63+
AsyncModelRunnerOutput,
64+
DraftTokenIds,
65+
KVConnectorOutput,
66+
ModelRunnerOutput,
67+
)
68+
from vllm.v1.request import Request
69+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
70+
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
71+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
6572

6673
logger = init_logger(__name__)
6774

@@ -105,7 +112,7 @@ def __init__(
105112
next_tokens: jax.Array,
106113
num_reqs: int,
107114
discard_sampled_tokens_req_indices: list[int],
108-
logits_indices_selector: Optional[List[int]] = None,
115+
logits_indices_selector: list[int] | None = None,
109116
):
110117
self._model_runner_output = model_runner_output
111118
self._next_tokens = next_tokens
@@ -133,7 +140,7 @@ class AsyncPreResults:
133140
request_seq_lens: list[tuple[int, CachedRequestState, int]]
134141
discard_sampled_tokens_req_indices: list[int]
135142
placeholder_req_id_to_index: dict[str, int]
136-
logits_indices_selector: Optional[List[int]] = None
143+
logits_indices_selector: list[int] | None = None
137144

138145

139146
@dataclass
@@ -143,13 +150,13 @@ class ExecuteModelState:
143150

144151
scheduler_output: "VllmSchedulerOutput"
145152
attn_metadata: AttentionMetadata
146-
input_ids: Optional[jax.Array]
153+
input_ids: jax.Array | None
147154
hidden_states: jax.Array
148155
logits: jax.Array
149-
aux_hidden_states: Optional[jax.Array]
150-
spec_decode_metadata: Optional[SpecDecodeMetadata]
151-
kv_connector_output: Optional[KVConnectorOutput]
152-
logits_indices_selector: Optional[List[int]] = None
156+
aux_hidden_states: jax.Array | None
157+
spec_decode_metadata: SpecDecodeMetadata | None
158+
kv_connector_output: KVConnectorOutput | None
159+
logits_indices_selector: list[int] | None = None
153160

154161

155162
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -192,7 +199,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
192199
def __init__(
193200
self,
194201
vllm_config: VllmConfig,
195-
devices: List[Any],
202+
devices: list[Any],
196203
):
197204
self.vllm_config = vllm_config
198205
self.model_config = vllm_config.model_config
@@ -266,8 +273,8 @@ def _init_mesh(self) -> None:
266273
axis_names = ("data", "attn_dp", "expert", "model")
267274
mesh_shape = (sharding_strategy.model_dp_size,
268275
sharding_strategy.attn_dp_size,
269-
sharding_strategy.expert_size,
270-
sharding_strategy.tp_size)
276+
sharding_strategy.expert_size, 2)
277+
print(f"DEBUG: {sharding_strategy}")
271278

272279
else:
273280
axis_names = ("data", "model")
@@ -462,7 +469,7 @@ def capture_model(self) -> None:
462469
def execute_model(
463470
self,
464471
scheduler_output: "VllmSchedulerOutput",
465-
intermediate_tensors: Optional[IntermediateTensors] = None,
472+
intermediate_tensors: IntermediateTensors | None = None,
466473
) -> ModelRunnerOutput | None:
467474
if self.execute_model_state is not None:
468475
raise RuntimeError("State error: sample_tokens() must be called "
@@ -691,13 +698,13 @@ def _sample_from_logits(
691698
self,
692699
scheduler_output: "VllmSchedulerOutput",
693700
attn_metadata: AttentionMetadata,
694-
input_ids: Optional[jax.Array],
701+
input_ids: jax.Array | None,
695702
hidden_states: jax.Array,
696703
logits: jax.Array,
697-
aux_hidden_states: Optional[jax.Array],
698-
spec_decode_metadata: Optional[SpecDecodeMetadata],
699-
kv_connector_output: Optional[KVConnectorOutput],
700-
logits_indices_selector: Optional[List[int]] = None,
704+
aux_hidden_states: jax.Array | None,
705+
spec_decode_metadata: SpecDecodeMetadata | None,
706+
kv_connector_output: KVConnectorOutput | None,
707+
logits_indices_selector: list[int] | None = None,
701708
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
702709
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
703710
self.input_batch.num_reqs, self.max_num_reqs)
@@ -1493,26 +1500,26 @@ def _get_input_ids_embeds(self, input_ids: jax.Array,
14931500
else:
14941501
return input_ids, None
14951502

1496-
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
1503+
def take_draft_token_ids(self) -> DraftTokenIds | None:
14971504
return self.speculative_decoding_manager.take_draft_token_ids()
14981505

14991506
###### Local disagg utilities ######
15001507

15011508
def get_kv_cache_for_block_ids(
15021509
self,
1503-
block_ids: List[int],
1504-
) -> List[jax.Array]:
1510+
block_ids: list[int],
1511+
) -> list[jax.Array]:
15051512
return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids)
15061513

15071514
def transfer_kv_cache(self,
1508-
kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
1515+
kv_cache_slices: list[jax.Array]) -> list[jax.Array]:
15091516
return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices)
15101517

15111518
def insert_request_with_kv_cache(
15121519
self,
15131520
request: "Request",
1514-
kv_cache_slices: List[jax.Array],
1515-
block_ids: List[List[int]],
1521+
kv_cache_slices: list[jax.Array],
1522+
block_ids: list[list[int]],
15161523
):
15171524
return self.kv_cache_manager.insert_request_with_kv_cache(
15181525
request, kv_cache_slices, block_ids)
@@ -1522,8 +1529,8 @@ def insert_request_with_kv_cache(
15221529
def _sync_weights(
15231530
self,
15241531
updated_weights: jaxtyping.PyTree,
1525-
mappings: Dict[str, Tuple[str, Tuple[str]]],
1526-
transpose_keys: Dict[str, Tuple[int]],
1532+
mappings: dict[str, tuple[str, tuple[str]]],
1533+
transpose_keys: dict[str, tuple[int]],
15271534
reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
15281535
jaxtyping.PyTree] = None
15291536
) -> None:

0 commit comments

Comments
 (0)