Skip to content

Commit 51f09a7

Browse files
authored
Set static_shapes=True (#937)
1 parent e317eb0 commit 51f09a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3043
-3495
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ and configurations directly from your code.
189189

190190
**For production deployment**, we recommend using ahead-of-time tuned configurations rather than relying on runtime autotuning. The autotuning process can be time-consuming and resource-intensive, making it unsuitable for production environments where predictable performance and startup times are critical.
191191

192+
### Static shapes and autotuning keys
193+
194+
By default Helion uses static shapes (`static_shapes=True`). This means each unique input shape/stride signature is treated as its own specialization and will be autotuned separately. This typically yields the best performance, but may increase autotuning time when many shapes are encountered.
195+
196+
If you want to reduce autotuning time by sharing configurations between different shapes, set `static_shapes=False`. In this mode, the autotuning key ignores exact sizes, allowing a single tuned config to be reused across multiple shapes. This can come with a performance penalty compared to fully specialized static shapes.
197+
198+
```python
199+
@helion.kernel(static_shapes=False)
200+
def my_kernel(x: torch.Tensor) -> torch.Tensor:
201+
...
202+
```
203+
192204
## Configurations
193205

194206
Helion configurations include the following options:

examples/low_mem_dropout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
# %%
29-
@helion.kernel()
29+
@helion.kernel(static_shapes=False)
3030
def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor:
3131
"""
3232
Applies dropout on x using p
@@ -57,7 +57,7 @@ def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor:
5757

5858

5959
# %%
60-
@helion.kernel()
60+
@helion.kernel(static_shapes=False)
6161
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor:
6262
"""
6363
For low mem dropout we are applying randomness inside both fwd and bwd

helion/runtime/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class _Settings:
166166
"Literal['tf32', 'tf32x3', 'ieee']",
167167
os.environ.get("TRITON_F32_DEFAULT", "tf32"),
168168
)
169-
static_shapes: bool = False
169+
static_shapes: bool = True
170170
autotune_log_level: int = logging.INFO
171171
autotune_compile_timeout: int = int(
172172
os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60")

test/test_associative_scan.expected

Lines changed: 223 additions & 281 deletions
Large diffs are not rendered by default.

test/test_atomic_ops.expected

Lines changed: 59 additions & 68 deletions
Large diffs are not rendered by default.

test/test_broadcasting.expected

Lines changed: 67 additions & 77 deletions
Large diffs are not rendered by default.

test/test_closures.expected

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,26 @@ from helion.runtime import default_launcher as _default_launcher
1212
import helion._testing.basic_kernels as _source_module
1313

1414
@triton.jit
15-
def _helion_use_globals(a, _source_module_attr_global_tensor, out, a_size_0, a_size_1, _source_module_attr_global_tensor_stride_0, a_stride_0, a_stride_1, out_stride_0, out_stride_1, _source_module_attr_global_float, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
16-
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
15+
def _helion_use_globals(a, _source_module_attr_global_tensor, out, _source_module_attr_global_float, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
16+
num_blocks_0 = tl.cdiv(512, _BLOCK_SIZE_0)
1717
pid_0 = tl.program_id(0) % num_blocks_0
1818
pid_1 = tl.program_id(0) // num_blocks_0
1919
offset_0 = pid_0 * _BLOCK_SIZE_0
2020
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
21-
mask_0 = indices_0 < a_size_0
2221
offset_1 = pid_1 * _BLOCK_SIZE_1
2322
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
24-
mask_1 = indices_1 < a_size_1
25-
load = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
26-
load_1 = tl.load(_source_module_attr_global_tensor + indices_1[None, :] * _source_module_attr_global_tensor_stride_0, mask_1[None, :], other=0)
23+
load = tl.load(a + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None)
24+
load_1 = tl.load(_source_module_attr_global_tensor + indices_1[None, :] * 1, None)
2725
v_0 = load + load_1
2826
v_1 = tl_math.sin(v_0)
2927
v_2 = v_1 + _source_module_attr_global_float
30-
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :])
28+
tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_2, None)
3129

3230
def use_globals(a, *, _launcher=_default_launcher):
3331
out = _source_module.empty_like(a)
3432
_BLOCK_SIZE_0 = 32
3533
_BLOCK_SIZE_1 = 32
36-
_launcher(_helion_use_globals, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, _source_module.global_tensor, out, a.size(0), a.size(1), _source_module.global_tensor.stride(0), a.stride(0), a.stride(1), out.stride(0), out.stride(1), _source_module.global_float, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
34+
_launcher(_helion_use_globals, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, _source_module.global_tensor, out, _source_module.global_float, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
3735
return out
3836

3937
--- assertExpectedJournal(TestClosures.test_fn_arg_with_closure)
@@ -160,17 +158,16 @@ from helion.runtime import default_launcher as _default_launcher
160158
import test.test_closures as _source_module
161159

162160
@triton.jit
163-
def _helion_call_func_arg_on_host(a, out, a_size_0, a_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
161+
def _helion_call_func_arg_on_host(a, out, _BLOCK_SIZE_0: tl.constexpr):
164162
pid_0 = tl.program_id(0)
165163
offset_0 = pid_0 * _BLOCK_SIZE_0
166164
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
167-
mask_0 = indices_0 < a_size_0
168-
load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
165+
load = tl.load(a + indices_0 * 1, None)
169166
v_0 = tl_math.sin(load)
170-
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
167+
tl.store(out + indices_0 * 1, v_0, None)
171168

172169
def call_func_arg_on_host(a, alloc, *, _launcher=_default_launcher):
173170
out = alloc(a)
174171
_BLOCK_SIZE_0 = 512
175-
_launcher(_helion_call_func_arg_on_host, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
172+
_launcher(_helion_call_func_arg_on_host, (triton.cdiv(512, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
176173
return out

test/test_closures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
global_tensor = torch.randn([512], device=DEVICE)
2020

2121

22-
@helion.kernel
22+
@helion.kernel(static_shapes=False)
2323
def sin_func_arg(a, fn) -> torch.Tensor:
2424
out = torch.empty_like(a)
2525
for tile in hl.tile(a.size()):

0 commit comments

Comments
 (0)