Skip to content

Commit 2baaaaa

Browse files
committed
it's up and runing well (atleast for qwen-embed and lastpooling)
1 parent abe35dc commit 2baaaaa

File tree

9 files changed

+206
-146
lines changed

9 files changed

+206
-146
lines changed

tpu_inference/layers/jax/pool/pooler.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,9 @@
99
from vllm.config.pooler import PoolerConfig
1010

1111

12-
@jax.tree_util.register_dataclass
13-
@dataclass
14-
class PoolingResult:
15-
"""Outputs produced by pooling kernels."""
16-
17-
num_reqs: int
18-
pooler_output: jax.Array # [padded_num_reqs, dim]
19-
# or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool
12+
# [padded_num_reqs, dim]
13+
# or [padded_num_reqs, padded_max_num_batchec_token_per_req, dim] for allpool
14+
PoolerOutput = jax.Array
2015

2116

2217
class PoolingType(enum.Enum):
@@ -139,7 +134,7 @@ def __call__(
139134
token_embeddings: jax.Array,
140135
token_mask: jax.Array,
141136
pooling_metadata: TPUSupportedPoolingMetadata,
142-
) -> PoolingResult:
137+
) -> PoolerOutput:
143138
raise NotImplementedError
144139

145140

@@ -152,7 +147,7 @@ def __call__(
152147
self,
153148
pooled: jax.Array,
154149
pooling_metadata: TPUSupportedPoolingMetadata,
155-
) -> PoolingResult:
150+
) -> PoolerOutput:
156151

157152
# In the torch version, this part should handle other computations related to pooling_params, such as
158153
# normalization and truncating the embedding dimensions (for matryoshka models).
@@ -166,10 +161,7 @@ def __call__(
166161
if self.default_normalize:
167162
pooled = normalize(pooled)
168163

169-
return PoolingResult(
170-
num_reqs=pooling_metadata.num_reqs,
171-
pooler_output=pooled,
172-
)
164+
return pooled
173165

174166

175167
class Pooler(nnx.Module):
@@ -187,7 +179,7 @@ def __call__(
187179
self,
188180
hidden_states: jax.Array,
189181
pooling_metadata: TPUSupportedPoolingMetadata,
190-
) -> PoolingResult:
182+
) -> PoolerOutput:
191183
raise NotImplementedError
192184

193185
def get_supported_tasks(self) -> set[str]:
@@ -213,7 +205,9 @@ def __call__(
213205
self,
214206
hidden_states: jax.Array,
215207
pooling_metadata: TPUSupportedPoolingMetadata,
216-
) -> PoolingResult:
208+
) -> PoolerOutput:
209+
hidden_states = hidden_states.astype(jnp.float32)
210+
# the output mus be of type torch.tensor, but we cannot convert numpy to torch if the dtype is bf16
217211
pooled = self.pooling(hidden_states, pooling_metadata)
218212
return self.head(pooled, pooling_metadata)
219213

tpu_inference/layers/jax/pool/pooling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from .pooler import Pooler, PoolerOutput
66
from .pooling_metadata import TPUSupportedPoolingMetadata
77

8+
9+
# actually my idea is not to jist this function but the model.pooler,
10+
# we can put some postprocesing here.
811
@jax.jit
912
def pool(
1013
hidden_states: jax.Array,

tpu_inference/layers/jax/pool/pooling_metadata.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def build_pooling_cursor(
2525
assert len(prompt_lens) == len(num_scheduled_tokens)
2626

2727
n_seq = len(num_scheduled_tokens)
28-
num_sched_tokens_padded = jnp.zeros(padded_num_seqs)
29-
num_sched_tokens_padded = num_sched_tokens_padded.at[:n_seq].set(
30-
jnp.asarary(num_scheduled_tokens, dtype=jnp.int32)
28+
num_scheduled_tokens_padded = jnp.zeros(padded_num_seqs)
29+
num_scheduled_tokens_padded = num_scheduled_tokens_padded.at[:n_seq].set(
30+
jnp.asarray(num_scheduled_tokens, dtype=jnp.int32)
3131
)
32-
cumsum = jnp.cumsum(num_scheduled_tokens)
33-
first_token_indices = jnp.concatenate((jnp.asarray(0), cumsum[:-1]))
34-
last_token_indices = first_token_indices + num_sched_tokens_padded - 1
32+
cumsum = jnp.cumsum(num_scheduled_tokens_padded, dtype = jnp.int64)
33+
first_token_indices = jnp.concatenate((jnp.asarray((0,)), cumsum[:-1]))
34+
last_token_indices = (first_token_indices + num_scheduled_tokens_padded - 1).astype(jnp.int64)
3535
last_token_indices = jnp.where(
36-
num_sched_tokens_padded > 0, last_token_indices, first_token_indices
36+
num_scheduled_tokens_padded > 0, last_token_indices, first_token_indices
3737
)
3838
return first_token_indices, last_token_indices
3939

@@ -42,11 +42,13 @@ def build_pooling_cursor(
4242
jax.tree_util.register_dataclass,
4343
data_fields=(
4444
"prompt_lens",
45+
"first_token_indices",
46+
"last_token_indices",
4547
"normalize",
4648
"num_reqs",
4749
"padded_num_reqs",
4850
),
49-
meta_fields=("task_id",),
51+
meta_fields=("task",),
5052
)
5153
@dataclass
5254
class TPUSupportedPoolingMetadata:

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def combine_hidden_states(graphdef, state, hidden_states):
281281
run_get_multimodal_embeddings, graphdef)
282282
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
283283
graphdef)
284-
lora_manager, model = None, None
284+
lora_manager, _ = None, None
285285
combine_hidden_states_fn = functools.partial(combine_hidden_states,
286286
graphdef)
287287

tpu_inference/models/jax/adapters.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,8 @@
44
from flax import nnx
55
from jax.sharding import Mesh
66

7-
from vllm.config import VllmConfig
8-
from vllm.model_executor.models.interfaces_base import (
9-
VllmModelForPooling,
10-
is_pooling_model,
11-
)
127
from tpu_inference.layers.jax.pool.pooler import Pooler
8+
from vllm.config import VllmConfig
139

1410
_T = tp.TypeVar("_T", bound=type[nnx.Module])
1511

@@ -18,6 +14,15 @@
1814
"ForConditionalGeneration",
1915
)
2016

