Skip to content

Commit 9449ed8

Browse files
authored
Merge branch 'main' into autotuner_cudagraph
2 parents 5001a53 + 5fb337c commit 9449ed8

File tree

8 files changed

+626
-112
lines changed

8 files changed

+626
-112
lines changed

helion/_compiler/device_ir.py

Lines changed: 303 additions & 109 deletions
Large diffs are not rendered by default.

helion/_compiler/type_propagation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2126,7 +2126,10 @@ def visit_Assert(self, node: ast.Assert) -> TypeInfo:
21262126

21272127
visit_Raise: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
21282128
visit_Delete: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
2129-
visit_Pass: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
2129+
2130+
def visit_Pass(self, node: ast.Pass) -> TypeInfo:
2131+
return NoType(origin=self.origin())
2132+
21302133
visit_TypeAlias: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType, reportIncompatibleMethodOverride]
21312134
visit_Import: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]
21322135
visit_ImportFrom: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride]

helion/autotuner/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@
2525
"PatternSearch": PatternSearch,
2626
"RandomSearch": RandomSearch,
2727
}
28+
29+
cache_classes = {
30+
"LocalAutotuneCache": LocalAutotuneCache,
31+
"StrictLocalAutotuneCache": StrictLocalAutotuneCache,
32+
}

helion/language/_tracing_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ def _(state: CodegenState) -> None:
8888
return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
8989

9090

91+
@has_side_effect
92+
@_decorators.api()
93+
def _while_loop(
94+
cond_graph_id: int,
95+
body_graph_id: int,
96+
args: list[object],
97+
orelse_graph_id: int | None = None,
98+
) -> list[object]:
99+
"""Represent a while loop in FX since FX lacks native control flow."""
100+
raise AssertionError("this should never be called")
101+
102+
103+
@_decorators.codegen(_while_loop)
104+
def _(state: CodegenState) -> None:
105+
return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
106+
107+
91108
@has_side_effect
92109
@_decorators.api()
93110
def _if(test: object, graph_id: int, args: list[object]) -> list[object]:

helion/runtime/settings.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def _env_get_literal(
131131
)
132132

133133

134+
def _env_get_str(var_name: str, default: str) -> str:
135+
value = os.environ.get(var_name)
136+
if value is None or (value := value.strip()) == "":
137+
return default
138+
return value
139+
140+
134141
def _get_index_dtype() -> torch.dtype:
135142
value = os.environ.get("HELION_INDEX_DTYPE")
136143
if value is None or (token := value.strip()) == "":
@@ -184,7 +191,7 @@ def _get_autotune_config_overrides() -> dict[str, object]:
184191
def default_autotuner_fn(
185192
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
186193
) -> BaseAutotuner:
187-
from ..autotuner import LocalAutotuneCache
194+
from ..autotuner import cache_classes
188195
from ..autotuner import search_algorithms
189196

190197
autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch")
@@ -223,7 +230,16 @@ def default_autotuner_fn(
223230
assert profile.random_search is not None
224231
kwargs.setdefault("count", profile.random_search.count)
225232

226-
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
233+
settings = bound_kernel.settings
234+
cache_name = settings.autotune_cache
235+
cache_cls = cache_classes.get(cache_name)
236+
if cache_cls is None:
237+
raise ValueError(
238+
f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: "
239+
f"{', '.join(cache_classes.keys())}"
240+
)
241+
242+
return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
227243

228244

229245
def _get_autotune_random_seed() -> int:
@@ -348,6 +364,11 @@ class _Settings:
348364
)
349365
)
350366
ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
367+
autotune_cache: str = dataclasses.field(
368+
default_factory=functools.partial(
369+
_env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache"
370+
)
371+
)
351372
autotuner_fn: AutotunerFunction = default_autotuner_fn
352373
autotune_baseline_fn: Callable[..., object] | None = None
353374

