[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836
[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836pedramr wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…itectures Replace the exact-match compute capability allowlist with a >= sm80 range check, matching flash-attn's own gate: Dao-AILab/flash-attention@bbb21d6 The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89 (L40S), sm86 (A40), and others where FA2 supports head_dim up to 256. The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear to be supported based on flash-attn's >= sm80 guarantee. Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
Greptile SummaryThis PR fixes a regression where FlashAttention 2 with
Confidence Score: 5/5Safe to merge — the behavioural fix is correct and only one minor dead-code style note remains. The only finding is P2: the new No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called] --> B{device_compute_capability < sm80?}
B -- Yes --> C[use_flash_attention_2 = False\nline 431]
B -- No --> D{use_flash_attention_2?}
C --> END[continue with FA2 disabled]
D -- No --> END
D -- Yes --> E{head_dim_qk > 256\nor head_dim_qk pct 8 != 0\nor head_dim_qk > 192 AND cc < sm80?}
E -- Third branch always False cc >= sm80 guaranteed here --> F[Only first two branches can disable FA2]
E -- Yes first/second branch --> G[use_flash_attention_2 = False\nline 646]
F --> H[FA2 stays enabled for head_dim 193-256 on sm80+ sm86 sm89 sm103 etc]
G --> END
Reviews (1): Last reviewed commit: "[PyTorch] Fix FlashAttention 2 head_dim ..." | Re-trigger Greptile |
| head_dim_qk > 192 | ||
| and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) | ||
| ) | ||
| or (head_dim_qk > 192 and device_compute_capability < (8, 0)) |
There was a problem hiding this comment.
Dead code: condition is always
False at this point
Lines 428–431 unconditionally set use_flash_attention_2 = False whenever device_compute_capability < (8, 0). By the time execution reaches line 634, use_flash_attention_2 can only be True if device_compute_capability >= (8, 0), so the sub-expression device_compute_capability < (8, 0) is never true and the entire third or branch is unreachable. The bug-fix intent is correct (no longer blocking head_dim > 192 on sm86/sm89/sm103), but the residual condition could be confusing to future readers who might believe it provides a meaningful guard.
Consider removing the dead branch entirely:
| or (head_dim_qk > 192 and device_compute_capability < (8, 0)) | |
| or head_dim_qk % 8 != 0 |
Description
The
head_dim > 192gate for FlashAttention 2 inget_attention_backendused an exact-matchcompute capability allowlist:
(8,0), (9,0), (10,0), (12,0). This excluded sm103 (B300/GB300),sm89 (L40S/RTX 4090), sm86 (A40/RTX 3090), and other valid architectures where flash-attn
supports head_dim up to 256.
This PR replaces the allowlist with a
>= sm80range check, matching flash-attn's own gate:Dao-AILab/flash-attention@bbb21d6
The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear
to be supported based on flash-attn's >= sm80 guarantee.
Type of change
Changes
device_compute_capability < (8, 0)range checksm80/90/100+tosm80+