Skip to content

Commit e500e91

Browse files
authored
[doc] Add deployment/autotuning guide (#869)
1 parent 5c8e194 commit e500e91

File tree

6 files changed

+237
-2
lines changed

6 files changed

+237
-2
lines changed

docs/api/autotuner.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,10 @@ The autotuner supports multiple search strategies:
5252
.. automodule:: helion.autotuner.finite_search
5353
:members:
5454
```
55+
56+
### Local Cache
57+
58+
```{eval-rst}
59+
.. automodule:: helion.autotuner.local_cache
60+
:members:
61+
```

docs/api/exceptions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Exceptions
22

3-
The `helion.exc` module provides a exception hierarchy for error handling and diagnostics.
3+
The `helion.exc` module provides an exception hierarchy for error handling and diagnostics.
44

55
```{eval-rst}
66
.. currentmodule:: helion.exc

docs/api/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ runtime
101101
register_tunable
102102
constexpr
103103
specialize
104+
```
104105

105106
### Language Classes
106107

@@ -130,4 +131,3 @@ runtime
130131
tile_block_size
131132
tile_id
132133
```
133-
```

docs/api/settings.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ with helion.set_default_settings(
162162
Dict of config key/value pairs to force during autotuning. Useful for disabling problematic candidates or pinning experimental options.
163163
```
164164

165+
### Autotuning Cache
166+
167+
Helion stores the best-performing configs discovered during autotuning in an on-disk cache so subsequent runs can skip the search.
168+
169+
- `HELION_CACHE_DIR`: Override the directory used to store cache entries. Defaults to PyTorch’s `torch._inductor` cache path (typically `/tmp/torchinductor_$USER/helion`).
170+
- `HELION_SKIP_CACHE`: Set to `1` to ignore cached entries and force the autotuner to re-run even if a matching artifact exists.
171+
172+
See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and behavior.
173+
165174
### Debugging and Development
166175

167176
```{eval-rst}

docs/deployment_autotuning.md

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Deployment and Autotuning
2+
3+
Helion’s autotuner explores a large search space which is a
4+
time-consuming process, so production deployments should generate
5+
autotuned configs **ahead of time**. Run the autotuner on a development
6+
workstation or a dedicated tuning box that mirrors your target
7+
GPU/accelerator. Check tuned configs into your repository alongside the kernel,
8+
or package them as data files and load them with `helion.Config.load`
9+
(see {doc}`api/config`). This keeps production kernel startup fast and
10+
deterministic, while also giving explicit control over when autotuning
11+
happens.
12+
13+
If you don't specify pre-tuned configs, Helion will autotune on the
14+
first call for each specialization key. This is convenient for
15+
experimentation, but not ideal for production since the first call
16+
pays a large tuning cost. Helion writes successful tuning results to
17+
an on-disk cache (overridable with `HELION_CACHE_DIR`, skippable
18+
with `HELION_SKIP_CACHE`, see {doc}`api/settings`) so repeated
19+
runs on the same machine can reuse prior configs. For more on
20+
caching see {py:class}`~helion.autotuner.local_cache.LocalAutotuneCache`
21+
and {py:class}`~helion.autotuner.local_cache.StrictLocalAutotuneCache`.
22+
23+
The rest of this document covers strategies for pre-tuning and deploying
24+
tuned configs, which is the recommended approach for production workloads.
25+
26+
## Run Autotuning Jobs
27+
28+
The simplest way to launch autotuning straight through the kernel call:
29+
30+
```python
31+
import torch, helion
32+
33+
@helion.kernel()
34+
def my_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
35+
...
36+
37+
example_inputs = (
38+
torch.randn(1048576, device="cuda"),
39+
torch.randn(1048576, device="cuda"),
40+
)
41+
42+
# First call triggers autotuning, which is cached for future calls, and prints the best config found.
43+
my_kernel(*example_inputs)
44+
```
45+
Set `HELION_FORCE_AUTOTUNE=1` to re-run tuning even when cached configs
46+
exist (documented in {doc}`api/settings`).
47+
48+
Call `my_kernel.autotune(example_inputs)` explicitly to separate
49+
tuning from execution (see {doc}`api/kernel`).
50+
`autotune()` returns the best config found, which you can save for
51+
later use. Tune against multiple sizes by invoking `autotune` with
52+
a list of representative shapes, for example:
53+
54+
```python
55+
datasets = {
56+
"s": (
57+
torch.randn(2**16, device="cuda"),
58+
torch.randn(2**16, device="cuda"),
59+
),
60+
"m": (
61+
torch.randn(2**20, device="cuda"),
62+
torch.randn(2**20, device="cuda"),
63+
),
64+
"l": (
65+
torch.randn(2**24, device="cuda"),
66+
torch.randn(2**24, device="cuda"),
67+
),
68+
}
69+
70+
for tag, args in datasets.items():
71+
config = my_kernel.autotune(args)
72+
config.save(f"configs/my_kernel_{tag}.json")
73+
```
74+
75+
### Direct Control Over Autotuners
76+
77+
When you need more control, construct autotuners
78+
manually. {py:class}`~helion.autotuner.pattern_search.PatternSearch` is the default
79+
autotuner:
80+
81+
```python
82+
from helion.autotuner import PatternSearch
83+
84+
bound = my_kernel.bind(example_inputs)
85+
tuner = PatternSearch(
86+
bound,
87+
example_inputs,
88+
# Double the defaults to explore more candidates:
89+
initial_population=200, # Default is 100.
90+
copies=10, # Default is 5.
91+
max_generations=40, # Default is 20.
92+
)
93+
best_config = tuner.autotune()
94+
best_config.save("configs/my_kernel.json")
95+
```
96+
97+
- Adjust `initial_population`, `copies`, or `max_generations` to trade
98+
tuning time versus coverage, or try different search algorithms.
99+
100+
- Use different input tuples to produce multiple saved configs
101+
(`my_kernel_large.json`, `my_kernel_fp8.json`, etc.).
102+
103+
- Tuning runs can be seeded with `HELION_AUTOTUNE_RANDOM_SEED` if you
104+
need more reproducibility; see {doc}`api/settings`. Note this only
105+
affects which configs are tried, not the timing results.
106+
107+
## Deploy a Single Config
108+
109+
If one configuration wins for every production call, bake it into the decorator:
110+
111+
```python
112+
best = helion.Config.load("configs/my_kernel.json")
113+
114+
@helion.kernel(config=best)
115+
def my_kernel(x, y):
116+
...
117+
```
118+
119+
The supplied `config` applies to **all** argument shapes, dtypes, and
120+
devices that hit this kernel. This is ideal for workloads with a single
121+
critical path or when you manage routing externally. `helion.Config.save`
122+
/ `load` make it easy to copy configs between machines; details live
123+
in {doc}`api/config`. One can also copy and paste the config from the
124+
autotuner output.
125+
126+
127+
## Deploy Multiple Configs
128+
129+
When you expect variability, supply a small list of candidates:
130+
131+
```python
132+
candidate_configs = [
133+
helion.Config.load("configs/my_kernel_small.json"),
134+
helion.Config.load("configs/my_kernel_large.json"),
135+
]
136+
137+
@helion.kernel(configs=candidate_configs, static_shapes=True)
138+
def my_kernel(x, y):
139+
...
140+
```
141+
142+
Helion performs a lightweight benchmark (similar to Triton’s autotune)
143+
the first time each specialization key is seen, running each provided
144+
config and selecting the fastest.
145+
146+
A key detail here is controlling the specialization key, which
147+
determines when to re-benchmark. Options include:
148+
149+
- **Default (dynamic shapes):** we reuse the timing result as long as
150+
tensor dtypes and device types stay constant. Shape changes only trigger
151+
a re-selection when a dimension size crosses the buckets `{0, 1, ≥2}`.
152+
153+
- **`static_shapes=True`:** add this setting to the decorator to specialize
154+
on the exact shape/stride signature, rerunning the selection whenever
155+
those shapes differ.
156+
157+
- **Custom keys:** pass `key=` to group calls however you like.
158+
This custom key is in addition to the above.
159+
160+
As an example, you could trigger re-tuning with power-of-two bucketing:
161+
162+
```python
163+
@helion.kernel(
164+
configs=candidate_configs,
165+
key=lambda x, y: helion.next_power_of_2(x.numel())
166+
)
167+
def my_kernel(x, y):
168+
...
169+
```
170+
171+
See {doc}`api/kernel` for the full decorator reference.
172+
173+
## Advanced Manual Deployment
174+
175+
Some teams prefer to skip all runtime selection, using Helion only as
176+
an ahead-of-time compiler. For this use case we provide `Kernel.bind`
177+
and `BoundKernel.compile_config`, enabling wrapper patterns that let
178+
you implement bespoke routing logic. For example, to route based on
179+
input size:
180+
181+
```python
182+
bound = my_kernel.bind(example_inputs)
183+
184+
small_cfg = helion.Config.load("configs/my_kernel_small.json")
185+
large_cfg = helion.Config.load("configs/my_kernel_large.json")
186+
187+
small_run = bound.compile_config(small_cfg) # Returns a callable
188+
large_run = bound.compile_config(large_cfg)
189+
190+
def routed_my_kernel(x, y):
191+
runner = small_run if x.numel() <= 2**16 else large_run
192+
return runner(x, y)
193+
```
194+
195+
`Kernel.bind` produces a `BoundKernel` tied to sample
196+
input types. You can pre-compile as many configs as you need using
197+
`BoundKernel.compile_config`. **Warning:** `kernel.bind()` specializes,
198+
and the result will only work with the same input types you passed.
199+
200+
- With `static_shapes=False` (default) it will specialize on the input
201+
dtypes, device types, and whether each dynamic dimension falls into the
202+
0, 1, or ≥2 bucket. Python types are also specialized. For dimensions
203+
that can vary across those buckets, supply representative inputs ≥2
204+
to avoid excessive specialization.
205+
206+
- With `static_shapes=True` the bound kernel only works for the exact
207+
shape/stride signature of the example inputs. The generated code will
208+
have shapes baked in, which often provides a performance boost.
209+
210+
If you need to support multiple input types, bind multiple times with
211+
representative inputs.
212+
213+
Alternately, you can export Triton source with
214+
`bound.to_triton_code(small_cfg)` to drop Helion from your serving
215+
environment altogether, embedding the generated kernel in a custom
216+
runtime. The Triton kernels could then be compiled down into PTX/cubins
217+
to further remove Python from the critical path, but details on this
218+
are beyond the scope of this document.

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:hidden:
77
88
installation
9+
deployment_autotuning
910
./examples/index
1011
helion_puzzles
1112
api/index

0 commit comments

Comments
 (0)