@@ -513,12 +513,9 @@ def __init__(
513513 assert self .num_heads % self .num_kv_heads == 0
514514 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
515515
516- if attn_type != AttentionType .DECODER :
516+ if attn_type not in [ AttentionType .DECODER , AttentionType . ENCODER_DECODER ] :
517517 raise NotImplementedError (
518- "Encoder self-attention and "
519- "encoder/decoder cross-attention "
520- "are not implemented for "
521- "FlashAttentionImpl"
518+ "Encoder self-attention is not implemented for FlashAttentionImpl"
522519 )
523520
524521 def extend_forward (
@@ -674,7 +671,14 @@ def forward(
674671 # performance to make sure it does not introduce any overhead.
675672 num_actual_tokens = attn_metadata .num_actual_tokens
676673 key_cache , value_cache = kv_cache .unbind (0 )
677- if self .kv_sharing_target_layer_name is None :
674+ # key and value may be None in the case of cross attention. They are
675+ # calculated once based on the output from the encoder and then cached
676+ # in KV cache.
677+ if (
678+ self .kv_sharing_target_layer_name is None
679+ and key is not None
680+ and value is not None
681+ ):
678682 # Reshape the input keys and values and store them in the cache.
679683 # Skip this if sharing KV cache with an earlier attention layer.
680684 # NOTE(woosuk): Here, key and value are padded while slot_mapping
@@ -700,8 +704,8 @@ def forward(
700704
701705 # decode:extend:prefill
702706 query = query [:num_actual_tokens ]
703- key = key [:num_actual_tokens ]
704- value = value [:num_actual_tokens ]
707+ key = key [:num_actual_tokens ] if key is not None else key_cache [: num_actual_tokens ]
708+ value = value [:num_actual_tokens ] if value is not None else value_cache [: num_actual_tokens ]
705709
706710 output_actual_tokens = output [:num_actual_tokens ]
707711
0 commit comments