We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
apply_rotary_pos_emb_flashattn
1 parent 4ccffe5 commit b9ce9a3Copy full SHA for b9ce9a3
vllm/model_executor/models/keye.py
@@ -346,6 +346,13 @@ def apply_rotary_pos_emb_flashatt(
346
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
347
elif current_platform.is_rocm():
348
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
349
+ else:
350
+ # For other platforms, use PyTorch fallback
351
+ from vllm.model_executor.layers.rotary_embedding.common import (
352
+ apply_rotary_emb_torch,
353
+ )
354
+
355
+ apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
356
357
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
358
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
0 commit comments