Skip to content

Commit 093cd3f

Browse files
authored
fix dispatch_attention_fn check (#12636)
* fix * fix
1 parent aecf0c5 commit 093cd3f

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,18 @@ def _check_shape(
383383
attn_mask: Optional[torch.Tensor] = None,
384384
**kwargs,
385385
) -> None:
386+
# Expected shapes:
387+
# query: (batch_size, seq_len_q, num_heads, head_dim)
388+
# key: (batch_size, seq_len_kv, num_heads, head_dim)
389+
# value: (batch_size, seq_len_kv, num_heads, head_dim)
390+
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
391+
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
386392
if query.shape[-1] != key.shape[-1]:
387-
raise ValueError("Query and key must have the same last dimension.")
388-
if query.shape[-2] != value.shape[-2]:
389-
raise ValueError("Query and value must have the same second to last dimension.")
390-
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
391-
raise ValueError("Attention mask must match the key's second to last dimension.")
393+
raise ValueError("Query and key must have the same head dimension.")
394+
if key.shape[-3] != value.shape[-3]:
395+
raise ValueError("Key and value must have the same sequence length.")
396+
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
397+
raise ValueError("Attention mask must match the key's sequence length.")
392398

393399

394400
# ===== Helper functions =====

src/diffusers/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
4343
DIFFUSERS_REQUEST_TIMEOUT = 60
4444
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
45-
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
45+
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES

0 commit comments

Comments
 (0)