Skip to content

Commit c1b4188

Browse files
authored
[Bugfix] Fix attention backend signature (#1103)
1 parent d18b3ed commit c1b4188

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import vllm.envs as vllm_envs
88
from torchax.ops.mappings import j2t_dtype
99
from tpu_info import device
10+
from vllm.attention.backends.abstract import AttentionType
1011
from vllm.inputs import ProcessorInputs, PromptType
1112
from vllm.platforms.interface import Platform, PlatformEnum
1213
from vllm.sampling_params import SamplingParams, SamplingType
@@ -57,7 +58,8 @@ class TpuPlatform(Platform):
5758
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
5859
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
5960
block_size: int, use_v1: bool, use_mla: bool,
60-
has_sink: bool, use_sparse: bool) -> str:
61+
has_sink: bool, use_sparse: bool,
62+
attn_type: AttentionType) -> str:
6163
from vllm.attention.backends.registry import _Backend
6264
if selected_backend != _Backend.PALLAS:
6365
logger.info("Cannot use %s backend on TPU.", selected_backend)

0 commit comments

Comments
 (0)