@@ -413,6 +434,11 @@ class Settings(_Settings):
413434
"Should have the same signature as the kernel function. "
414435
"Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)."
415436
),
437+
"autotune_cache": (
438+
"The name of the autotuner cache class to use. "
439+
"Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. "
440+
"Defaults to 'LocalAutotuneCache'."
441+
),
416442
}
417443

418444
def __init__(self, **settings: object) -> None:

test/test_autotuner.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import math
99
import multiprocessing as mp
10+
import operator
1011
import os
1112
from pathlib import Path
1213
import pickle
@@ -41,6 +42,8 @@
4142
from helion.autotuner.config_generation import ConfigGeneration
4243
from helion.autotuner.effort_profile import get_effort_profile
4344
from helion.autotuner.finite_search import FiniteSearch
45+
from helion.autotuner.local_cache import LocalAutotuneCache
46+
from helion.autotuner.local_cache import StrictLocalAutotuneCache
4447
from helion.autotuner.logger import LambdaLogger
4548
from helion.autotuner.random_search import RandomSearch
4649
import helion.language as hl
@@ -955,5 +958,59 @@ def test_autotune_random_seed_from_settings(self) -> None:
955958
self.assertNotEqual(first, second)
956959

957960

961+
class TestAutotuneCacheSelection(TestCase):
962+
"""Selection of the autotune cache via HELION_AUTOTUNE_CACHE."""
963+
964+
def _make_bound(self):
965+
@helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0)
966+
def add(a: torch.Tensor, b: torch.Tensor):
967+
out = torch.empty_like(a)
968+
for tile in hl.tile(out.size()):
969+
out[tile] = a[tile] + b[tile]
970+
return out
971+
972+
args = (
973+
torch.randn([8], device=DEVICE),
974+
torch.randn([8], device=DEVICE),
975+
)
976+
return add.bind(args), args
977+
978+
def test_autotune_cache_default_is_local(self):
979+
"""Default (no env var set) -> LocalAutotuneCache."""
980+
with without_env_var("HELION_AUTOTUNE_CACHE"):
981+
bound, args = self._make_bound()
982+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
983+
sync.return_value = None
984+
autotuner = bound.settings.autotuner_fn(bound, args)
985+
self.assertIsInstance(autotuner, LocalAutotuneCache)
986+
self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache)
987+
988+
def test_autotune_cache_strict_selected_by_env(self):
989+
"""HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache."""
990+
with patch.dict(
991+
os.environ,
992+
{"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"},
993+
clear=False,
994+
):
995+
bound, args = self._make_bound()
996+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
997+
sync.return_value = None
998+
autotuner = bound.settings.autotuner_fn(bound, args)
999+
self.assertIsInstance(autotuner, StrictLocalAutotuneCache)
1000+
1001+
def test_autotune_cache_invalid_raises(self):
1002+
"""Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError."""
1003+
with patch.dict(
1004+
os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False
1005+
):
1006+
bound, args = self._make_bound()
1007+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
1008+
sync.return_value = None
1009+
with self.assertRaisesRegex(
1010+
ValueError, "Unknown HELION_AUTOTUNE_CACHE"
1011+
):
1012+
bound.settings.autotuner_fn(bound, args)
1013+
1014+
9581015
if __name__ == "__main__":
9591016
unittest.main()

