Skip to content

Commit 276344a

Browse files
authored
Fix CUDA IMA from combination of unrolling + pipelining (#920)
1 parent e0ff088 commit 276344a

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

helion/_compiler/tile_strategy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dataclasses
66
import functools
77
import itertools
8+
import math
89
import operator
910
from typing import TYPE_CHECKING
1011
from typing import NamedTuple
@@ -147,10 +148,28 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]:
147148
range_num_stages = env.config_spec.range_num_stages.config_get(
148149
config.range_num_stages, block_idx, 0
149150
)
151+
150152
if config.indexing == "tensor_descriptor" and range_num_stages > 0:
151153
# Tensor descriptor + multi-stage tl.range pipelines tend to cause
152154
# CUDA "misaligned address" or "unspecified launch failure" errors.
153155
range_num_stages = 0
156+
elif (
157+
range_num_stages > 1
158+
and range_unroll_factor > 1
159+
and env.block_sizes[block_idx].size
160+
and env.block_sizes[block_idx].numel.is_number
161+
):
162+
# Unrolling can cause CUDA IMA with pipelining
163+
# We want to ensure new step size + pipeline is within bounds
164+
loop_numel = int(env.block_sizes[block_idx].numel)
165+
block_size = int(env.block_sizes[block_idx].from_config_assert(config))
166+
step = range_unroll_factor * block_size
167+
last_offset = ((loop_numel - 1) // block_size) * block_size
168+
remainder = loop_numel - last_offset
169+
range_num_stages = min(
170+
max(1, int(math.ceil(remainder / step))), range_num_stages
171+
)
172+
154173
if range_num_stages > 0:
155174
kwargs.append(f"num_stages={range_num_stages}")
156175

@@ -194,6 +213,7 @@ def get_range_call_str(
194213

195214
if use_static_range:
196215
return f"tl.static_range({', '.join(range_args)})"
216+
197217
range_kwargs = TileStrategy.get_tl_range_kwargs(config, block_ids[0])
198218
return f"tl.range({', '.join(range_args + range_kwargs)})"
199219

test/test_loops.expected

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,46 @@ def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
11451145
_BLOCK_SIZE_3 = 8
11461146
_launcher(_helion_three_pass_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2)
11471147
return out
1148+
1149+
--- assertExpectedJournal(TestLoops.test_unroll_with_pipelining)
1150+
from __future__ import annotations
1151+
1152+
import torch
1153+
import helion
1154+
import triton
1155+
import triton.language as tl
1156+
from helion.runtime import default_launcher as _default_launcher
1157+
1158+
@triton.jit
1159+
def _helion_matmul(x, y, out, _NUM_SM: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1160+
total_pids = tl.cdiv(256, _BLOCK_SIZE_1) * tl.cdiv(256, _BLOCK_SIZE_0)
1161+
block_size = tl.cdiv(total_pids, _NUM_SM)
1162+
start_pid = tl.program_id(0) * block_size
1163+
end_pid = tl.minimum(start_pid + block_size, total_pids)
1164+
for virtual_pid in tl.range(start_pid, end_pid, loop_unroll_factor=4, num_stages=1):
1165+
num_blocks_0 = tl.cdiv(256, _BLOCK_SIZE_1)
1166+
pid_0 = virtual_pid % num_blocks_0
1167+
pid_1 = virtual_pid // num_blocks_0
1168+
offset_1 = pid_0 * _BLOCK_SIZE_1
1169+
offset_0 = pid_1 * _BLOCK_SIZE_0
1170+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1171+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2, loop_unroll_factor=4, num_stages=1):
1172+
acc_copy = acc
1173+
acc_copy_0 = acc_copy
1174+
load = tl.load(tl.make_block_ptr(x, [256, 256], [256, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
1175+
load_1 = tl.load(tl.make_block_ptr(y, [256, 256], [256, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
1176+
acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
1177+
v_0 = tl.cast(acc, tl.bfloat16)
1178+
tl.store(tl.make_block_ptr(out, [256, 256], [256, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1])
1179+
1180+
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1181+
m, k = x.size()
1182+
k2, n = y.size()
1183+
assert k == k2, f'size mismatch {k} != {k2}'
1184+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1185+
_NUM_SM = helion.runtime.get_num_sm(x.device)
1186+
_BLOCK_SIZE_1 = 16
1187+
_BLOCK_SIZE_0 = 64
1188+
_BLOCK_SIZE_2 = 16
1189+
_launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
1190+
return out

test/test_loops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,47 @@ def three_pass_kernel(x: torch.Tensor) -> torch.Tensor:
12131213
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
12141214
self.assertExpectedJournal(code)
12151215

1216+
def test_unroll_with_pipelining(self):
1217+
@helion.kernel(static_shapes=True)
1218+
def matmul(
1219+
x: torch.Tensor,
1220+
y: torch.Tensor,
1221+
) -> torch.Tensor:
1222+
m, k = x.size()
1223+
k2, n = y.size()
1224+
assert k == k2, f"size mismatch {k} != {k2}"
1225+
out = torch.empty(
1226+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
1227+
)
1228+
for tile_m, tile_n in hl.tile([m, n]):
1229+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1230+
for tile_k in hl.tile(k):
1231+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1232+
out[tile_m, tile_n] = acc
1233+
return out
1234+
1235+
a = torch.randn(256, 256, device=DEVICE, dtype=torch.bfloat16)
1236+
b = torch.randn(256, 256, device=DEVICE, dtype=torch.bfloat16)
1237+
1238+
code, result = code_and_output(
1239+
matmul,
1240+
(a, b),
1241+
block_sizes=[64, 16, 16],
1242+
indexing="block_ptr",
1243+
loop_orders=[[1, 0]],
1244+
pid_type="persistent_blocked",
1245+
range_num_stages=[4, 2],
1246+
range_unroll_factors=[4, 4],
1247+
)
1248+
1249+
expected = torch.matmul(a, b)
1250+
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
1251+
self.assertExpectedJournal(code)
1252+
1253+
# Logic for modifying num_stages and loop unrolling factors should
1254+
# change num_stages=1
1255+
self.assertIn("num_stages=1", code)
1256+
12161257

12171258
if __name__ == "__main__":
12181259
unittest.main()

0 commit comments

Comments
 (0)