1919)
2020
2121from vllm .attention .layer import MultiHeadAttention
22+ from vllm .attention .layers .encoder_only_attention import EncoderOnlyAttention
2223from vllm .config import VllmConfig
2324from vllm .config .multimodal import BaseDummyOptions
2425from vllm .distributed import divide , get_tensor_model_parallel_world_size
@@ -379,6 +380,7 @@ def __init__(
379380 quant_config : QuantizationConfig | None = None ,
380381 * ,
381382 prefix : str = "" ,
383+ attn_cls : type [EncoderOnlyAttention ] | type [MultiHeadAttention ],
382384 ) -> None :
383385 super ().__init__ ()
384386
@@ -413,8 +415,11 @@ def __init__(
413415 self .tp_size = get_tensor_model_parallel_world_size ()
414416 self .num_heads_per_partition = divide (self .num_heads , self .tp_size )
415417
416- self .attn = MultiHeadAttention (
417- self .num_heads_per_partition , self .head_dim , self .scale
418+ self .attn = attn_cls (
419+ self .num_heads_per_partition ,
420+ self .head_dim ,
421+ self .scale ,
422+ prefix = f"{ prefix } .attn" ,
418423 )
419424
420425 def forward (
@@ -424,25 +429,7 @@ def forward(
424429 """Input shape: Batch x Time x Channel"""
425430 qkv_states , _ = self .qkv_proj (hidden_states )
426431 query_states , key_states , value_states = qkv_states .chunk (3 , dim = - 1 )
427-
428- needs_unsqueeze = query_states .ndim == 2
429- if needs_unsqueeze :
430- query_states , key_states , value_states = (
431- query_states .unsqueeze (0 ),
432- key_states .unsqueeze (0 ),
433- value_states .unsqueeze (0 ),
434- )
435-
436432 out = self .attn (query_states , key_states , value_states )
437-
438- if needs_unsqueeze :
439- out , query_states , key_states , value_states = (
440- out .squeeze (0 ),
441- query_states .squeeze (0 ),
442- key_states .squeeze (0 ),
443- value_states .squeeze (0 ),
444- )
445-
446433 attn_output , _ = self .out_proj (out )
447434
448435 return attn_output , None
@@ -495,6 +482,7 @@ def __init__(
495482 quant_config : QuantizationConfig | None = None ,
496483 * ,
497484 prefix : str = "" ,
485+ attn_cls : type [EncoderOnlyAttention ] | type [MultiHeadAttention ],
498486 ) -> None :
499487 super ().__init__ ()
500488
@@ -504,6 +492,7 @@ def __init__(
504492 config ,
505493 quant_config = quant_config ,
506494 prefix = f"{ prefix } .self_attn" ,
495+ attn_cls = attn_cls ,
507496 )
508497 self .layer_norm1 = nn .LayerNorm (self .embed_dim , eps = config .layer_norm_eps )
509498 self .mlp = SiglipMLP (
@@ -539,6 +528,7 @@ def __init__(
539528 num_hidden_layers_override : int | None = None ,
540529 * ,
541530 prefix : str = "" ,
531+ attn_cls : type [EncoderOnlyAttention ] | type [MultiHeadAttention ],
542532 ) -> None :
543533 super ().__init__ ()
544534
@@ -555,6 +545,7 @@ def __init__(
555545 config ,
556546 quant_config = quant_config ,
557547 prefix = f"{ prefix } .layers.{ layer_idx } " ,
548+ attn_cls = attn_cls ,
558549 )
559550 for layer_idx in range (num_hidden_layers )
560551 ]
@@ -598,6 +589,7 @@ def __init__(
598589 config = config ,
599590 quant_config = quant_config ,
600591 prefix = f"{ prefix } .encoder" ,
592+ attn_cls = EncoderOnlyAttention ,
601593 )
602594
603595 self .final_layer_norm = nn .LayerNorm (embed_dim , eps = config .layer_norm_eps )
@@ -709,6 +701,7 @@ def __init__(
709701 quant_config = quant_config ,
710702 num_hidden_layers_override = num_hidden_layers_override ,
711703 prefix = f"{ prefix } .encoder" ,
704+ attn_cls = MultiHeadAttention ,
712705 )
713706
714707 num_hidden_layers = config .num_hidden_layers
@@ -1034,10 +1027,56 @@ def get_text_features(
10341027 inputs_embeds = inputs_embeds ,
10351028 )
10361029 text_features = self .text_model .head (last_hidden_state )
1037- # Flip to extract CLS token (first token after reversal) for pooling
1038- text_features = text_features .flip (0 )
1030+
1031+ # SigLIP uses reversed position_ids;
1032+ # flip sequences to move EOS token to first position
1033+ text_features = self ._flip_sequences_by_position_ids (
1034+ text_features , position_ids
1035+ )
1036+
10391037 return text_features
10401038
1039+ def _flip_sequences_by_position_ids (
1040+ self ,
1041+ features : torch .Tensor ,
1042+ position_ids : torch .Tensor ,
1043+ ) -> torch .Tensor :
1044+ """Flip sequences so EOS token moves to first position for CLS pooling.
1045+
1046+ SigLIP position_ids are reversed within each sequence. This method detects
1047+ sequence boundaries and flips each sequence individually.
1048+ """
1049+ if len (features ) == 1 :
1050+ return features
1051+
1052+ # Detect sequence boundaries where position_ids decrease
1053+ position_diffs = position_ids [1 :] - position_ids [:- 1 ]
1054+ boundary_mask = position_diffs <= 0
1055+
1056+ boundary_indices = torch .cat (
1057+ [
1058+ torch .tensor ([0 ], device = features .device ),
1059+ torch .where (boundary_mask )[0 ] + 1 ,
1060+ torch .tensor ([len (features )], device = features .device ),
1061+ ]
1062+ )
1063+
1064+ # For each sequence [start, end), position i flips to: start + end - 1 - i
1065+ lengths = boundary_indices [1 :] - boundary_indices [:- 1 ]
1066+ starts = boundary_indices [:- 1 ]
1067+ ends = boundary_indices [1 :]
1068+
1069+ # Assign sequence ID to each element
1070+ sequence_ids = torch .arange (
1071+ len (lengths ), device = features .device
1072+ ).repeat_interleave (lengths )
1073+
1074+ # Calculate flipped indices for all positions at once
1075+ current_positions = torch .arange (len (features ), device = features .device )
1076+ flip_indices = starts [sequence_ids ] + ends [sequence_ids ] - 1 - current_positions
1077+
1078+ return features [flip_indices ]
1079+
10411080 def get_image_features (
10421081 self ,
10431082 pixel_values : torch .Tensor ,
0 commit comments