Skip to content

Commit b9ce9a3

Browse files
authored
[BugFix] Add fallback path in apply_rotary_pos_emb_flashattn for non-cuda platforms (#28447)
Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
1 parent 4ccffe5 commit b9ce9a3

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

vllm/model_executor/models/keye.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,13 @@ def apply_rotary_pos_emb_flashatt(
346346
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
347347
elif current_platform.is_rocm():
348348
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
349+
else:
350+
# For other platforms, use PyTorch fallback
351+
from vllm.model_executor.layers.rotary_embedding.common import (
352+
apply_rotary_emb_torch,
353+
)
354+
355+
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
349356

350357
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
351358
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)

0 commit comments

Comments
 (0)