From 232b99f4ce74acef3167a06c5754e331371e0f45 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Sat, 8 Nov 2025 11:46:05 +0000 Subject: [PATCH 1/2] fix cross attention Signed-off-by: fsx950223 --- vllm/v1/attention/backends/triton_attn.py | 17 +++++++++-------- vllm/v1/worker/utils.py | 6 +++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b1d34dbfd172..f8271aadebcf 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -244,14 +244,11 @@ def __init__( TritonAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl" + "Encoder self-attention is not implemented for TritonAttentionImpl" ) - + self.attn_type = attn_type self.fp8_dtype = current_platform.fp8_dtype() self.sinks = sinks @@ -312,7 +309,11 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(1) - if self.kv_sharing_target_layer_name is None: + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. if self.kv_cache_dtype.startswith("fp8"): @@ -346,7 +347,7 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) unified_attention( q=query[:num_actual_tokens], diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 0ca7e81a5c7b..7d184fefb0e1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -316,7 +316,11 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if current_platform.is_cuda() or current_platform.is_xpu(): + if ( + current_platform.is_cuda() + or current_platform.is_xpu() + or current_platform.is_rocm() + ): # We know that the GPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this. From 9a17437c758f785569a4c436f7a4ce26459b55a5 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 13 Nov 2025 03:49:46 +0000 Subject: [PATCH 2/2] change api Signed-off-by: fsx950223 --- vllm/v1/worker/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7d184fefb0e1..9c21a87f4f51 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -316,11 +316,7 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if ( - current_platform.is_cuda() - or current_platform.is_xpu() - or current_platform.is_rocm() - ): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): # We know that the GPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this.