Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,11 @@ def __init__(

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl"
"Encoder self-attention is not implemented for TritonAttentionImpl"
)

self.attn_type = attn_type
Comment on lines +247 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change enables AttentionType.ENCODER_DECODER, but the forward method of TritonAttentionImpl subsequently calls unified_attention with causal=True hardcoded. Cross-attention is non-causal, so this will result in incorrect attention calculations for encoder-decoder models.

Furthermore, the unified_attention function itself asserts that it only supports causal attention. To properly support cross-attention, unified_attention and its underlying Triton kernel must be updated to handle non-causal cases. Without this, enabling ENCODER_DECODER here introduces a bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

self.fp8_dtype = current_platform.fp8_dtype()

self.sinks = sinks
Expand Down Expand Up @@ -312,7 +309,11 @@ def forward(
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)

if self.kv_sharing_target_layer_name is None:
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
Expand Down Expand Up @@ -346,7 +347,7 @@ def forward(
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is a noop right? We are just getting the num_kv_heads from a different spot? Is there some subtle difference that I'm missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just leave the original code then?

Copy link
Contributor Author

@fsx950223 fsx950223 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, key may be None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No


unified_attention(
q=query[:num_actual_tokens],
Comment on lines 347 to 353

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cross-attention still forced through causal mask

The constructor now allows AttentionType.ENCODER_DECODER, but the execution path still unconditionally calls unified_attention(..., causal=True, …). Cross-attention builders explicitly set attn_metadata.causal = False so that queries can see the full encoder sequence. For batches where multiple decoder queries are processed at once (e.g. teacher-forced decoding or speculative decoding of several tokens), keeping causal=True restricts each query to only the prefix of encoder keys and produces incorrect attention weights. If cross-attention is intended to work for Triton, this should respect attn_metadata.causal (and the kernel would need to support the non-causal case); otherwise the change to accept ENCODER_DECODER silently yields wrong results.

Useful? React with 👍 / 👎.

Expand Down