Skip to content

Commit d4d122b

Browse files
authored
Register tile symbol origin, to support tile + offset use case in blackwell attention (#939)
1 parent 3c8c390 commit d4d122b

File tree

7 files changed

+243
-13
lines changed

7 files changed

+243
-13
lines changed

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def has_current() -> bool:
416416
except NoCurrentEnvironment:
417417
return False
418418

419-
def get_block_id(self, size: int | torch.SymInt | sympy.Expr) -> int | None:
419+
def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None:
420420
"""
421421
Get the block ID associated with a given size expression.
422422
@@ -425,7 +425,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Expr) -> int | None:
425425
symbolic expressions to find their associated block IDs.
426426
427427
Args:
428-
size: The size expression to check. Can be an integer, torch.SymInt, or sympy.Expr.
428+
size: The size expression to check. Can be an integer, torch.SymInt, or sympy.Basic.
429429
430430
Returns:
431431
The block ID if the size corresponds to a registered block size, None otherwise.

helion/_compiler/generate_ast.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@
2525
from .inductor_lowering import CodegenState
2626
from .inductor_lowering import codegen_call_with_graph
2727
from .program_id import ForEachProgramID
28+
from .tile_strategy import DeviceLoopState
2829
from .variable_origin import ArgumentOrigin
2930

3031
if TYPE_CHECKING:
3132
from collections.abc import Iterator
3233

34+
import sympy
35+
3336
from ..runtime import Config
3437
from .host_function import HostFunction
3538
from .tile_strategy import DeviceLoopOrGridState
36-
from .tile_strategy import DeviceLoopState
3739
from .type_propagation import TensorType
3840

3941

@@ -97,6 +99,60 @@ def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Na
9799
)
98100
return create(ast.Name, id=varname, ctx=ast.Load())
99101

102+
def lift_symnode(
103+
self,
104+
expr: ast.AST,
105+
sym_expr: sympy.Expr,
106+
*,
107+
dce: bool = False,
108+
prefix: str = "symnode",
109+
) -> ast.Name:
110+
if isinstance(expr, ast.Name):
111+
return expr
112+
assert isinstance(expr, ExtendedAST), expr
113+
114+
target_statements = self.statements_stack[-1]
115+
env = CompileEnvironment.current()
116+
# Identify every block dimension the symbolic value depends on so we know
117+
# which loop nests the expression depends on.
118+
dep_block_ids = {
119+
block_id
120+
for symbol in sym_expr.free_symbols
121+
if (block_id := env.get_block_id(symbol)) is not None
122+
}
123+
124+
# Walk outward through the active device loops: as soon as we see a loop
125+
# whose block id appears in the dependency set we must stop, otherwise we
126+
# can safely hoist into that loop's outer prefix (which executes before the
127+
# loop body).
128+
for loop_state in reversed(self._active_loop_stack()):
129+
if dep_block_ids.intersection(loop_state.block_ids):
130+
break
131+
target_statements = loop_state.outer_prefix
132+
133+
with expr:
134+
varname = self.tmpvar(dce=dce, prefix=prefix)
135+
# Emit the temporary into the chosen statement list so the symbolic
136+
# expression is computed exactly once at the appropriate scope.
137+
target_statements.append(
138+
statement_from_string(f"{varname} = {{expr}}", expr=expr)
139+
)
140+
# Reuse the temporary everywhere else in the kernel body.
141+
return create(ast.Name, id=varname, ctx=ast.Load())
142+
143+
def _active_loop_stack(self) -> list[DeviceLoopState]:
144+
seen: set[int] = set()
145+
stack: list[DeviceLoopState] = []
146+
for loops in self.active_device_loops.values():
147+
for loop_state in loops:
148+
if not isinstance(loop_state, DeviceLoopState):
149+
continue
150+
key = id(loop_state)
151+
if key not in seen:
152+
stack.append(loop_state)
153+
seen.add(key)
154+
return stack
155+
100156
@contextlib.contextmanager
101157
def set_statements(self, new_statements: list[ast.AST] | None) -> Iterator[None]:
102158
if new_statements is None:

helion/language/_tracing_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def _(state: CodegenState) -> ast.AST:
5454
if block_size_var is None:
5555
return expr_from_string("1")
5656
return expr_from_string(block_size_var)
57-
return state.codegen.lift(
58-
expr_from_string(state.sympy_expr(val._sympy_())),
57+
sym_expr = val._sympy_()
58+
return state.codegen.lift_symnode(
59+
expr_from_string(state.sympy_expr(sym_expr)),
60+
sym_expr,
5961
dce=True,
6062
prefix="symnode",
6163
)

helion/language/tile_ops.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from .. import exc
88
from .._compiler.ast_extension import expr_from_string
99
from .._compiler.compile_environment import CompileEnvironment
10+
from .._compiler.host_function import HostFunction
11+
from .._compiler.host_function import SymbolOrigin
12+
from .._compiler.variable_origin import GridOrigin
1013
from . import _decorators
1114

1215
if TYPE_CHECKING:
@@ -17,6 +20,13 @@
1720
from .tile_interface import TileInterface
1821

1922

23+
def _register_tile_symbol_origin(symbol: torch.SymInt, tile_index: int) -> None:
24+
"""Register the origin for a tile-related symbol so it can be resolved during codegen."""
25+
HostFunction.current().expr_to_origin[symbol._sympy_()] = SymbolOrigin(
26+
GridOrigin(tile_index)
27+
)
28+
29+
2030
@_decorators.api(tiles_as_sizes=True)
2131
def tile_index(tile: TileInterface) -> torch.Tensor:
2232
"""
@@ -68,10 +78,12 @@ def tile_begin(tile: TileInterface) -> int:
6878

6979
@_decorators.register_fake(tile_begin)
7080
def _(tile: torch.SymInt) -> torch.SymInt:
71-
_disable_flatten_get_tile(tile) # update config spec if needed
72-
return CompileEnvironment.current().cached_create_unbacked_symint(
81+
index = _disable_flatten_get_tile(tile) # update config spec if needed
82+
result = CompileEnvironment.current().cached_create_unbacked_symint(
7383
("tile_begin", tile)
7484
)
85+
_register_tile_symbol_origin(result, index)
86+
return result
7587

7688

7789
def _disable_flatten_get_tile(tile: object) -> int:
@@ -109,10 +121,12 @@ def tile_end(tile: TileInterface) -> int:
109121

110122
@_decorators.register_fake(tile_end)
111123
def _(tile: torch.SymInt) -> torch.SymInt:
112-
_disable_flatten_get_tile(tile) # update config spec if needed
113-
return CompileEnvironment.current().cached_create_unbacked_symint(
124+
index = _disable_flatten_get_tile(tile) # update config spec if needed
125+
result = CompileEnvironment.current().cached_create_unbacked_symint(
114126
("tile_end", tile)
115127
)
128+
_register_tile_symbol_origin(result, index)
129+
return result
116130

117131

118132
@_decorators.codegen(tile_end)
@@ -175,9 +189,13 @@ def tile_id(tile: TileInterface) -> int:
175189

176190
@_decorators.register_fake(tile_id)
177191
def _(tile: torch.SymInt) -> torch.SymInt:
178-
_disable_flatten_get_tile(tile) # update config spec if needed
192+
index = _disable_flatten_get_tile(tile) # update config spec if needed
179193
assert isinstance(tile, torch.SymInt)
180-
return CompileEnvironment.current().cached_create_unbacked_symint(("tile_id", tile))
194+
result = CompileEnvironment.current().cached_create_unbacked_symint(
195+
("tile_id", tile)
196+
)
197+
_register_tile_symbol_origin(result, index)
198+
return result
181199

182200

183201
@_decorators.codegen(tile_id)

test/test_examples.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ def _helion_fp8_attention_kernel(q, k, v, out, out_stride_0, heads, _RDIM_SIZE_2
987987
pid_0 = tl.program_id(0)
988988
offset_0 = pid_0
989989
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
990+
symnode_0 = triton_helpers.div_floor_integer(offset_0, heads)
991+
symnode_1 = triton_helpers.remainder_integer(offset_0, heads)
990992
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
991993
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
992994
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
@@ -1028,8 +1030,6 @@ def _helion_fp8_attention_kernel(q, k, v, out, out_stride_0, heads, _RDIM_SIZE_2
10281030
subscript_2 = l_i[:, None]
10291031
v_11 = acc / subscript_2
10301032
v_12 = tl.cast(v_11, tl.float8e4nv)
1031-
symnode_0 = triton_helpers.div_floor_integer(offset_0, heads)
1032-
symnode_1 = triton_helpers.remainder_integer(offset_0, heads)
10331033
tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_12, None)
10341034

10351035
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch: int, heads: int, *, _launcher=_default_launcher):

test/test_indexing.expected

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,93 @@ def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
462462
_launcher(_helion_tile_offset_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
463463
return out
464464

465+
--- assertExpectedJournal(TestIndexing.test_tile_with_offset_from_expr)
466+
from __future__ import annotations
467+
468+
import torch
469+
import triton
470+
import triton.language as tl
471+
from torch._inductor.runtime import triton_helpers
472+
from helion.runtime import default_launcher as _default_launcher
473+
474+
@triton.jit
475+
def _helion_attention(q, k, v, lse, o, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
476+
pid_0 = tl.program_id(0)
477+
offset_0 = pid_0 * _BLOCK_SIZE_0
478+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
479+
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
480+
full = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
481+
v_0 = float('inf')
482+
v_1 = full - v_0
483+
full_1 = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
484+
v_2 = 1.0
485+
v_3 = full_1 + v_2
486+
acc = tl.full([_BLOCK_SIZE_0, 64], 0.0, tl.float32)
487+
q_i = tl.load(q + (indices_0[:, None] * 64 + indices_3[None, :] * 1), None)
488+
symnode_0 = 64 * triton_helpers.div_floor_integer(offset_0, 64)
489+
for offset_2 in tl.range(0, 64, _BLOCK_SIZE_1):
490+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
491+
q_i_copy = q_i
492+
v_1_copy = v_1
493+
acc_copy = acc
494+
v_3_copy = v_3
495+
q_i_copy_0 = q_i_copy
496+
v_1_copy_0 = v_1_copy
497+
acc_copy_0 = acc_copy
498+
v_3_copy_0 = v_3_copy
499+
v_4 = tl.cast(symnode_0, tl.int32)
500+
v_5 = indices_2 + v_4
501+
k_j = tl.load(k + ((indices_2 + symnode_0)[:, None] * 64 + indices_3[None, :] * 1), None)
502+
v_6 = tl.cast(symnode_0, tl.int32)
503+
v_7 = indices_2 + v_6
504+
v_j = tl.load(v + ((indices_2 + symnode_0)[:, None] * 64 + indices_3[None, :] * 1), None)
505+
permute = tl.permute(k_j, [1, 0])
506+
qk = tl.dot(tl.cast(q_i_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
507+
amax = tl.cast(tl.max(qk, 1), tl.float32)
508+
v_8 = 0.18033688
509+
v_9 = amax * v_8
510+
v_10 = triton_helpers.maximum(v_1_copy_0, v_9)
511+
v_11 = 0.18033688
512+
v_12 = qk * v_11
513+
subscript = v_10[:, None]
514+
v_13 = v_12 - subscript
515+
v_14 = libdevice.exp2(v_13)
516+
v_15 = v_1_copy_0 - v_10
517+
v_16 = libdevice.exp2(v_15)
518+
l_ij = tl.cast(tl.sum(v_14, 1), tl.float32)
519+
subscript_1 = v_16[:, None]
520+
v_17 = acc_copy_0 * subscript_1
521+
v_18 = tl.cast(v_14, tl.bfloat16)
522+
acc = tl.dot(tl.cast(v_18, tl.bfloat16), tl.cast(v_j, tl.bfloat16), acc=v_17, input_precision='tf32', out_dtype=tl.float32)
523+
v_19 = v_3_copy_0 * v_16
524+
v_3 = v_19 + l_ij
525+
v_1 = v_10
526+
v_21 = libdevice.log2(v_3)
527+
v_22 = v_1 + v_21
528+
subscript_2 = v_3[:, None]
529+
v_23 = acc / subscript_2
530+
tl.store(lse + indices_0 * 1, v_22, None)
531+
v_24 = tl.cast(v_23, tl.bfloat16)
532+
tl.store(o + (indices_0[:, None] * 64 + indices_3[None, :] * 1), v_24, None)
533+
534+
def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
535+
B, H, M, D = q_in.shape
536+
Bk, Hk, N, Dk = k_in.shape
537+
Bv, Hv, Nv, Dv = v_in.shape
538+
D = 64
539+
Dv = 64
540+
q = q_in.reshape(-1, D)
541+
k = k_in.reshape(-1, D)
542+
v = v_in.reshape(-1, Dv)
543+
MM = q.shape[0]
544+
o = q.new_empty(MM, Dv)
545+
lse = q.new_empty(MM, dtype=torch.float32)
546+
_BLOCK_SIZE_0 = 32
547+
_RDIM_SIZE_2 = 64
548+
_BLOCK_SIZE_1 = 32
549+
_launcher(_helion_attention, (triton.cdiv(8192, _BLOCK_SIZE_0),), q, k, v, lse, o, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
550+
return (o.reshape(B, H, M, Dv), lse.reshape(B, H, M))
551+
465552
--- assertExpectedJournal(TestIndexing.test_tile_with_offset_pointer)
466553
from __future__ import annotations
467554

test/test_indexing.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import unittest
45

56
import torch
@@ -1302,6 +1303,72 @@ def tile_offset_2d_kernel(x: torch.Tensor) -> torch.Tensor:
13021303
torch.testing.assert_close(result, x[10:, :])
13031304
self.assertExpectedJournal(code)
13041305

1306+
@skipIfRefEager(
1307+
"Test is block size dependent which is not supported in ref eager mode"
1308+
)
1309+
def test_tile_with_offset_from_expr(self):
1310+
@helion.kernel(
1311+
autotune_effort="none",
1312+
static_shapes=True,
1313+
)
1314+
def attention(
1315+
q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor
1316+
) -> tuple[torch.Tensor, torch.Tensor]:
1317+
B, H, M, D = q_in.shape
1318+
Bk, Hk, N, Dk = k_in.shape
1319+
Bv, Hv, Nv, Dv = v_in.shape
1320+
D = hl.specialize(D)
1321+
Dv = hl.specialize(Dv)
1322+
q = q_in.reshape(-1, D)
1323+
k = k_in.reshape(-1, D)
1324+
v = v_in.reshape(-1, Dv)
1325+
MM = q.shape[0]
1326+
o = q.new_empty(MM, Dv)
1327+
lse = q.new_empty(MM, dtype=torch.float32)
1328+
block_m = hl.register_block_size(M)
1329+
block_n = hl.register_block_size(N)
1330+
sm_scale = 1.0 / math.sqrt(D)
1331+
qk_scale = sm_scale * 1.44269504 # 1/log(2)
1332+
for tile_m in hl.tile(MM, block_size=block_m):
1333+
m_i = hl.zeros([tile_m]) - float("inf")
1334+
l_i = hl.zeros([tile_m]) + 1.0
1335+
acc = hl.zeros([tile_m, Dv])
1336+
q_i = q[tile_m, :]
1337+
1338+
start_N = tile_m.begin // M * N
1339+
for tile_n in hl.tile(0, N, block_size=block_n):
1340+
k_j = k[tile_n + start_N, :]
1341+
v_j = v[tile_n + start_N, :]
1342+
qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32)
1343+
m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
1344+
qk = qk * qk_scale - m_ij[:, None]
1345+
p = torch.exp2(qk)
1346+
alpha = torch.exp2(m_i - m_ij)
1347+
l_ij = torch.sum(p, -1)
1348+
acc = acc * alpha[:, None]
1349+
p = p.to(v.dtype)
1350+
acc = hl.dot(p, v_j, acc=acc)
1351+
l_i = l_i * alpha + l_ij
1352+
m_i = m_ij
1353+
1354+
m_i += torch.log2(l_i)
1355+
acc = acc / l_i[:, None]
1356+
lse[tile_m] = m_i
1357+
o[tile_m, :] = acc
1358+
1359+
return o.reshape(B, H, M, Dv), lse.reshape(B, H, M)
1360+
1361+
z, h, n_ctx, head_dim = 4, 32, 64, 64
1362+
dtype = torch.bfloat16
1363+
q, k, v = [
1364+
torch.randn((z, h, n_ctx, head_dim), dtype=dtype, device=DEVICE)
1365+
for _ in range(3)
1366+
]
1367+
code, (o, lse) = code_and_output(attention, (q, k, v))
1368+
torch_out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
1369+
torch.testing.assert_close(o, torch_out, atol=1e-2, rtol=1e-2)
1370+
self.assertExpectedJournal(code)
1371+
13051372

13061373
if __name__ == "__main__":
13071374
unittest.main()

0 commit comments

Comments
 (0)