Skip to content

Commit 2a4c027

Browse files
[ROCm][CI] Changed to flex attention for cross-attention
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent 0e62494 commit 2a4c027

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,19 @@
2828

2929

3030
@pytest.fixture(scope="module", autouse=True)
31-
def rocm_aiter_fa_attention():
31+
def rocm_flex_attention():
3232
"""
33-
Automatically sets VLLM_ATTENTION_BACKEND=ROCM_AITER_FA for ROCm
34-
for the duration of this test module.
33+
Sets VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for ROCm
34+
for the duration of this test module. For now the only
35+
attention backend that supports cross attention on ROCm
36+
is FLEX_ATTENTION.
3537
"""
3638
from vllm.platforms import current_platform
3739

3840
if current_platform.is_rocm():
39-
# Store previous value to restore later (cleanup)
4041
old_backend = os.environ.get("VLLM_ATTENTION_BACKEND")
41-
# Set the specific backend required for audio models on ROCm
42-
os.environ["VLLM_ATTENTION_BACKEND"] = "ROCM_AITER_FA"
42+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
4343
yield
44-
# Cleanup: Restore the environment
4544
if old_backend is None:
4645
del os.environ["VLLM_ATTENTION_BACKEND"]
4746
else:

tests/entrypoints/openai/test_translation_validation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,15 @@
2222
@pytest.fixture(scope="module", autouse=True)
2323
def rocm_aiter_fa_attention():
2424
"""
25-
Automatically sets VLLM_ATTENTION_BACKEND=ROCM_AITER_FA for ROCm
25+
Sets VLLM_ATTENTION_BACKEND=ROCM_AITER_FA for ROCm
2626
for the duration of this test module.
2727
"""
2828
from vllm.platforms import current_platform
2929

3030
if current_platform.is_rocm():
31-
# Store previous value to restore later (cleanup)
3231
old_backend = os.environ.get("VLLM_ATTENTION_BACKEND")
33-
# Set the specific backend required for audio models on ROCm
3432
os.environ["VLLM_ATTENTION_BACKEND"] = "ROCM_AITER_FA"
3533
yield
36-
# Cleanup: Restore the environment
3734
if old_backend is None:
3835
del os.environ["VLLM_ATTENTION_BACKEND"]
3936
else:

0 commit comments

Comments
 (0)