Skip to content

Commit 91e4a58

Browse files
committed
add support for whisper v1 using aiter unified attention and aiter flash attention
Signed-off-by: apinge <Tong.Qiu2@amd.com>
1 parent ca00b1b commit 91e4a58

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -513,12 +513,9 @@ def __init__(
513513
assert self.num_heads % self.num_kv_heads == 0
514514
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
515515

516-
if attn_type != AttentionType.DECODER:
516+
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
517517
raise NotImplementedError(
518-
"Encoder self-attention and "
519-
"encoder/decoder cross-attention "
520-
"are not implemented for "
521-
"FlashAttentionImpl"
518+
"Encoder self-attention is not implemented for FlashAttentionImpl"
522519
)
523520

524521
def extend_forward(
@@ -674,7 +671,14 @@ def forward(
674671
# performance to make sure it does not introduce any overhead.
675672
num_actual_tokens = attn_metadata.num_actual_tokens
676673
key_cache, value_cache = kv_cache.unbind(0)
677-
if self.kv_sharing_target_layer_name is None:
674+
# key and value may be None in the case of cross attention. They are
675+
# calculated once based on the output from the encoder and then cached
676+
# in KV cache.
677+
if (
678+
self.kv_sharing_target_layer_name is None
679+
and key is not None
680+
and value is not None
681+
):
678682
# Reshape the input keys and values and store them in the cache.
679683
# Skip this if sharing KV cache with an earlier attention layer.
680684
# NOTE(woosuk): Here, key and value are padded while slot_mapping
@@ -700,8 +704,8 @@ def forward(
700704

701705
# decode:extend:prefill
702706
query = query[:num_actual_tokens]
703-
key = key[:num_actual_tokens]
704-
value = value[:num_actual_tokens]
707+
key = key[:num_actual_tokens] if key is not None else key_cache[:num_actual_tokens]
708+
value = value[:num_actual_tokens] if value is not None else value_cache[:num_actual_tokens]
705709

706710
output_actual_tokens = output[:num_actual_tokens]
707711

vllm/v1/attention/backends/rocm_aiter_unified_attn.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,14 @@ def forward(
142142

143143
key_cache, value_cache = kv_cache.unbind(0)
144144

145-
if self.kv_sharing_target_layer_name is None:
145+
# key and value may be None in the case of cross attention. They are
146+
# calculated once based on the output from the encoder and then cached
147+
# in KV cache.
148+
if (
149+
self.kv_sharing_target_layer_name is None
150+
and key is not None
151+
and value is not None
152+
):
146153
# Reshape the input keys and values and store them in the cache.
147154
# Skip this if sharing KV cache with an earlier attention layer.
148155
ops.reshape_and_cache_flash(
@@ -169,7 +176,7 @@ def forward(
169176
max_seqlen_k = attn_metadata.max_seq_len
170177
block_table = attn_metadata.block_table
171178

172-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
179+
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1] if key is not None else self.num_kv_heads)
173180

174181
self.unified_attention(
175182
q=query[:num_actual_tokens],

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,9 @@ def __init__(
238238

239239
RocmAttentionBackend.validate_head_size(head_size)
240240

241-
if attn_type != AttentionType.DECODER:
241+
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
242242
raise NotImplementedError(
243-
"Encoder self-attention and "
244-
"encoder/decoder cross-attention "
245-
"are not implemented for "
246-
"RocmAttentionImpl"
243+
"Encoder self-attention is not implemented for RocmAttentionImpl"
247244
)
248245

249246
self.fp8_dtype = current_platform.fp8_dtype()

0 commit comments

Comments
 (0)