@@ -185,6 +185,75 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
185185 _launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
186186 return out
187187
188+ --- assertExpectedJournal(TestIndexing.test_hl_arange_non_power_of_2)
189+ from __future__ import annotations
190+
191+ import torch
192+ import triton
193+ import triton.language as tl
194+ from helion.runtime import default_launcher as _default_launcher
195+
196+ @triton.jit
197+ def _helion__matmul_layernorm_bwd_dxdy(z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, grad_y_stride_0, grad_y_stride_1, mean_stride_0, rstd_stride_0, weight_stride_0, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, z_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
198+ pid_0 = tl.program_id(0)
199+ offset_0 = pid_0 * _BLOCK_SIZE_0
200+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
201+ mask_0 = indices_0 < m
202+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
203+ mask_1 = indices_1 < 7
204+ indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
205+ mask_2 = indices_2 < 3
206+ load = tl.load(z + (indices_0[:, None] * z_stride_0 + indices_1[None, :] * z_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
207+ v_0 = tl.cast(load, tl.float32)
208+ load_1 = tl.load(grad_out + (indices_0[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
209+ v_1 = tl.cast(load_1, tl.float32)
210+ load_2 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
211+ v_2 = tl.cast(load_2, tl.float32)
212+ mean_tile = tl.load(mean + indices_0 * mean_stride_0, mask_0, other=0)
213+ rstd_tile = tl.load(rstd + indices_0 * rstd_stride_0, mask_0, other=0)
214+ subscript = mean_tile[:, None]
215+ v_3 = v_0 - subscript
216+ subscript_1 = rstd_tile[:, None]
217+ v_4 = v_3 * subscript_1
218+ v_5 = v_2[None, :]
219+ v_6 = v_5 * v_1
220+ v_7 = v_4 * v_6
221+ sum_1 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
222+ v_8 = 0.14285714285714285
223+ v_9 = sum_1 * v_8
224+ sum_2 = tl.cast(tl.reshape(tl.sum(v_6, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
225+ v_10 = 0.14285714285714285
226+ v_11 = sum_2 * v_10
227+ v_12 = v_4 * v_9
228+ v_13 = v_12 + v_11
229+ v_14 = v_6 - v_13
230+ subscript_2 = rstd_tile[:, None]
231+ v_15 = v_14 * subscript_2
232+ load_5 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
233+ permute = tl.permute(load_5, [1, 0])
234+ v_16 = tl.cast(permute, tl.float32)
235+ mm = tl.dot(tl.reshape(tl.permute(tl.join(tl.cast(v_15, tl.float32), tl.zeros_like(tl.cast(v_15, tl.float32))), [0, 2, 1]), [16, 16]), tl.reshape(tl.permute(tl.join(tl.cast(v_16, tl.float32), tl.zeros_like(tl.cast(v_16, tl.float32))), [2, 0, 1]), [16, 4]), input_precision='tf32', out_dtype=tl.float32)
236+ v_17 = tl.cast(mm, tl.float16)
237+ tl.store(grad_x + (indices_0[:, None] * grad_x_stride_0 + indices_2[None, :] * grad_x_stride_1), v_17, mask_0[:, None] & mask_2[None, :])
238+ load_6 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
239+ permute_1 = tl.permute(load_6, [1, 0])
240+ v_18 = tl.cast(permute_1, tl.float32)
241+ mm_1 = tl.dot(tl.cast(v_18, tl.float32), tl.cast(v_15, tl.float32), input_precision='tf32', out_dtype=tl.float32)
242+ v_19 = tl.cast(mm_1, tl.float16)
243+ iota = tl.arange(0, 4)
244+ iota_1 = tl.arange(0, 8)
245+ tl.atomic_add(grad_y + (iota[:, None] * grad_y_stride_0 + iota_1[None, :] * grad_y_stride_1), v_19, mask=(iota < 3)[:, None] & (iota_1 < 7)[None, :], sem='relaxed')
246+
247+ def _matmul_layernorm_bwd_dxdy(grad_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launcher):
248+ m, n = z.shape
249+ grad_x = torch.empty_like(x)
250+ grad_y = torch.zeros_like(y)
251+ _BLOCK_SIZE_0 = 16
252+ _RDIM_SIZE_1 = 8
253+ _RDIM_SIZE_2 = 4
254+ _launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2)
255+ return (grad_x, grad_y)
256+
188257--- assertExpectedJournal(TestIndexing.test_mask_load)
189258from __future__ import annotations
190259
0 commit comments