|
13 | 13 | from __future__ import annotations |
14 | 14 |
|
15 | 15 | import math |
| 16 | +from typing import Callable |
16 | 17 |
|
17 | 18 | import torch |
18 | 19 | from triton.testing import do_bench |
@@ -96,8 +97,8 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso |
96 | 97 | static_shapes=True, |
97 | 98 | autotune_accuracy_check=False, |
98 | 99 | ) |
99 | | -def blackwell_attention( |
100 | | - q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor |
| 100 | +def blackwell_attention_kernel( |
| 101 | + q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, qk_scale: float |
101 | 102 | ) -> tuple[torch.Tensor, torch.Tensor]: |
102 | 103 | """ |
103 | 104 | Computes scaled dot-product attention. |
@@ -143,8 +144,7 @@ def blackwell_attention( |
143 | 144 | hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192))) |
144 | 145 | SUBTILING = True |
145 | 146 | VECT_MUL = 1 |
146 | | - sm_scale = 1.0 / math.sqrt(D) |
147 | | - qk_scale = sm_scale * 1.44269504 # 1/log(2) |
| 147 | + qk_scale = qk_scale * 1.44269504 # 1/log(2) |
148 | 148 | for tile_m in hl.tile(MM, block_size=block_m): |
149 | 149 | m_i = hl.zeros([tile_m]) - float("inf") |
150 | 150 | l_i = hl.zeros([tile_m]) + 1.0 |
@@ -205,6 +205,18 @@ def blackwell_attention( |
205 | 205 | return o.reshape(B, H, M, Dv), lse.reshape(B, H, M) |
206 | 206 |
|
207 | 207 |
|
| 208 | +def blackwell_attention( |
| 209 | + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| 210 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 211 | + return blackwell_attention_kernel(q, k, v, qk_scale=math.sqrt(1.0 / q.shape[-1])) |
| 212 | + |
| 213 | + |
| 214 | +def blackwell_attention_tritonbench( |
| 215 | + tb_mod: object, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| 216 | +) -> Callable: |
| 217 | + return lambda: blackwell_attention(q, k, v) |
| 218 | + |
| 219 | + |
208 | 220 | # %% |
209 | 221 | # Testing Function |
210 | 222 | # ---------------- |
|
0 commit comments