Skip to content

Commit f0631e3

Browse files
authored
[blackwell attn example] qk scale as param (#969)
1 parent f8bad83 commit f0631e3

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

benchmarks/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class RunResult:
306306
"blackwell_attentions": (
307307
"tritonbench.operators.blackwell_attentions.operator",
308308
"examples.blackwell_attention",
309-
"blackwell_attention",
309+
"blackwell_attention_tritonbench",
310310
{
311311
"d_head": 128, # Set default head dimension to 128 for TLX attention compatibility
312312
"num_inputs": 6, # flash_attention takes long time on Benchmark CI, so use fewer inputs instead.
@@ -594,8 +594,8 @@ class RunResult:
594594
"triton_tutorial_flash_v2_tma_ws_persistent-accuracy": "triton_accuracy",
595595
"flex_attention-speedup": "torch_compile_speedup",
596596
"flex_attention-accuracy": "torch_compile_accuracy",
597-
"helion_attention-speedup": "helion_speedup",
598-
"helion_attention-accuracy": "helion_accuracy",
597+
"helion_blackwell_attention_tritonbench-speedup": "helion_speedup",
598+
"helion_blackwell_attention_tritonbench-accuracy": "helion_accuracy",
599599
},
600600
}
601601

examples/blackwell_attention.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import math
16+
from typing import Callable
1617

1718
import torch
1819
from triton.testing import do_bench
@@ -96,8 +97,8 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso
9697
static_shapes=True,
9798
autotune_accuracy_check=False,
9899
)
99-
def blackwell_attention(
100-
q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor
100+
def blackwell_attention_kernel(
101+
q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, qk_scale: float
101102
) -> tuple[torch.Tensor, torch.Tensor]:
102103
"""
103104
Computes scaled dot-product attention.
@@ -143,8 +144,7 @@ def blackwell_attention(
143144
hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192)))
144145
SUBTILING = True
145146
VECT_MUL = 1
146-
sm_scale = 1.0 / math.sqrt(D)
147-
qk_scale = sm_scale * 1.44269504 # 1/log(2)
147+
qk_scale = qk_scale * 1.44269504 # 1/log(2)
148148
for tile_m in hl.tile(MM, block_size=block_m):
149149
m_i = hl.zeros([tile_m]) - float("inf")
150150
l_i = hl.zeros([tile_m]) + 1.0
@@ -205,6 +205,18 @@ def blackwell_attention(
205205
return o.reshape(B, H, M, Dv), lse.reshape(B, H, M)
206206

207207

208+
def blackwell_attention(
209+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
210+
) -> tuple[torch.Tensor, torch.Tensor]:
211+
return blackwell_attention_kernel(q, k, v, qk_scale=math.sqrt(1.0 / q.shape[-1]))
212+
213+
214+
def blackwell_attention_tritonbench(
215+
tb_mod: object, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
216+
) -> Callable:
217+
return lambda: blackwell_attention(q, k, v)
218+
219+
208220
# %%
209221
# Testing Function
210222
# ----------------

0 commit comments

Comments
 (0)