17+
class PoolingMixin:
18+
"""
19+
same as VllmModelForPooling
20+
"""
21+
is_pooling_model: tp.ClassVar[tp.Literal[True]] = True
22+
23+
default_pooling_type: tp.ClassVar[str] = "LAST"
24+
pooler: Pooler
25+
2126

2227
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
2328
model_name = orig_model_name
@@ -27,7 +32,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
2732

2833

2934
def _create_pooling_model_cls(orig_cls: _T) -> _T:
30-
class ModelForPooling(orig_cls, VllmModelForPooling):
35+
class ModelForPooling(orig_cls, PoolingMixin):
3136
is_pooling_model = True
3237

3338
def __init__(

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def _load_hf_weights_on_thread(vllm_config,
316316
if hf_key.endswith(".weight"):
317317
hf_key = hf_key.removesuffix(".weight")
318318

319+
if not hf_key.startswith('models.'):
320+
hf_key = 'model.' + hf_key
321+
319322
# Find the corresponding model key using the HF key
320323
if "layers" in hf_key:
321324
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)

tpu_inference/runner/compilation_manager.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
from tpu_inference.core.disagg_utils import is_disagg_enabled
1212
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1313
from tpu_inference.layers.common.sharding import ShardingAxisName
14+
from tpu_inference.layers.jax.pool.pooling import pool
15+
from tpu_inference.layers.jax.pool.pooling_metadata import (
16+
TPUSupportedPoolingMetadata,
17+
)
1418
from tpu_inference.layers.jax.sample.sampling import sample
15-
from tpu_inference.layers.jax.sample.sampling_metadata import \
16-
TPUSupportedSamplingMetadata
19+
from tpu_inference.layers.jax.sample.sampling_metadata import (
20+
TPUSupportedSamplingMetadata,
21+
)
1722
from tpu_inference.logger import init_logger
1823
from tpu_inference.utils import device_array
1924

@@ -79,6 +84,9 @@ def capture_model(self) -> None:
7984
self._run_compilation, )
8085
self._precompile_input_embeddings_merger()
8186
self._precompile_backbone_with_inputs_embeds()
87+
if self.runner.is_pooling_model:
88+
self._precompile_pooling()
89+
return
8290
if self.runner.scheduler_config.async_scheduling:
8391
self._precompile_substitute_placeholder_token()
8492
self._precompile_select_from_array()
@@ -90,6 +98,68 @@ def capture_model(self) -> None:
9098
if self.runner.speculative_config:
9199
self._precompile_speculative_decoding()
92100

101+
def _precompile_pooling(self) -> None:
102+
pooler = getattr(self.runner, "pooler", None)
103+
if pooler is None:
104+
logger.warning(
105+
"Pooling precompile skipped because model has no pooler attribute.")
106+
return
107+
108+
logger.info("Precompile pooling kernels for pooling models.")
109+
110+
hidden_size = self.runner.model_config.get_hidden_size()
111+
dtype = self.runner.model_config.dtype
112+
hidden_sharding = NamedSharding(
113+
self.runner.mesh, PartitionSpec(None, None))
114+
115+
for num_tokens in self.runner.num_tokens_paddings:
116+
hidden_states = self._create_dummy_tensor(
117+
(num_tokens, hidden_size), dtype, sharding=hidden_sharding)
118+
119+
for num_reqs in self.runner.num_reqs_paddings:
120+
if num_reqs == 0 or num_reqs > num_tokens:
121+
continue
122+
123+
prompt_lens = np.ones(num_reqs, dtype=np.int32)
124+
first_token_indices = np.arange(num_reqs, dtype=np.int32)
125+
last_token_indices = first_token_indices.copy()
126+
normalize = np.ones(num_reqs, dtype=np.int8)
127+
128+
(
129+
prompt_lens,
130+
normalize,
131+
first_token_indices,
132+
last_token_indices,
133+
) = device_array(
134+
self.runner.mesh,
135+
(
136+
prompt_lens,
137+
normalize,
138+
first_token_indices,
139+
last_token_indices,
140+
),
141+
)
142+
143+
pooling_metadata = TPUSupportedPoolingMetadata(
144+
prompt_lens=prompt_lens,
145+
first_token_indices=first_token_indices,
146+
last_token_indices=last_token_indices,
147+
normalize=normalize,
148+
num_reqs=num_reqs,
149+
padded_num_reqs=num_reqs,
150+
task="embed",
151+
)
152+
153+
self._run_compilation(
154+
"pool",
155+
pool,
156+
hidden_states,
157+
pooling_metadata,
158+
pooler,
159+
num_tokens=num_tokens,
160+
num_reqs=num_reqs,
161+
)
162+
93163
def _precompile_input_embeddings_merger(self) -> None:
94164
for num_tokens in self.runner.num_tokens_paddings:
95165
hidden_size = self.runner.vllm_config.model_config.get_hidden_size(

0 commit comments

Comments
 (0)