Skip to content

Commit 138c0a3

Browse files
jbschlossermeta-codesync[bot]
authored andcommitted
Support Blackwell CUTLASS attention kernels in torch.compile (#5136)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2138 Pull Request resolved: #5136 Support the fbgemm Blackwell CUTLASS attention kernel in torch.compile by adding a C++-side meta function. Reviewed By: henrylhtsang Differential Revision: D86986981 fbshipit-source-id: 066cd6b93c2d815e3f4e180806dd0af243db5724
1 parent bc6d968 commit 138c0a3

File tree

4 files changed

+146
-15
lines changed

4 files changed

+146
-15
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,26 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
171171
}
172172
}
173173

174+
std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd_meta(
175+
const at::Tensor& dOutput,
176+
const at::Tensor& query,
177+
const at::Tensor& key,
178+
const at::Tensor& value,
179+
const at::Tensor& output,
180+
const at::Tensor& softmax_lse,
181+
const std::optional<at::Tensor>& cu_seqlens_q,
182+
const std::optional<at::Tensor>& cu_seqlens_k,
183+
std::optional<c10::SymInt> max_seq_len_q,
184+
std::optional<c10::SymInt> max_seq_len_k,
185+
std::optional<double> softmax_scale,
186+
bool causal,
187+
c10::SymInt window_size_left,
188+
c10::SymInt window_size_right,
189+
bool bottom_right,
190+
bool deterministic) {
191+
return std::make_tuple(at::empty_like(query), at::empty_like(key), at::empty_like(value));
192+
}
193+
174194
// -------------------------------------------------------------------------------------------------
175195
// Op registration
176196
// -------------------------------------------------------------------------------------------------
@@ -185,12 +205,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
185205
" Tensor softmax_lse, "
186206
" Tensor? cu_seqlens_q=None, "
187207
" Tensor? cu_seqlens_k=None, "
188-
" int? max_seq_len_q=None, "
189-
" int? max_seq_len_k=None, "
208+
" SymInt? max_seq_len_q=None, "
209+
" SymInt? max_seq_len_k=None, "
190210
" float? softmax_scale=None, "
191211
" bool causal=False, "
192-
" int window_size_left=-1, "
193-
" int window_size_right=-1, "
212+
" SymInt window_size_left=-1, "
213+
" SymInt window_size_right=-1, "
194214
" bool bottom_right=True, "
195215
" bool deterministic=False"
196216
") -> (Tensor, Tensor, Tensor)");
@@ -199,4 +219,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
199219
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
200220
m.impl("fmha_bwd", dispatch_fmha_bwd);
201221
}
222+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
223+
m.impl("fmha_bwd", dispatch_fmha_bwd_meta);
224+
}
202225
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,31 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
139139
}
140140
}
141141

