From ab04e7d3b8bfc8f7a710230fa309a2578ee6021b Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Tue, 16 Jun 2026 09:46:26 +0000 Subject: [PATCH] Fix custom flash block sizes fallback in Ulysses attention. Ensure that block sizes and heads_per_tile fall back to default values when resolved as None from CustomFlashBlockSizes dataclass. This fixes a TypeError in ulysses_custom attention when heads_per_tile is not specified. --- src/maxdiffusion/models/attention_flax.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3572c2b88..edc9f4f7b 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -635,19 +635,19 @@ def wrap_ulysses_attention(query, key, value): if flash_block_sizes is not None: if isinstance(flash_block_sizes, dict): - bq = flash_block_sizes.get("block_q", bq) - bkv = flash_block_sizes.get("block_kv", bkv) - bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute) - bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in) - heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) - vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) + bq = flash_block_sizes.get("block_q", None) or bq + bkv = flash_block_sizes.get("block_kv", None) or bkv + bkv_compute = flash_block_sizes.get("block_kv_compute", None) or bkv_compute + bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = flash_block_sizes.get("heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", None) or vmem_limit_bytes else: - bq = getattr(flash_block_sizes, "block_q", bq) - bkv = getattr(flash_block_sizes, "block_kv", bkv) - bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) - bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) - heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) - vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + bq = getattr(flash_block_sizes, "block_q", None) or bq + bkv = getattr(flash_block_sizes, "block_kv", None) or bkv + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", None) or bkv_compute + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", None) or vmem_limit_bytes if use_base2_exp: query = query * LOG2E