@@ -383,12 +383,18 @@ def _check_shape(
383383 attn_mask : Optional [torch .Tensor ] = None ,
384384 ** kwargs ,
385385) -> None :
386+ # Expected shapes:
387+ # query: (batch_size, seq_len_q, num_heads, head_dim)
388+ # key: (batch_size, seq_len_kv, num_heads, head_dim)
389+ # value: (batch_size, seq_len_kv, num_heads, head_dim)
390+ # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
391+ # or (batch_size, num_heads, seq_len_q, seq_len_kv)
386392 if query .shape [- 1 ] != key .shape [- 1 ]:
387- raise ValueError ("Query and key must have the same last dimension." )
388- if query .shape [- 2 ] != value .shape [- 2 ]:
389- raise ValueError ("Query and value must have the same second to last dimension ." )
390- if attn_mask is not None and attn_mask .shape [- 1 ] != key .shape [- 2 ]:
391- raise ValueError ("Attention mask must match the key's second to last dimension ." )
393+ raise ValueError ("Query and key must have the same head dimension." )
394+ if key .shape [- 3 ] != value .shape [- 3 ]:
395+ raise ValueError ("Key and value must have the same sequence length ." )
396+ if attn_mask is not None and attn_mask .shape [- 1 ] != key .shape [- 3 ]:
397+ raise ValueError ("Attention mask must match the key's sequence length ." )
392398
393399
394400# ===== Helper functions =====
0 commit comments