test/test_loops.expected

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,3 +1730,158 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
17301730
_launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
17311731
# src[test_loops.py:N]: return out
17321732
return out
1733+
1734+
--- assertExpectedJournal(TestLoops.test_while_accumulates_tensor)
1735+
from __future__ import annotations
1736+
1737+
import torch
1738+
import triton
1739+
import triton.language as tl
1740+
from helion.runtime import default_launcher as _default_launcher
1741+
1742+
@triton.jit
1743+
def _helion_kernel(out, _BLOCK_SIZE_0: tl.constexpr):
1744+
# src[test_loops.py:N]: for tile in hl.tile(x.shape):
1745+
pid_0 = tl.program_id(0)
1746+
offset_0 = pid_0 * _BLOCK_SIZE_0
1747+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1748+
# src[test_loops.py:N]: acc = torch.zeros_like(x[tile])
1749+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
1750+
# src[test_loops.py:N]: steps = torch.zeros([], device=x.device, dtype=torch.int32)
1751+
steps = tl.full([], 0, tl.int32)
1752+
# src[test_loops.py:N]: while steps < 4:
1753+
# src[test_loops.py:N]: acc = acc + 1
1754+
# src[test_loops.py:N]: steps = steps + 1
1755+
steps_copy = steps
1756+
steps_copy_0 = steps_copy
1757+
# src[test_loops.py:N]: while steps < 4:
1758+
v_0 = tl.full([], 4, tl.int32)
1759+
v_1 = steps_copy_0 < v_0
1760+
# src[test_loops.py:N]: while steps < 4:
1761+
# src[test_loops.py:N]: acc = acc + 1
1762+
# src[test_loops.py:N]: steps = steps + 1
1763+
while_cond = v_1
1764+
while while_cond:
1765+
steps_copy_1 = steps
1766+
acc_copy = acc
1767+
steps_copy_1_0 = steps_copy_1
1768+
acc_copy_0 = acc_copy
1769+
# src[test_loops.py:N]: acc = acc + 1
1770+
v_2 = 1.0
1771+
acc = acc_copy_0 + v_2
1772+
# src[test_loops.py:N]: steps = steps + 1
1773+
v_4 = tl.full([], 1, tl.int32)
1774+
steps = steps_copy_1_0 + v_4
1775+
# src[test_loops.py:N]: while steps < 4:
1776+
# src[test_loops.py:N]: acc = acc + 1
1777+
# src[test_loops.py:N]: steps = steps + 1
1778+
steps_copy_2 = steps
1779+
steps_copy_2_0 = steps_copy_2
1780+
# src[test_loops.py:N]: while steps < 4:
1781+
v_6 = tl.full([], 4, tl.int32)
1782+
v_7 = steps_copy_2_0 < v_6
1783+
# src[test_loops.py:N]: while steps < 4:
1784+
# src[test_loops.py:N]: acc = acc + 1
1785+
# src[test_loops.py:N]: steps = steps + 1
1786+
while_cond = v_7
1787+
# src[test_loops.py:N]: out[tile] = acc
1788+
tl.store(out + indices_0 * 1, acc, None)
1789+
1790+
def kernel(x: torch.Tensor, *, _launcher=_default_launcher):
1791+
# src[test_loops.py:N]: out = torch.empty_like(x)
1792+
out = torch.empty_like(x)
1793+
# src[test_loops.py:N]: for tile in hl.tile(x.shape):
1794+
_BLOCK_SIZE_0 = 16
1795+
# src[test_loops.py:N]: for tile in hl.tile(x.shape):
1796+
# src[test_loops.py:N]: acc = torch.zeros_like(x[tile])
1797+
# src[test_loops.py:N]: steps = torch.zeros([], device=x.device, dtype=torch.int32)
1798+
# src[test_loops.py:N-N]: ...
1799+
_launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1800+
# src[test_loops.py:N]: return out
1801+
return out
1802+
1803+
--- assertExpectedJournal(TestLoops.test_while_atomic_add_accumulates)
1804+
from __future__ import annotations
1805+
1806+
import torch
1807+
import triton
1808+
import triton.language as tl
1809+
from helion.runtime import default_launcher as _default_launcher
1810+
1811+
@triton.jit
1812+
def _helion_kernel(counters, values, totals):
1813+
# src[test_loops.py:N]: for idx in hl.tile(values.size(0)):
1814+
pid_0 = tl.program_id(0)
1815+
offset_0 = pid_0
1816+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
1817+
# src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1818+
atomic_add = tl.atomic_add(counters + indices_0 * 1, 1, mask=None, sem='relaxed')
1819+
sum_1 = tl.cast(tl.sum(atomic_add, 0), tl.float32)
1820+
v_0 = 1.0
1821+
v_1 = sum_1 < v_0
1822+
# src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1823+
# src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1824+
while_cond = v_1
1825+
while while_cond:
1826+
# src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1827+
load = tl.load(values + indices_0 * 1, None)
1828+
tl.atomic_add(totals + indices_0 * 1, load, mask=None, sem='relaxed')
1829+
# src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1830+
atomic_add_1 = tl.atomic_add(counters + indices_0 * 1, 1, mask=None, sem='relaxed')
1831+
sum_2 = tl.cast(tl.sum(atomic_add_1, 0), tl.float32)
1832+
v_2 = 1.0
1833+
v_3 = sum_2 < v_2
1834+
# src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1835+
# src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1836+
while_cond = v_3
1837+
1838+
def kernel(values: torch.Tensor, totals: torch.Tensor, counters: torch.Tensor, *, _launcher=_default_launcher):
1839+
# src[test_loops.py:N]: for idx in hl.tile(values.size(0)):
1840+
# src[test_loops.py:N]: while hl.atomic_add(counters, [idx], 1).sum() < 1:
1841+
# src[test_loops.py:N]: hl.atomic_add(totals, [idx], values[idx])
1842+
_launcher(_helion_kernel, (8,), counters, values, totals, num_warps=4, num_stages=1)
1843+
# src[test_loops.py:N]: return totals
1844+
return totals
1845+
1846+
--- assertExpectedJournal(TestLoops.test_while_atomic_cas_pass)
1847+
from __future__ import annotations
1848+
1849+
import torch
1850+
import triton
1851+
import triton.language as tl
1852+
from helion.runtime import default_launcher as _default_launcher
1853+
1854+
@triton.jit
1855+
def _helion_kernel(grad_x_lock, _BLOCK_SIZE_0: tl.constexpr):
1856+
# src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1857+
pid_0 = tl.program_id(0)
1858+
offset_0 = pid_0 * _BLOCK_SIZE_0
1859+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1860+
# src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1861+
atomic_cas = tl.atomic_cas(grad_x_lock + indices_0 * 1, 0, 1, sem='relaxed')
1862+
v_0 = tl.full([], 1, tl.int32)
1863+
v_1 = atomic_cas == v_0
1864+
# src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1865+
# src[test_loops.py:N]: pass
1866+
while_cond = v_1
1867+
while while_cond:
1868+
# src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1869+
atomic_cas_1 = tl.atomic_cas(grad_x_lock + indices_0 * 1, 0, 1, sem='relaxed')
1870+
v_2 = tl.full([], 1, tl.int32)
1871+
v_3 = atomic_cas_1 == v_2
1872+
# src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1873+
# src[test_loops.py:N]: pass
1874+
while_cond = v_3
1875+
# src[test_loops.py:N]: hl.atomic_cas(grad_x_lock, [idx], 1, 0)
1876+
tl.atomic_cas(grad_x_lock + indices_0 * 1, 1, 0, sem='relaxed')
1877+
1878+
def kernel(grad_x_lock: torch.Tensor, *, _launcher=_default_launcher):
1879+
# src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1880+
_BLOCK_SIZE_0 = 16
1881+
# src[test_loops.py:N]: for idx in hl.tile(grad_x_lock.size(0)):
1882+
# src[test_loops.py:N]: while hl.atomic_cas(grad_x_lock, [idx], 0, 1) == 1:
1883+
# src[test_loops.py:N]: pass
1884+
# src[test_loops.py:N-N]: ...
1885+
_launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), grad_x_lock, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1886+
# src[test_loops.py:N]: return grad_x_lock
1887+
return grad_x_lock

0 commit comments

Comments
 (0)