Commit 6350109
Fix cutlass_blackwell_fmha_custom_op and add comprehensive FMHA tests (#5108)
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2113
Pull Request resolved: #5108
This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention).
## Changes Made:
### 1. Fixed `cutlass_blackwell_fmha_custom_op.py`
- Added missing parameters to `fmha_fwd`: `page_table`, `seqlen_k`, `window_size_left`, `window_size_right`, `bottom_right`
- Added missing parameters to `fmha_bwd`: `softmax_scale`, `window_size_left`, `window_size_right`, `bottom_right`, `deterministic`
- Fixed parameter type issues: `torch.ops.fbgemm.fmha_fwd/bwd` expect `int` and `bool` types, not `Optional[int]` or `Optional[bool]`
- Added proper default value handling:
- `window_size_left = -1` (default for no left window)
- `window_size_right = -1` (default for no right window)
- `bottom_right = True` (default)
- `deterministic = False` (default)
- Updated `_backward`, `_setup_context`, and wrapper functions to properly pass all parameters
- The custom op now correctly wraps `torch.ops.fbgemm.fmha_fwd` and `torch.ops.fbgemm.fmha_bwd`
### 2. Created `blackwell_fmha.py` Test File
- Structured following `blackwell_gdpa.py` as reference
- Uses `cutlass_blackwell_fmha_custom_op` (Cutlass implementation) for forward and backward passes
- Compares against `jagged_flash_attention_v2` (Triton JFA v2 implementation)
- Tests BF16 dtype only (as specified)
- Tests both forward outputs and backward gradients (dq, dk, dv)
- Runs 10 random test configurations with varying batch sizes, sequence lengths, and number of heads
- Uses `generate_jagged_data` utility for proper test data generation
### 3. Updated BUCK Dependencies
- Changed from `//ads_mkl/ops:jfa` to `//ads_mkl/ops/triton:triton_jfa_v2`
- Added `//ads_mkl/ops/utils:jfa_utils` for data generation utilities
- Changed from `blackwell_attention_ops_gpu` to `blackwell_attention` to include Python bindings
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Trace)
Reviewed By: devashishshankar
Differential Revision: D86583157
fbshipit-source-id: 8771f26c80b587694e2568e6b3232d4ae367c9151 parent d8dcd23 commit 6350109
File tree
1 file changed
+75
-3
lines changed- fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha
1 file changed
+75
-3
lines changedLines changed: 75 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
38 | 43 | | |
39 | 44 | | |
40 | 45 | | |
41 | 46 | | |
42 | 47 | | |
43 | 48 | | |
44 | 49 | | |
| 50 | + | |
45 | 51 | | |
46 | 52 | | |
47 | 53 | | |
| |||
53 | 59 | | |
54 | 60 | | |
55 | 61 | | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
56 | 67 | | |
57 | 68 | | |
58 | 69 | | |
| |||
68 | 79 | | |
69 | 80 | | |
70 | 81 | | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
71 | 87 | | |
72 | 88 | | |
73 | 89 | | |
| |||
122 | 138 | | |
123 | 139 | | |
124 | 140 | | |
| 141 | + | |
125 | 142 | | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
126 | 147 | | |
| 148 | + | |
127 | 149 | | |
128 | 150 | | |
129 | 151 | | |
| |||
135 | 157 | | |
136 | 158 | | |
137 | 159 | | |
| 160 | + | |
138 | 161 | | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
139 | 166 | | |
140 | 167 | | |
141 | 168 | | |
| |||
151 | 178 | | |
152 | 179 | | |
153 | 180 | | |
| 181 | + | |
154 | 182 | | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
155 | 187 | | |
156 | 188 | | |
157 | 189 | | |
| |||
198 | 230 | | |
199 | 231 | | |
200 | 232 | | |
| 233 | + | |
201 | 234 | | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
202 | 256 | | |
203 | | - | |
204 | 257 | | |
205 | 258 | | |
206 | 259 | | |
| |||
215 | 268 | | |
216 | 269 | | |
217 | 270 | | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
218 | 276 | | |
219 | 277 | | |
220 | 278 | | |
| |||
224 | 282 | | |
225 | 283 | | |
226 | 284 | | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
227 | 289 | | |
228 | 290 | | |
229 | 291 | | |
| |||
246 | 308 | | |
247 | 309 | | |
248 | 310 | | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
249 | 316 | | |
250 | 317 | | |
251 | 318 | | |
| |||
258 | 325 | | |
259 | 326 | | |
260 | 327 | | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
261 | 333 | | |
0 commit comments