1111from tpu_inference .core .disagg_utils import is_disagg_enabled
1212from tpu_inference .layers .common .attention_metadata import AttentionMetadata
1313from tpu_inference .layers .common .sharding import ShardingAxisName
14+ from tpu_inference .layers .jax .pool .pooling import pool
15+ from tpu_inference .layers .jax .pool .pooling_metadata import (
16+ TPUSupportedPoolingMetadata ,
17+ )
1418from tpu_inference .layers .jax .sample .sampling import sample
15- from tpu_inference .layers .jax .sample .sampling_metadata import \
16- TPUSupportedSamplingMetadata
19+ from tpu_inference .layers .jax .sample .sampling_metadata import (
20+ TPUSupportedSamplingMetadata ,
21+ )
1722from tpu_inference .logger import init_logger
1823from tpu_inference .utils import device_array
1924
@@ -79,6 +84,9 @@ def capture_model(self) -> None:
7984 self ._run_compilation , )
8085 self ._precompile_input_embeddings_merger ()
8186 self ._precompile_backbone_with_inputs_embeds ()
87+ if self .runner .is_pooling_model :
88+ self ._precompile_pooling ()
89+ return
8290 if self .runner .scheduler_config .async_scheduling :
8391 self ._precompile_substitute_placeholder_token ()
8492 self ._precompile_select_from_array ()
@@ -90,6 +98,68 @@ def capture_model(self) -> None:
9098 if self .runner .speculative_config :
9199 self ._precompile_speculative_decoding ()
92100
101+ def _precompile_pooling (self ) -> None :
102+ pooler = getattr (self .runner , "pooler" , None )
103+ if pooler is None :
104+ logger .warning (
105+ "Pooling precompile skipped because model has no pooler attribute." )
106+ return
107+
108+ logger .info ("Precompile pooling kernels for pooling models." )
109+
110+ hidden_size = self .runner .model_config .get_hidden_size ()
111+ dtype = self .runner .model_config .dtype
112+ hidden_sharding = NamedSharding (
113+ self .runner .mesh , PartitionSpec (None , None ))
114+
115+ for num_tokens in self .runner .num_tokens_paddings :
116+ hidden_states = self ._create_dummy_tensor (
117+ (num_tokens , hidden_size ), dtype , sharding = hidden_sharding )
118+
119+ for num_reqs in self .runner .num_reqs_paddings :
120+ if num_reqs == 0 or num_reqs > num_tokens :
121+ continue
122+
123+ prompt_lens = np .ones (num_reqs , dtype = np .int32 )
124+ first_token_indices = np .arange (num_reqs , dtype = np .int32 )
125+ last_token_indices = first_token_indices .copy ()
126+ normalize = np .ones (num_reqs , dtype = np .int8 )
127+
128+ (
129+ prompt_lens ,
130+ normalize ,
131+ first_token_indices ,
132+ last_token_indices ,
133+ ) = device_array (
134+ self .runner .mesh ,
135+ (
136+ prompt_lens ,
137+ normalize ,
138+ first_token_indices ,
139+ last_token_indices ,
140+ ),
141+ )
142+
143+ pooling_metadata = TPUSupportedPoolingMetadata (
144+ prompt_lens = prompt_lens ,
145+ first_token_indices = first_token_indices ,
146+ last_token_indices = last_token_indices ,
147+ normalize = normalize ,
148+ num_reqs = num_reqs ,
149+ padded_num_reqs = num_reqs ,
150+ task = "embed" ,
151+ )
152+
153+ self ._run_compilation (
154+ "pool" ,
155+ pool ,
156+ hidden_states ,
157+ pooling_metadata ,
158+ pooler ,
159+ num_tokens = num_tokens ,
160+ num_reqs = num_reqs ,
161+ )
162+
93163 def _precompile_input_embeddings_merger (self ) -> None :
94164 for num_tokens in self .runner .num_tokens_paddings :
95165 hidden_size = self .runner .vllm_config .model_config .get_hidden_size (
0 commit comments