Skip to content

Commit 292d310

Browse files
authored
Fix sharding mismatch caused recompilation in Qwen2.5-vl-7b integration test (#1117)
Signed-off-by: Kewei Wang <keweiwang@google.com>
1 parent 36bd457 commit 292d310

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,11 @@ def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
242242
model = nnx.merge(graphdef, state)
243243
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
244244

245+
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
245246
# This function will calculates the embeddings of input texts and then merge with the image embeddings
246247
@functools.partial(
247248
jax.jit,
248-
out_shardings=(logits_sharding),
249+
out_shardings=(embed_sharding),
249250
)
250251
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
251252
model = nnx.merge(graphdef, state)

tpu_inference/runner/compilation_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,15 @@ def _precompile_select_from_array(self) -> None:
332332
index_paddings = self.runner.num_reqs_paddings
333333
dp_sharding = NamedSharding(self.runner.mesh,
334334
PartitionSpec(ShardingAxisName.ATTN_DATA))
335+
hidden_states_sharding = NamedSharding(
336+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
335337
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
336338
self._precompile_select_from_array_helper(
337339
name="select all logits",
338340
source_paddings=self.runner.num_tokens_paddings,
339341
indices_paddings=index_paddings,
340342
hidden_dim=hsize,
341-
input_sharding=dp_sharding,
343+
input_sharding=hidden_states_sharding,
342344
indices_sharding=dp_sharding if dp_size > 1 else None,
343345
)
344346

0 commit comments

Comments
 (0)