Skip to content

Commit 15be507

Browse files
authored
[bugfix] fix siglip batch text output error (#28365)
Signed-off-by: piood <2477084691@qq.com>
1 parent 6f7de33 commit 15be507

File tree

1 file changed

+61
-22
lines changed

1 file changed

+61
-22
lines changed

vllm/model_executor/models/siglip.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020

2121
from vllm.attention.layer import MultiHeadAttention
22+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
2223
from vllm.config import VllmConfig
2324
from vllm.config.multimodal import BaseDummyOptions
2425
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -379,6 +380,7 @@ def __init__(
379380
quant_config: QuantizationConfig | None = None,
380381
*,
381382
prefix: str = "",
383+
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
382384
) -> None:
383385
super().__init__()
384386

@@ -413,8 +415,11 @@ def __init__(
413415
self.tp_size = get_tensor_model_parallel_world_size()
414416
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
415417

416-
self.attn = MultiHeadAttention(
417-
self.num_heads_per_partition, self.head_dim, self.scale
418+
self.attn = attn_cls(
419+
self.num_heads_per_partition,
420+
self.head_dim,
421+
self.scale,
422+
prefix=f"{prefix}.attn",
418423
)
419424

420425
def forward(
@@ -424,25 +429,7 @@ def forward(
424429
"""Input shape: Batch x Time x Channel"""
425430
qkv_states, _ = self.qkv_proj(hidden_states)
426431
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
427-
428-
needs_unsqueeze = query_states.ndim == 2
429-
if needs_unsqueeze:
430-
query_states, key_states, value_states = (
431-
query_states.unsqueeze(0),
432-
key_states.unsqueeze(0),
433-
value_states.unsqueeze(0),
434-
)
435-
436432
out = self.attn(query_states, key_states, value_states)
437-
438-
if needs_unsqueeze:
439-
out, query_states, key_states, value_states = (
440-
out.squeeze(0),
441-
query_states.squeeze(0),
442-
key_states.squeeze(0),
443-
value_states.squeeze(0),
444-
)
445-
446433
attn_output, _ = self.out_proj(out)
447434

448435
return attn_output, None
@@ -495,6 +482,7 @@ def __init__(
495482
quant_config: QuantizationConfig | None = None,
496483
*,
497484
prefix: str = "",
485+
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
498486
) -> None:
499487
super().__init__()
500488

@@ -504,6 +492,7 @@ def __init__(
504492
config,
505493
quant_config=quant_config,
506494
prefix=f"{prefix}.self_attn",
495+
attn_cls=attn_cls,
507496
)
508497
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
509498
self.mlp = SiglipMLP(
@@ -539,6 +528,7 @@ def __init__(
539528
num_hidden_layers_override: int | None = None,
540529
*,
541530
prefix: str = "",
531+
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
542532
) -> None:
543533
super().__init__()
544534

@@ -555,6 +545,7 @@ def __init__(
555545
config,
556546
quant_config=quant_config,
557547
prefix=f"{prefix}.layers.{layer_idx}",
548+
attn_cls=attn_cls,
558549
)
559550
for layer_idx in range(num_hidden_layers)
560551
]
@@ -598,6 +589,7 @@ def __init__(
598589
config=config,
599590
quant_config=quant_config,
600591
prefix=f"{prefix}.encoder",
592+
attn_cls=EncoderOnlyAttention,
601593
)
602594

603595
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@@ -709,6 +701,7 @@ def __init__(
709701
quant_config=quant_config,
710702
num_hidden_layers_override=num_hidden_layers_override,
711703
prefix=f"{prefix}.encoder",
704+
attn_cls=MultiHeadAttention,
712705
)
713706

714707
num_hidden_layers = config.num_hidden_layers
@@ -1034,10 +1027,56 @@ def get_text_features(
10341027
inputs_embeds=inputs_embeds,
10351028
)
10361029
text_features = self.text_model.head(last_hidden_state)
1037-
# Flip to extract CLS token (first token after reversal) for pooling
1038-
text_features = text_features.flip(0)
1030+
1031+
# SigLIP uses reversed position_ids;
1032+
# flip sequences to move EOS token to first position
1033+
text_features = self._flip_sequences_by_position_ids(
1034+
text_features, position_ids
1035+
)
1036+
10391037
return text_features
10401038

1039+
def _flip_sequences_by_position_ids(
1040+
self,
1041+
features: torch.Tensor,
1042+
position_ids: torch.Tensor,
1043+
) -> torch.Tensor:
1044+
"""Flip sequences so EOS token moves to first position for CLS pooling.
1045+
1046+
SigLIP position_ids are reversed within each sequence. This method detects
1047+
sequence boundaries and flips each sequence individually.
1048+
"""
1049+
if len(features) == 1:
1050+
return features
1051+
1052+
# Detect sequence boundaries where position_ids decrease
1053+
position_diffs = position_ids[1:] - position_ids[:-1]
1054+
boundary_mask = position_diffs <= 0
1055+
1056+
boundary_indices = torch.cat(
1057+
[
1058+
torch.tensor([0], device=features.device),
1059+
torch.where(boundary_mask)[0] + 1,
1060+
torch.tensor([len(features)], device=features.device),
1061+
]
1062+
)
1063+
1064+
# For each sequence [start, end), position i flips to: start + end - 1 - i
1065+
lengths = boundary_indices[1:] - boundary_indices[:-1]
1066+
starts = boundary_indices[:-1]
1067+
ends = boundary_indices[1:]
1068+
1069+
# Assign sequence ID to each element
1070+
sequence_ids = torch.arange(
1071+
len(lengths), device=features.device
1072+
).repeat_interleave(lengths)
1073+
1074+
# Calculate flipped indices for all positions at once
1075+
current_positions = torch.arange(len(features), device=features.device)
1076+
flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions
1077+
1078+
return features[flip_indices]
1079+
10411080
def get_image_features(
10421081
self,
10431082
pixel_values: torch.Tensor,

0 commit comments

Comments
 (0)