diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 889c79db18ef..09c36043c8c8 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -244,14 +244,11 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type != AttentionType.DECODER: + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl" + "Encoder self-attention is not implemented for TritonAttentionImpl" ) - + self.attn_type = attn_type self.fp8_dtype = current_platform.fp8_dtype() self.sinks = sinks @@ -312,7 +309,11 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(1) - if self.kv_sharing_target_layer_name is None: + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. if self.kv_cache_dtype.startswith("fp8"): @@ -346,7 +347,7 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) unified_attention( q=query[:num_actual_tokens],