@@ -462,6 +462,93 @@ def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
462462 _launcher(_helion_tile_offset_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
463463 return out
464464
465+ --- assertExpectedJournal(TestIndexing.test_tile_with_offset_from_expr)
466+ from __future__ import annotations
467+
468+ import torch
469+ import triton
470+ import triton.language as tl
471+ from torch._inductor.runtime import triton_helpers
472+ from helion.runtime import default_launcher as _default_launcher
473+
474+ @triton.jit
475+ def _helion_attention(q, k, v, lse, o, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
476+ pid_0 = tl.program_id(0)
477+ offset_0 = pid_0 * _BLOCK_SIZE_0
478+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
479+ indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
480+ full = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
481+ v_0 = float('inf')
482+ v_1 = full - v_0
483+ full_1 = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
484+ v_2 = 1.0
485+ v_3 = full_1 + v_2
486+ acc = tl.full([_BLOCK_SIZE_0, 64], 0.0, tl.float32)
487+ q_i = tl.load(q + (indices_0[:, None] * 64 + indices_3[None, :] * 1), None)
488+ symnode_0 = 64 * triton_helpers.div_floor_integer(offset_0, 64)
489+ for offset_2 in tl.range(0, 64, _BLOCK_SIZE_1):
490+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
491+ q_i_copy = q_i
492+ v_1_copy = v_1
493+ acc_copy = acc
494+ v_3_copy = v_3
495+ q_i_copy_0 = q_i_copy
496+ v_1_copy_0 = v_1_copy
497+ acc_copy_0 = acc_copy
498+ v_3_copy_0 = v_3_copy
499+ v_4 = tl.cast(symnode_0, tl.int32)
500+ v_5 = indices_2 + v_4
501+ k_j = tl.load(k + ((indices_2 + symnode_0)[:, None] * 64 + indices_3[None, :] * 1), None)
502+ v_6 = tl.cast(symnode_0, tl.int32)
503+ v_7 = indices_2 + v_6
504+ v_j = tl.load(v + ((indices_2 + symnode_0)[:, None] * 64 + indices_3[None, :] * 1), None)
505+ permute = tl.permute(k_j, [1, 0])
506+ qk = tl.dot(tl.cast(q_i_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
507+ amax = tl.cast(tl.max(qk, 1), tl.float32)
508+ v_8 = 0.18033688
509+ v_9 = amax * v_8
510+ v_10 = triton_helpers.maximum(v_1_copy_0, v_9)
511+ v_11 = 0.18033688
512+ v_12 = qk * v_11
513+ subscript = v_10[:, None]
514+ v_13 = v_12 - subscript
515+ v_14 = libdevice.exp2(v_13)
516+ v_15 = v_1_copy_0 - v_10
517+ v_16 = libdevice.exp2(v_15)
518+ l_ij = tl.cast(tl.sum(v_14, 1), tl.float32)
519+ subscript_1 = v_16[:, None]
520+ v_17 = acc_copy_0 * subscript_1
521+ v_18 = tl.cast(v_14, tl.bfloat16)
522+ acc = tl.dot(tl.cast(v_18, tl.bfloat16), tl.cast(v_j, tl.bfloat16), acc=v_17, input_precision='tf32', out_dtype=tl.float32)
523+ v_19 = v_3_copy_0 * v_16
524+ v_3 = v_19 + l_ij
525+ v_1 = v_10
526+ v_21 = libdevice.log2(v_3)
527+ v_22 = v_1 + v_21
528+ subscript_2 = v_3[:, None]
529+ v_23 = acc / subscript_2
530+ tl.store(lse + indices_0 * 1, v_22, None)
531+ v_24 = tl.cast(v_23, tl.bfloat16)
532+ tl.store(o + (indices_0[:, None] * 64 + indices_3[None, :] * 1), v_24, None)
533+
534+ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
535+ B, H, M, D = q_in.shape
536+ Bk, Hk, N, Dk = k_in.shape
537+ Bv, Hv, Nv, Dv = v_in.shape
538+ D = 64
539+ Dv = 64
540+ q = q_in.reshape(-1, D)
541+ k = k_in.reshape(-1, D)
542+ v = v_in.reshape(-1, Dv)
543+ MM = q.shape[0]
544+ o = q.new_empty(MM, Dv)
545+ lse = q.new_empty(MM, dtype=torch.float32)
546+ _BLOCK_SIZE_0 = 32
547+ _RDIM_SIZE_2 = 64
548+ _BLOCK_SIZE_1 = 32
549+ _launcher(_helion_attention, (triton.cdiv(8192, _BLOCK_SIZE_0),), q, k, v, lse, o, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
550+ return (o.reshape(B, H, M, Dv), lse.reshape(B, H, M))
551+
465552--- assertExpectedJournal(TestIndexing.test_tile_with_offset_pointer)
466553from __future__ import annotations
467554
0 commit comments