Skip to content

Commit f472e32

Browse files
authored
feat(autotuner): Make autotune cache class configurable via env var (#1071)
Signed-off-by: Alessandro Sangiorgi <asangior@redhat.com>
1 parent 4db264a commit f472e32

File tree

3 files changed

+90
-2
lines changed

3 files changed

+90
-2
lines changed

helion/autotuner/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@
2525
"PatternSearch": PatternSearch,
2626
"RandomSearch": RandomSearch,
2727
}
28+
29+
cache_classes = {
30+
"LocalAutotuneCache": LocalAutotuneCache,
31+
"StrictLocalAutotuneCache": StrictLocalAutotuneCache,
32+
}

helion/runtime/settings.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def _env_get_literal(
131131
)
132132

133133

134+
def _env_get_str(var_name: str, default: str) -> str:
135+
value = os.environ.get(var_name)
136+
if value is None or (value := value.strip()) == "":
137+
return default
138+
return value
139+
140+
134141
def _get_index_dtype() -> torch.dtype:
135142
value = os.environ.get("HELION_INDEX_DTYPE")
136143
if value is None or (token := value.strip()) == "":
@@ -184,7 +191,7 @@ def _get_autotune_config_overrides() -> dict[str, object]:
184191
def default_autotuner_fn(
185192
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
186193
) -> BaseAutotuner:
187-
from ..autotuner import LocalAutotuneCache
194+
from ..autotuner import cache_classes
188195
from ..autotuner import search_algorithms
189196

190197
autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch")
@@ -223,7 +230,16 @@ def default_autotuner_fn(
223230
assert profile.random_search is not None
224231
kwargs.setdefault("count", profile.random_search.count)
225232

226-
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
233+
settings = bound_kernel.settings
234+
cache_name = settings.autotune_cache
235+
cache_cls = cache_classes.get(cache_name)
236+
if cache_cls is None:
237+
raise ValueError(
238+
f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: "
239+
f"{', '.join(cache_classes.keys())}"
240+
)
241+
242+
return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
227243

228244

229245
def _get_autotune_random_seed() -> int:
@@ -348,6 +364,11 @@ class _Settings:
348364
)
349365
)
350366
ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
367+
autotune_cache: str = dataclasses.field(
368+
default_factory=functools.partial(
369+
_env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache"
370+
)
371+
)
351372
autotuner_fn: AutotunerFunction = default_autotuner_fn
352373
autotune_baseline_fn: Callable[..., object] | None = None
353374

@@ -413,6 +434,11 @@ class Settings(_Settings):
413434
"Should have the same signature as the kernel function. "
414435
"Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)."
415436
),
437+
"autotune_cache": (
438+
"The name of the autotuner cache class to use. "
439+
"Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. "
440+
"Defaults to 'LocalAutotuneCache'."
441+
),
416442
}
417443

418444
def __init__(self, **settings: object) -> None:

test/test_autotuner.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import math
99
import multiprocessing as mp
10+
import operator
1011
import os
1112
from pathlib import Path
1213
import pickle
@@ -41,6 +42,8 @@
4142
from helion.autotuner.config_generation import ConfigGeneration
4243
from helion.autotuner.effort_profile import get_effort_profile
4344
from helion.autotuner.finite_search import FiniteSearch
45+
from helion.autotuner.local_cache import LocalAutotuneCache
46+
from helion.autotuner.local_cache import StrictLocalAutotuneCache
4447
from helion.autotuner.logger import LambdaLogger
4548
from helion.autotuner.random_search import RandomSearch
4649
import helion.language as hl
@@ -955,5 +958,59 @@ def test_autotune_random_seed_from_settings(self) -> None:
955958
self.assertNotEqual(first, second)
956959

957960

961+
class TestAutotuneCacheSelection(TestCase):
962+
"""Selection of the autotune cache via HELION_AUTOTUNE_CACHE."""
963+
964+
def _make_bound(self):
965+
@helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0)
966+
def add(a: torch.Tensor, b: torch.Tensor):
967+
out = torch.empty_like(a)
968+
for tile in hl.tile(out.size()):
969+
out[tile] = a[tile] + b[tile]
970+
return out
971+
972+
args = (
973+
torch.randn([8], device=DEVICE),
974+
torch.randn([8], device=DEVICE),
975+
)
976+
return add.bind(args), args
977+
978+
def test_autotune_cache_default_is_local(self):
979+
"""Default (no env var set) -> LocalAutotuneCache."""
980+
with without_env_var("HELION_AUTOTUNE_CACHE"):
981+
bound, args = self._make_bound()
982+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
983+
sync.return_value = None
984+
autotuner = bound.settings.autotuner_fn(bound, args)
985+
self.assertIsInstance(autotuner, LocalAutotuneCache)
986+
self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache)
987+
988+
def test_autotune_cache_strict_selected_by_env(self):
989+
"""HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache."""
990+
with patch.dict(
991+
os.environ,
992+
{"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"},
993+
clear=False,
994+
):
995+
bound, args = self._make_bound()
996+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
997+
sync.return_value = None
998+
autotuner = bound.settings.autotuner_fn(bound, args)
999+
self.assertIsInstance(autotuner, StrictLocalAutotuneCache)
1000+
1001+
def test_autotune_cache_invalid_raises(self):
1002+
"""Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError."""
1003+
with patch.dict(
1004+
os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False
1005+
):
1006+
bound, args = self._make_bound()
1007+
with patch("torch.accelerator.synchronize", autospec=True) as sync:
1008+
sync.return_value = None
1009+
with self.assertRaisesRegex(
1010+
ValueError, "Unknown HELION_AUTOTUNE_CACHE"
1011+
):
1012+
bound.settings.autotuner_fn(bound, args)
1013+
1014+
9581015
if __name__ == "__main__":
9591016
unittest.main()

0 commit comments

Comments
 (0)