diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3572c2b8..edc9f4f7 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -635,19 +635,19 @@ def wrap_ulysses_attention(query, key, value): if flash_block_sizes is not None: if isinstance(flash_block_sizes, dict): - bq = flash_block_sizes.get("block_q", bq) - bkv = flash_block_sizes.get("block_kv", bkv) - bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute) - bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in) - heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) - vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) + bq = flash_block_sizes.get("block_q", None) or bq + bkv = flash_block_sizes.get("block_kv", None) or bkv + bkv_compute = flash_block_sizes.get("block_kv_compute", None) or bkv_compute + bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = flash_block_sizes.get("heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", None) or vmem_limit_bytes else: - bq = getattr(flash_block_sizes, "block_q", bq) - bkv = getattr(flash_block_sizes, "block_kv", bkv) - bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) - bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) - heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) - vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + bq = getattr(flash_block_sizes, "block_q", None) or bq + bkv = getattr(flash_block_sizes, "block_kv", None) or bkv + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", None) or bkv_compute + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", None) or vmem_limit_bytes if use_base2_exp: query = query * LOG2E