File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
vllm/v1/attention/backends Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -641,10 +641,6 @@ def _run_sdpa_forward(
641641 attn_metadata : TorchSDPAMetadata ,
642642 attn_type : str = AttentionType .DECODER ,
643643 ) -> None :
644- if self .num_kv_heads != self .num_heads :
645- key = key .repeat_interleave (self .num_queries_per_kv , dim = 1 )
646- value = value .repeat_interleave (self .num_queries_per_kv , dim = 1 )
647-
648644 attn_masks = attn_metadata .get_attn_bias (attn_type )
649645 if attn_masks is None :
650646 if self .alibi_slopes is not None :
@@ -665,6 +661,10 @@ def _run_sdpa_forward(
665661 key = key .movedim (0 , key .dim () - 2 )
666662 value = value .movedim (0 , value .dim () - 2 )
667663
664+ if self .num_kv_heads != self .num_heads :
665+ key = key .repeat_interleave (self .num_queries_per_kv , dim = - 3 )
666+ value = value .repeat_interleave (self .num_queries_per_kv , dim = - 3 )
667+
668668 causal_attn = (attn_type == AttentionType .DECODER )
669669
670670 seq_lens_q , seq_lens_kv = attn_metadata .get_seq_lens (attn_type )
You can’t perform that action at this time.
0 commit comments