We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 65c4e40 commit 8ea69b8Copy full SHA for 8ea69b8
src/maxdiffusion/models/attention_flax.py
@@ -240,7 +240,7 @@ def _tpu_flash_attention(
240
block_q_dkv=min(q_max_block_size, query.shape[2]),
241
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
242
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
243
- block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
+ block_q_dq=None if attention_kernel == "tokamax_flash" else min(q_max_block_size, query.shape[2]),
244
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
245
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
246
)
0 commit comments