Skip to content

Conversation

@fsx950223
Copy link
Contributor

@fsx950223 fsx950223 commented Nov 8, 2025

Purpose

Fix #27442

Test Plan

pytest entrypoints/openai/test_transcription_validation.py::test_basic_audio[openai/whisper-large-v3-turbo] --maxfail=1

Test Result

=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
rootdir: /mnt/raid0/sixifang/vllm
configfile: pyproject.toml
plugins: asyncio-1.1.0, anyio-4.10.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 1 item                                                                                                                                                                  

entrypoints/openai/test_transcription_validation.py .                                                                                                                       [100%]

================================================================================ warnings summary =================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../../../../usr/lib/python3/dist-packages/pyparsing.py:108
  /usr/lib/python3/dist-packages/pyparsing.py:108: DeprecationWarning: module 'sre_constants' is deprecated
    import sre_constants

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================== 1 passed, 3 warnings in 98.74s (0:01:38) =====================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: fsx950223 <fsx950223@outlook.com>
@fsx950223 fsx950223 requested a review from tdoublep as a code owner November 8, 2025 11:47
@mergify mergify bot added the v1 label Nov 8, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to add cross-attention support to the Triton attention backend. While the intent is good, it introduces a critical issue where cross-attention is incorrectly treated as causal attention, which will lead to incorrect model outputs. The underlying Triton kernel, unified_attention, currently only supports causal attention and would need to be modified to correctly handle cross-attention. The other change to add ROCm support in a utility function appears correct.

Comment on lines +247 to +251
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl"
"Encoder self-attention is not implemented for TritonAttentionImpl"
)

self.attn_type = attn_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change enables AttentionType.ENCODER_DECODER, but the forward method of TritonAttentionImpl subsequently calls unified_attention with causal=True hardcoded. Cross-attention is non-causal, so this will result in incorrect attention calculations for encoder-decoder models.

Furthermore, the unified_attention function itself asserts that it only supports causal attention. To properly support cross-attention, unified_attention and its underlying Triton kernel must be updated to handle non-causal cases. Without this, enabling ENCODER_DECODER here introduces a bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 347 to 353
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])

unified_attention(
q=query[:num_actual_tokens],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cross-attention still forced through causal mask

The constructor now allows AttentionType.ENCODER_DECODER, but the execution path still unconditionally calls unified_attention(..., causal=True, …). Cross-attention builders explicitly set attn_metadata.causal = False so that queries can see the full encoder sequence. For batches where multiple decoder queries are processed at once (e.g. teacher-forced decoding or speculative decoding of several tokens), keeping causal=True restricts each query to only the prefix of encoder keys and produces incorrect attention weights. If cross-attention is intended to work for Triton, this should respect attn_metadata.causal (and the kernel would need to support the non-causal case); otherwise the change to accept ENCODER_DECODER silently yields wrong results.

Useful? React with 👍 / 👎.

Copy link
Contributor

@maleksan85 maleksan85 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fsx950223.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
# in the same decoder block.
if current_platform.is_cuda() or current_platform.is_xpu():
if (
current_platform.is_cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: use current_platform.is_cuda_alike() and remove current_platform.is_rocm()

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 12, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to add cross-attention support to the Triton attention backend. While the changes in vllm/v1/worker/utils.py and the robustness improvement in triton_attn.py are correct, there is a critical issue in the main logic for handling cross-attention. The current implementation incorrectly attempts to cache encoder keys and values into the decoder's paged KV cache, which will lead to memory corruption or a crash. I've provided a detailed comment and a suggested fix to prevent this bug.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 12, 2025

@SageMoore @tdoublep @bringlein Could you help to take a look? Thank you.

Signed-off-by: fsx950223 <fsx950223@outlook.com>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 19, 2025

@SageMoore @tdoublep @bringlein Could you help to take a look? Thank you.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Just one question.

@@ -346,7 +347,7 @@ def forward(
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is a noop right? We are just getting the num_kv_heads from a different spot? Is there some subtle difference that I'm missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just leave the original code then?

Copy link
Contributor Author

@fsx950223 fsx950223 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, key may be None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 21, 2025
@sammysun0711
Copy link
Contributor

@fsx950223, could you please rebase?

@DarkLight1337
Copy link
Member

@vllm-bot vllm-bot merged commit fc9f821 into vllm-project:main Nov 21, 2025
42 of 45 checks passed
ywang96 pushed a commit to ywang96/vllm that referenced this pull request Nov 23, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[CI Failure][AMD] Encoder-Decoder Models Fail on AMD CI

9 participants