142+
std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd_meta(
143+
const at::Tensor& q,
144+
const at::Tensor& k, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
145+
const at::Tensor& v, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
146+
const std::optional<at::Tensor>& cu_seqlens_q,
147+
const std::optional<at::Tensor>& cu_seqlens_k,
148+
std::optional<c10::SymInt> max_seq_len_q,
149+
std::optional<c10::SymInt> max_seq_len_k,
150+
std::optional<double> softmax_scale,
151+
bool causal,
152+
const std::optional<at::Tensor>& seqlen_kv,
153+
const std::optional<at::Tensor>& page_table, // dim: (batch_size, max_num_pages_per_seq) , null if non-paged
154+
std::optional<c10::SymInt> seqlen_k,
155+
c10::SymInt window_size_left,
156+
c10::SymInt window_size_right,
157+
bool bottom_right) {
158+
auto output = at::empty_like(q);
159+
bool k_is_varlen = max_seq_len_q.has_value();
160+
auto SQ = k_is_varlen ? q.sym_size(0) : q.sym_size(1);
161+
auto H_Q = k_is_varlen ? q.sym_size(1) : q.sym_size(2);
162+
auto B = k_is_varlen ? 1 : q.sym_size(0);
163+
auto logsumexp = q.new_empty_symint({B, H_Q, SQ}, q.options());
164+
return std::make_tuple(output, logsumexp);
165+
}
166+
142167
// -------------------------------------------------------------------------------------------------
143168
// Op registration
144169
// -------------------------------------------------------------------------------------------------
@@ -150,20 +175,23 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
150175
" Tensor value, "
151176
" Tensor? cu_seqlens_q=None, "
152177
" Tensor? cu_seqlens_k=None, "
153-
" int? max_seq_len_q=None, "
154-
" int? max_seq_len_k=None, "
178+
" SymInt? max_seq_len_q=None, "
179+
" SymInt? max_seq_len_k=None, "
155180
" float? softmax_scale=None, "
156181
" bool causal=False, "
157182
" Tensor? seqlen_kv=None, "
158183
" Tensor? page_table=None, "
159-
" int? seqlen_k=None, "
160-
" int window_size_left=-1, "
161-
" int window_size_right=-1, "
184+
" SymInt? seqlen_k=None, "
185+
" SymInt window_size_left=-1, "
186+
" SymInt window_size_right=-1, "
162187
" bool bottom_right=True"
163188
") -> (Tensor, Tensor)");
164189
}
165190

166191
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
167192
m.impl("fmha_fwd", dispatch_fmha_fwd);
168193
}
194+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
195+
m.impl("fmha_fwd", dispatch_fmha_fwd_meta);
196+
}
169197
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,16 @@ at::Tensor dispatch_fmha_gen_fwd(
310310
});
311311
}
312312

313+
at::Tensor dispatch_fmha_gen_fwd_meta(
314+
const at::Tensor& q,
315+
const at::Tensor& k,
316+
const at::Tensor& v,
317+
const at::Tensor& seqlen_kv,
318+
const std::optional<at::Tensor>& batch_idx,
319+
int64_t kernel_type
320+
) {
321+
return at::empty_like(q);
322+
}
313323

314324
// -------------------------------------------------------------------------------------------------
315325
// Op registration
@@ -329,4 +339,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
329339
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
330340
m.impl("fmha_gen_fwd", dispatch_fmha_gen_fwd);
331341
}
342+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
343+
m.impl("fmha_gen_fwd", dispatch_fmha_gen_fwd_meta);
344+
}
332345
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def _execute_cutlass_blackwell_attn_dense(
293293
deterministic: bool,
294294
sm_scale: Optional[float],
295295
is_paged: Optional[bool],
296+
use_compile: bool = False,
296297
) -> None:
297298
device = torch.accelerator.current_accelerator()
298299
assert device is not None
@@ -369,9 +370,12 @@ def _execute_cutlass_blackwell_attn_dense(
369370
)
370371

