-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
fix cross attention #28346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix cross attention #28346
Changes from all commits
232b99f
a5bd54a
9a17437
086ebdf
f5c18a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -244,14 +244,11 @@ def __init__( | |
|
|
||
| self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
|
||
| if attn_type != AttentionType.DECODER: | ||
| if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: | ||
| raise NotImplementedError( | ||
| "Encoder self-attention and " | ||
| "encoder/decoder cross-attention " | ||
| "are not implemented for " | ||
| "TritonAttentionImpl" | ||
| "Encoder self-attention is not implemented for TritonAttentionImpl" | ||
| ) | ||
|
|
||
| self.attn_type = attn_type | ||
| self.fp8_dtype = current_platform.fp8_dtype() | ||
|
|
||
| self.sinks = sinks | ||
|
|
@@ -312,7 +309,11 @@ def forward( | |
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
| key_cache, value_cache = kv_cache.unbind(1) | ||
|
|
||
| if self.kv_sharing_target_layer_name is None: | ||
| if ( | ||
| self.kv_sharing_target_layer_name is None | ||
| and key is not None | ||
| and value is not None | ||
| ): | ||
| # Reshape the input keys and values and store them in the cache. | ||
| # Skip this if sharing KV cache with an earlier attention layer. | ||
| if self.kv_cache_dtype.startswith("fp8"): | ||
|
|
@@ -346,7 +347,7 @@ def forward( | |
| max_seqlen_k = attn_metadata.max_seq_len | ||
| block_table = attn_metadata.block_table | ||
|
|
||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) | ||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is a noop right? We are just getting the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just leave the original code then?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, key may be None.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No |
||
|
|
||
| unified_attention( | ||
| q=query[:num_actual_tokens], | ||
|
Comment on lines
347
to
353
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The constructor now allows Useful? React with 👍 / 👎. |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change enables
AttentionType.ENCODER_DECODER, but theforwardmethod ofTritonAttentionImplsubsequently callsunified_attentionwithcausal=Truehardcoded. Cross-attention is non-causal, so this will result in incorrect attention calculations for encoder-decoder models.Furthermore, the
unified_attentionfunction itself asserts that it only supports causal attention. To properly support cross-attention,unified_attentionand its underlying Triton kernel must be updated to handle non-causal cases. Without this, enablingENCODER_DECODERhere introduces a bug.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes