Skip to content

Commit 8ea69b8

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 65c4e40 commit 8ea69b8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _tpu_flash_attention(
240240
block_q_dkv=min(q_max_block_size, query.shape[2]),
241241
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
242242
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
243-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
243+
block_q_dq=None if attention_kernel == "tokamax_flash" else min(q_max_block_size, query.shape[2]),
244244
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
245245
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
246246
)

0 commit comments

Comments
 (0)