371372
# Run tested kernel
373+
func_to_test = cutlass_blackwell_fmha_func
374+
if use_compile:
375+
func_to_test = torch.compile(func_to_test, fullgraph=True)
372376
if is_paged:
373377
assert k_paged is not None and v_paged is not None
374-
out_paged = cutlass_blackwell_fmha_func(
378+
out_paged = func_to_test(
375379
q,
376380
k_paged,
377381
v_paged,
@@ -384,7 +388,7 @@ def _execute_cutlass_blackwell_attn_dense(
384388
softmax_scale=sm_scale,
385389
)
386390

387-
out = cutlass_blackwell_fmha_func(
391+
out = func_to_test(
388392
q,
389393
k,
390394
v,
@@ -411,7 +415,7 @@ def _execute_cutlass_blackwell_attn_dense(
411415

412416
if deterministic:
413417
# Rerun the test. The outputs must be bit-wise exact
414-
out_d = cutlass_blackwell_fmha_func(
418+
out_d = func_to_test(
415419
q,
416420
cast(torch.Tensor, k_paged) if is_paged else k,
417421
cast(torch.Tensor, v_paged) if is_paged else v,
@@ -479,6 +483,7 @@ def _execute_cutlass_blackwell_attn_varlen(
479483
deterministic: bool,
480484
sm_scale: Optional[float],
481485
is_paged: Optional[bool],
486+
use_compile: bool = False,
482487
) -> None:
483488
device = torch.accelerator.current_accelerator()
484489
assert device is not None
@@ -572,9 +577,12 @@ def _execute_cutlass_blackwell_attn_varlen(
572577
softmax_scale=sm_scale,
573578
)
574579

580+
func_to_test = cutlass_blackwell_fmha_func
581+
if use_compile:
582+
func_to_test = torch.compile(func_to_test, fullgraph=True)
575583
if is_paged:
576584
assert k_paged is not None and v_paged is not None
577-
out_unpad_paged = cutlass_blackwell_fmha_func(
585+
out_unpad_paged = func_to_test(
578586
q_unpad,
579587
k_paged,
580588
v_paged,
@@ -590,7 +598,7 @@ def _execute_cutlass_blackwell_attn_varlen(
590598
)
591599
out_paged = output_pad_fn(out_unpad_paged)
592600

593-
out_unpad = cutlass_blackwell_fmha_func(
601+
out_unpad = func_to_test(
594602
q_unpad,
595603
k_unpad,
596604
v_unpad,
@@ -617,7 +625,7 @@ def _execute_cutlass_blackwell_attn_varlen(
617625

618626
if deterministic:
619627
# Rerun the test. The outputs must be bit-wise exact
620-
out_unpad_d = cutlass_blackwell_fmha_func(
628+
out_unpad_d = func_to_test(
621629
q_unpad,
622630
cast(torch.Tensor, k_paged) if is_paged else k_unpad,
623631
cast(torch.Tensor, v_paged) if is_paged else v_unpad,
@@ -1165,3 +1173,62 @@ def test_backward(
11651173
sm_scale=sm_scale,
11661174
is_paged=False,
11671175
)
1176+
1177+
@skip_cuda_lt_sm100
1178+
@skip_rocm
1179+
@parameterized.expand(
1180+
[
1181+
(
1182+
is_varlen,
1183+
is_mqa,
1184+
seqlen_q,
1185+
)
1186+
for is_varlen in [False, True]
1187+
for is_mqa in [False, True]
1188+
for seqlen_q in [1, 64]
1189+
]
1190+
)
1191+
def test_compile(
1192+
self,
1193+
is_varlen: bool,
1194+
is_mqa: bool,
1195+
seqlen_q: int,
1196+
):
1197+
test_func = (
1198+
self._execute_cutlass_blackwell_attn_varlen
1199+
if is_varlen
1200+
else self._execute_cutlass_blackwell_attn_dense
1201+
)
1202+
q_heads = 8
1203+
kv_heads = 2 if is_mqa else q_heads
1204+
batch_size = 2
1205+
seqlen_k = 128
1206+
kv_heads = 2
1207+
head_dim = 128
1208+
dtype = torch.bfloat16
1209+
causal = True
1210+
# Decode kernel does not support sliding window attention yet
1211+
window_size = (-1, -1)
1212+
deterministic = False
1213+
# Backward pass is not supported for generation phase (sq=1)
1214+
is_decode = seqlen_q == 1
1215+
fwd_only = is_decode
1216+
# Decode kernel does not support sm_scale
1217+
sm_scale = None if is_decode else 1.0 / head_dim
1218+
1219+
test_func(
1220+
batch_size,
1221+
seqlen_q,
1222+
seqlen_k,
1223+
q_heads=q_heads,
1224+
kv_heads=kv_heads,
1225+
head_dim=head_dim,
1226+
page_block_size=0,
1227+
dtype=dtype,
1228+
causal=causal,
1229+
window_size=window_size,
1230+
fwd_only=fwd_only,
1231+
deterministic=deterministic,
1232+
sm_scale=sm_scale,
1233+
is_paged=False,
1234+
)

0 commit comments

Comments
 (0)