@@ -1145,3 +1145,46 @@ def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
11451145 _BLOCK_SIZE_3 = 8
11461146 _launcher(_helion_three_pass_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2)
11471147 return out
1148+
1149+ --- assertExpectedJournal(TestLoops.test_unroll_with_pipelining)
1150+ from __future__ import annotations
1151+
1152+ import torch
1153+ import helion
1154+ import triton
1155+ import triton.language as tl
1156+ from helion.runtime import default_launcher as _default_launcher
1157+
1158+ @triton.jit
1159+ def _helion_matmul(x, y, out, _NUM_SM: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1160+ total_pids = tl.cdiv(256, _BLOCK_SIZE_1) * tl.cdiv(256, _BLOCK_SIZE_0)
1161+ block_size = tl.cdiv(total_pids, _NUM_SM)
1162+ start_pid = tl.program_id(0) * block_size
1163+ end_pid = tl.minimum(start_pid + block_size, total_pids)
1164+ for virtual_pid in tl.range(start_pid, end_pid, loop_unroll_factor=4, num_stages=1):
1165+ num_blocks_0 = tl.cdiv(256, _BLOCK_SIZE_1)
1166+ pid_0 = virtual_pid % num_blocks_0
1167+ pid_1 = virtual_pid // num_blocks_0
1168+ offset_1 = pid_0 * _BLOCK_SIZE_1
1169+ offset_0 = pid_1 * _BLOCK_SIZE_0
1170+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1171+ for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2, loop_unroll_factor=4, num_stages=1):
1172+ acc_copy = acc
1173+ acc_copy_0 = acc_copy
1174+ load = tl.load(tl.make_block_ptr(x, [256, 256], [256, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
1175+ load_1 = tl.load(tl.make_block_ptr(y, [256, 256], [256, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
1176+ acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
1177+ v_0 = tl.cast(acc, tl.bfloat16)
1178+ tl.store(tl.make_block_ptr(out, [256, 256], [256, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1])
1179+
1180+ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1181+ m, k = x.size()
1182+ k2, n = y.size()
1183+ assert k == k2, f'size mismatch {k} != {k2}'
1184+ out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1185+ _NUM_SM = helion.runtime.get_num_sm(x.device)
1186+ _BLOCK_SIZE_1 = 16
1187+ _BLOCK_SIZE_0 = 64
1188+ _BLOCK_SIZE_2 = 16
1189+ _launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
1190+ return out
0 commit comments