-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
fix cross attention #28346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix cross attention #28346
Conversation
Signed-off-by: fsx950223 <fsx950223@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to add cross-attention support to the Triton attention backend. While the intent is good, it introduces a critical issue where cross-attention is incorrectly treated as causal attention, which will lead to incorrect model outputs. The underlying Triton kernel, unified_attention, currently only supports causal attention and would need to be modified to correctly handle cross-attention. The other change to add ROCm support in a utility function appears correct.
| if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: | ||
| raise NotImplementedError( | ||
| "Encoder self-attention and " | ||
| "encoder/decoder cross-attention " | ||
| "are not implemented for " | ||
| "TritonAttentionImpl" | ||
| "Encoder self-attention is not implemented for TritonAttentionImpl" | ||
| ) | ||
|
|
||
| self.attn_type = attn_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change enables AttentionType.ENCODER_DECODER, but the forward method of TritonAttentionImpl subsequently calls unified_attention with causal=True hardcoded. Cross-attention is non-causal, so this will result in incorrect attention calculations for encoder-decoder models.
Furthermore, the unified_attention function itself asserts that it only supports causal attention. To properly support cross-attention, unified_attention and its underlying Triton kernel must be updated to handle non-causal cases. Without this, enabling ENCODER_DECODER here introduces a bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| max_seqlen_k = attn_metadata.max_seq_len | ||
| block_table = attn_metadata.block_table | ||
|
|
||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) | ||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) | ||
|
|
||
| unified_attention( | ||
| q=query[:num_actual_tokens], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cross-attention still forced through causal mask
The constructor now allows AttentionType.ENCODER_DECODER, but the execution path still unconditionally calls unified_attention(..., causal=True, …). Cross-attention builders explicitly set attn_metadata.causal = False so that queries can see the full encoder sequence. For batches where multiple decoder queries are processed at once (e.g. teacher-forced decoding or speculative decoding of several tokens), keeping causal=True restricts each query to only the prefix of encoder keys and produces incorrect attention weights. If cross-attention is intended to work for Triton, this should respect attn_metadata.causal (and the kernel would need to support the non-causal case); otherwise the change to accept ENCODER_DECODER silently yields wrong results.
Useful? React with 👍 / 👎.
maleksan85
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: fsx950223 <fsx950223@outlook.com>
vllm/v1/worker/utils.py
Outdated
| # in the same decoder block. | ||
| if current_platform.is_cuda() or current_platform.is_xpu(): | ||
| if ( | ||
| current_platform.is_cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NITS: use current_platform.is_cuda_alike() and remove current_platform.is_rocm()
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to add cross-attention support to the Triton attention backend. While the changes in vllm/v1/worker/utils.py and the robustness improvement in triton_attn.py are correct, there is a critical issue in the main logic for handling cross-attention. The current implementation incorrectly attempts to cache encoder keys and values into the decoder's paged KV cache, which will lead to memory corruption or a crash. I've provided a detailed comment and a suggested fix to prevent this bug.
|
@SageMoore @tdoublep @bringlein Could you help to take a look? Thank you. |
Signed-off-by: fsx950223 <fsx950223@outlook.com>
2b244f5 to
9a17437
Compare
|
@SageMoore @tdoublep @bringlein Could you help to take a look? Thank you. |
SageMoore
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Just one question.
| @@ -346,7 +347,7 @@ def forward( | |||
| max_seqlen_k = attn_metadata.max_seq_len | |||
| block_table = attn_metadata.block_table | |||
|
|
|||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) | |||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is a noop right? We are just getting the num_kv_heads from a different spot? Is there some subtle difference that I'm missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just leave the original code then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, key may be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
|
@fsx950223, could you please rebase? |
|
H100 test is failing in latest nightly: https://buildkite.com/vllm/ci/builds/40068/steps/canvas?jid=019aa55f-e011-41eb-bf7f-b02500cc76d4#019aa55f-e011-41eb-bf7f-b02500cc76d4 So we can merge this |
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: fsx950223 <fsx950223@outlook.com> Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Purpose
Fix #27442
Test Plan
pytest entrypoints/openai/test_transcription_validation.py::test_basic_audio[openai/whisper-large-v3-turbo] --maxfail=1
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.