Skip to content

Commit b40750f

Browse files
authored
Support hl.arange() with non-power-of-2 input (#862)
1 parent 4d41dcc commit b40750f

File tree

5 files changed

+239
-47
lines changed

5 files changed

+239
-47
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch._inductor.utils import triton_type
1212
from torch._prims_common import compute_required_storage_length
13+
from triton import next_power_of_2
1314

1415
from .. import exc
1516
from .._compat import get_tensor_descriptor_fn_name
@@ -32,6 +33,32 @@
3233
ShapeLike = Sequence[SymIntLike]
3334

3435

36+
def _get_padded_iota_original_length(
37+
state: CodegenState, index_position: int
38+
) -> int | None:
39+
"""Get the original length of a padded iota node at the given index position.
40+
41+
Args:
42+
state: The codegen state containing fx_node information
43+
index_position: The position in the index list to check
44+
45+
Returns:
46+
The original (unpadded) length if the index is a padded iota, None otherwise
47+
"""
48+
try:
49+
index_node = state.fx_node.args[1][index_position] # type: ignore[union-attr, index]
50+
if (
51+
isinstance(index_node, torch.fx.Node)
52+
and index_node.target == torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue]
53+
and isinstance(length_arg := index_node.args[0], int)
54+
and length_arg != next_power_of_2(length_arg)
55+
):
56+
return length_arg
57+
except (AttributeError, IndexError, TypeError):
58+
pass
59+
return None
60+
61+
3562
class IndexingStrategy:
3663
def codegen_load(
3764
self,
@@ -634,6 +661,11 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
634661
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
635662
if mask := state.codegen.mask_var(block_idx):
636663
mask_values.setdefault(f"({mask}){expand}")
664+
# Check if this index comes from a padded hl.arange and generate mask
665+
if (
666+
original_length := _get_padded_iota_original_length(state, n)
667+
) is not None:
668+
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
637669
output_idx += 1
638670
elif (
639671
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1

helion/_compiler/inductor_lowering.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch.fx.interpreter import Interpreter
4040
from torch.fx.node import Node
4141
from torch.fx.node import map_arg
42+
from triton import next_power_of_2
4243

4344
from .. import exc
4445
from ..exc import InductorLoweringError
@@ -1451,15 +1452,21 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
14511452

14521453
@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue]
14531454
def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
1454-
"""Generate tl.arange for torch.ops.prims.iota.default operations."""
1455+
"""Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding."""
14551456
start = node.kwargs.get("start", 0)
14561457
step = node.kwargs.get("step", 1)
14571458
dtype = (
14581459
node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype
14591460
)
14601461
assert isinstance(dtype, torch.dtype)
14611462
(length_arg,) = node.args # expecting a single argument for length
1462-
expr = "tl.arange(0, {length})"
1463+
1464+
# Pad static non-power-of-2 lengths to next power of 2
1465+
length_expr = "{length}"
1466+
if isinstance(length_arg, int) and length_arg != next_power_of_2(length_arg):
1467+
length_expr = str(next_power_of_2(length_arg))
1468+
1469+
expr = f"tl.arange(0, {length_expr})"
14631470
if step != 1:
14641471
expr = f"{{step}} * {expr}"
14651472
if start != 0:

helion/language/atomic_ops.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import itertools
45
from typing import TYPE_CHECKING
56
from typing import Callable
67

@@ -125,15 +126,22 @@ def _ref_apply(
125126
if tensor_indices:
126127
# Element-wise processing for tensor indices (handle first tensor index)
127128
i, tensor_idx = tensor_indices[0]
128-
for j, elem in enumerate(tensor_idx):
129+
130+
if tensor_idx.ndim == 0:
131+
coords_iter = [()]
132+
else:
133+
ranges = [range(dim) for dim in tensor_idx.shape]
134+
coords_iter = itertools.product(*ranges)
135+
136+
for coords in coords_iter:
137+
elem = tensor_idx[coords].item()
129138
new_index = processed_index.copy()
130-
new_index[i] = int(elem.item())
131-
val = (
132-
value[j]
133-
if isinstance(value, torch.Tensor) and value.numel() > 1
134-
else value
135-
)
136-
apply_fn(target, tuple(new_index), val)
139+
new_index[i] = int(elem)
140+
if isinstance(value, torch.Tensor) and value.numel() > 1:
141+
next_value = value[coords]
142+
else:
143+
next_value = value
144+
_ref_apply(target, new_index, apply_fn, next_value)
137145
else:
138146
apply_fn(target, tuple(processed_index), value)
139147

@@ -208,58 +216,50 @@ def _(
208216
_validate_sem(sem)
209217
from .ref_tile import RefTile
210218

211-
# Convert indices and detect tensor indices for element-wise updates
219+
# Convert indices for shape computation and fast path detection
212220
processed_index: list[object] = []
213-
tensor_indices: list[tuple[int, torch.Tensor]] = []
214-
for i, idx in enumerate(index):
221+
has_tensor_index = False
222+
for idx in index:
215223
if isinstance(idx, RefTile):
216224
processed_index.append(idx._slice)
217225
elif isinstance(idx, torch.Tensor):
218226
if idx.numel() == 1:
219227
processed_index.append(int(idx.item()))
220228
else:
221229
processed_index.append(idx)
222-
tensor_indices.append((i, idx))
230+
has_tensor_index = True
223231
else:
224232
processed_index.append(idx)
225233

226-
if tensor_indices:
227-
# Element-wise processing for the first tensor index to ensure correct semantics
228-
i, idx_tensor = tensor_indices[0]
229-
ret = torch.empty_like(idx_tensor, dtype=target.dtype, device=target.device)
230-
# Flatten to assign easily
231-
flat_ret = ret.reshape(-1)
232-
flat_idx = idx_tensor.reshape(-1)
233-
# Prepare value per element
234-
if isinstance(value, torch.Tensor) and value.numel() > 1:
235-
flat_val = value.reshape(-1)
234+
def _convert_value_to_target_dtype(val: object) -> torch.Tensor:
235+
if isinstance(val, torch.Tensor):
236+
vt = val.to(device=target.device)
237+
if vt.dtype != target.dtype:
238+
vt = vt.to(dtype=target.dtype)
239+
return vt
240+
return torch.as_tensor(val, dtype=target.dtype, device=target.device)
241+
242+
if has_tensor_index:
243+
ret_shape = SubscriptIndexing.compute_shape(target, processed_index)
244+
prev_chunks: list[torch.Tensor] = []
245+
246+
def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
247+
prev_val = t[idx_tuple].clone() # pyright: ignore[reportArgumentType]
248+
val_tensor = _convert_value_to_target_dtype(v)
249+
t[idx_tuple] = t[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType]
250+
prev_chunks.append(prev_val.reshape(-1))
251+
252+
_ref_apply(target, index, apply, value)
253+
if prev_chunks:
254+
flat_prev = torch.cat(prev_chunks)
236255
else:
237-
flat_val = None
238-
for j, elem in enumerate(flat_idx):
239-
new_index = list(processed_index)
240-
new_index[i] = int(elem.item())
241-
new_index_t = tuple(new_index)
242-
prev = target[new_index_t] # pyright: ignore[reportArgumentType]
243-
vj = flat_val[j] if flat_val is not None else value
244-
# Convert scalar to tensor on device
245-
vj_t = (
246-
vj
247-
if isinstance(vj, torch.Tensor)
248-
else torch.as_tensor(vj, dtype=target.dtype, device=target.device)
249-
)
250-
target[new_index_t] = target[new_index_t] + vj_t # pyright: ignore[reportArgumentType]
251-
flat_ret[j] = prev # pyright: ignore[reportArgumentType]
252-
return ret
256+
flat_prev = target.new_empty(0, dtype=target.dtype, device=target.device)
257+
return flat_prev.reshape(ret_shape)
253258

254-
# Scalar or simple indexing path
255259
idx_tuple = tuple(processed_index)
256260
prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType]
257-
val = (
258-
value
259-
if isinstance(value, torch.Tensor)
260-
else torch.as_tensor(value, dtype=target.dtype, device=target.device)
261-
)
262-
target[idx_tuple] = target[idx_tuple] + val # pyright: ignore[reportArgumentType]
261+
val_tensor = _convert_value_to_target_dtype(value)
262+
target[idx_tuple] = target[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType]
263263
return prev
264264

265265

test/test_indexing.expected

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,75 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
185185
_launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
186186
return out
187187

188+
--- assertExpectedJournal(TestIndexing.test_hl_arange_non_power_of_2)
189+
from __future__ import annotations
190+
191+
import torch
192+
import triton
193+
import triton.language as tl
194+
from helion.runtime import default_launcher as _default_launcher
195+
196+
@triton.jit
197+
def _helion__matmul_layernorm_bwd_dxdy(z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, grad_y_stride_0, grad_y_stride_1, mean_stride_0, rstd_stride_0, weight_stride_0, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, z_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
198+
pid_0 = tl.program_id(0)
199+
offset_0 = pid_0 * _BLOCK_SIZE_0
200+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
201+
mask_0 = indices_0 < m
202+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
203+
mask_1 = indices_1 < 7
204+
indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
205+
mask_2 = indices_2 < 3
206+
load = tl.load(z + (indices_0[:, None] * z_stride_0 + indices_1[None, :] * z_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
207+
v_0 = tl.cast(load, tl.float32)
208+
load_1 = tl.load(grad_out + (indices_0[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
209+
v_1 = tl.cast(load_1, tl.float32)
210+
load_2 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
211+
v_2 = tl.cast(load_2, tl.float32)
212+
mean_tile = tl.load(mean + indices_0 * mean_stride_0, mask_0, other=0)
213+
rstd_tile = tl.load(rstd + indices_0 * rstd_stride_0, mask_0, other=0)
214+
subscript = mean_tile[:, None]
215+
v_3 = v_0 - subscript
216+
subscript_1 = rstd_tile[:, None]
217+
v_4 = v_3 * subscript_1
218+
v_5 = v_2[None, :]
219+
v_6 = v_5 * v_1
220+
v_7 = v_4 * v_6
221+
sum_1 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
222+
v_8 = 0.14285714285714285
223+
v_9 = sum_1 * v_8
224+
sum_2 = tl.cast(tl.reshape(tl.sum(v_6, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
225+
v_10 = 0.14285714285714285
226+
v_11 = sum_2 * v_10
227+
v_12 = v_4 * v_9
228+
v_13 = v_12 + v_11
229+
v_14 = v_6 - v_13
230+
subscript_2 = rstd_tile[:, None]
231+
v_15 = v_14 * subscript_2
232+
load_5 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
233+
permute = tl.permute(load_5, [1, 0])
234+
v_16 = tl.cast(permute, tl.float32)
235+
mm = tl.dot(tl.reshape(tl.permute(tl.join(tl.cast(v_15, tl.float32), tl.zeros_like(tl.cast(v_15, tl.float32))), [0, 2, 1]), [16, 16]), tl.reshape(tl.permute(tl.join(tl.cast(v_16, tl.float32), tl.zeros_like(tl.cast(v_16, tl.float32))), [2, 0, 1]), [16, 4]), input_precision='tf32', out_dtype=tl.float32)
236+
v_17 = tl.cast(mm, tl.float16)
237+
tl.store(grad_x + (indices_0[:, None] * grad_x_stride_0 + indices_2[None, :] * grad_x_stride_1), v_17, mask_0[:, None] & mask_2[None, :])
238+
load_6 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
239+
permute_1 = tl.permute(load_6, [1, 0])
240+
v_18 = tl.cast(permute_1, tl.float32)
241+
mm_1 = tl.dot(tl.cast(v_18, tl.float32), tl.cast(v_15, tl.float32), input_precision='tf32', out_dtype=tl.float32)
242+
v_19 = tl.cast(mm_1, tl.float16)
243+
iota = tl.arange(0, 4)
244+
iota_1 = tl.arange(0, 8)
245+
tl.atomic_add(grad_y + (iota[:, None] * grad_y_stride_0 + iota_1[None, :] * grad_y_stride_1), v_19, mask=(iota < 3)[:, None] & (iota_1 < 7)[None, :], sem='relaxed')
246+
247+
def _matmul_layernorm_bwd_dxdy(grad_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launcher):
248+
m, n = z.shape
249+
grad_x = torch.empty_like(x)
250+
grad_y = torch.zeros_like(y)
251+
_BLOCK_SIZE_0 = 16
252+
_RDIM_SIZE_1 = 8
253+
_RDIM_SIZE_2 = 4
254+
_launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2)
255+
return (grad_x, grad_y)
256+
188257
--- assertExpectedJournal(TestIndexing.test_mask_load)
189258
from __future__ import annotations
190259

test/test_indexing.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,90 @@ def arange(length: int, device: torch.device) -> torch.Tensor:
6363
)
6464
self.assertExpectedJournal(code)
6565

66+
def test_hl_arange_non_power_of_2(self):
67+
@helion.kernel
68+
def _matmul_layernorm_bwd_dxdy(
69+
grad_out: torch.Tensor,
70+
x: torch.Tensor,
71+
y: torch.Tensor,
72+
z: torch.Tensor,
73+
mean: torch.Tensor,
74+
rstd: torch.Tensor,
75+
weight: torch.Tensor,
76+
) -> tuple[torch.Tensor, torch.Tensor]:
77+
m, n = z.shape
78+
k = x.shape[1]
79+
n = hl.specialize(n)
80+
k = hl.specialize(k)
81+
82+
grad_x = torch.empty_like(x)
83+
grad_y = torch.zeros_like(y)
84+
85+
for tile_m in hl.tile(m):
86+
z_tile = z[tile_m, :].to(torch.float32)
87+
dy_tile = grad_out[tile_m, :].to(torch.float32)
88+
w = weight[:].to(torch.float32)
89+
mean_tile = mean[tile_m]
90+
rstd_tile = rstd[tile_m]
91+
92+
z_hat = (z_tile - mean_tile[:, None]) * rstd_tile[:, None]
93+
wdy = w * dy_tile
94+
c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n)
95+
c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n)
96+
dz = (wdy - (z_hat * c1 + c2)) * rstd_tile[:, None]
97+
98+
grad_x[tile_m, :] = (dz @ y[:, :].t().to(torch.float32)).to(x.dtype)
99+
grad_y_update = (x[tile_m, :].t().to(torch.float32) @ dz).to(y.dtype)
100+
101+
hl.atomic_add(
102+
grad_y,
103+
[
104+
hl.arange(0, k),
105+
hl.arange(0, n),
106+
],
107+
grad_y_update,
108+
)
109+
110+
return grad_x, grad_y
111+
112+
m, k, n = 5, 3, 7
113+
eps = 1e-5
114+
115+
x = torch.randn((m, k), device=DEVICE, dtype=torch.float16)
116+
y = torch.randn((k, n), device=DEVICE, dtype=torch.float16)
117+
weight = torch.randn((n,), device=DEVICE, dtype=torch.float16)
118+
grad_out = torch.randn((m, n), device=DEVICE, dtype=torch.float16)
119+
120+
z = (x @ y).to(torch.float32)
121+
var, mean = torch.var_mean(z, dim=-1, keepdim=True, correction=0)
122+
rstd = torch.rsqrt(var + eps)
123+
124+
code, (grad_x, grad_y) = code_and_output(
125+
_matmul_layernorm_bwd_dxdy,
126+
(
127+
grad_out,
128+
x,
129+
y,
130+
z.to(x.dtype),
131+
mean.squeeze(-1),
132+
rstd.squeeze(-1),
133+
weight,
134+
),
135+
)
136+
137+
# PyTorch reference gradients
138+
z_hat = (z - mean) * rstd
139+
wdy = weight.to(torch.float32) * grad_out.to(torch.float32)
140+
c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n)
141+
c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n)
142+
dz = (wdy - (z_hat * c1 + c2)) * rstd
143+
ref_grad_x = (dz @ y.to(torch.float32).t()).to(grad_x.dtype)
144+
ref_grad_y = (x.to(torch.float32).t() @ dz).to(grad_y.dtype)
145+
146+
torch.testing.assert_close(grad_x, ref_grad_x, rtol=1e-3, atol=2e-3)
147+
torch.testing.assert_close(grad_y, ref_grad_y, rtol=1e-3, atol=2e-3)
148+
self.assertExpectedJournal(code)
149+
66150
def test_pairwise_add(self):
67151
@helion.kernel()
68152
def pairwise_add(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)