|
7 | 7 | import logging |
8 | 8 | import math |
9 | 9 | import multiprocessing as mp |
| 10 | +import operator |
10 | 11 | import os |
11 | 12 | from pathlib import Path |
12 | 13 | import pickle |
|
41 | 42 | from helion.autotuner.config_generation import ConfigGeneration |
42 | 43 | from helion.autotuner.effort_profile import get_effort_profile |
43 | 44 | from helion.autotuner.finite_search import FiniteSearch |
| 45 | +from helion.autotuner.local_cache import LocalAutotuneCache |
| 46 | +from helion.autotuner.local_cache import StrictLocalAutotuneCache |
44 | 47 | from helion.autotuner.logger import LambdaLogger |
45 | 48 | from helion.autotuner.random_search import RandomSearch |
46 | 49 | import helion.language as hl |
@@ -955,5 +958,59 @@ def test_autotune_random_seed_from_settings(self) -> None: |
955 | 958 | self.assertNotEqual(first, second) |
956 | 959 |
|
957 | 960 |
|
| 961 | +class TestAutotuneCacheSelection(TestCase): |
| 962 | + """Selection of the autotune cache via HELION_AUTOTUNE_CACHE.""" |
| 963 | + |
| 964 | + def _make_bound(self): |
| 965 | + @helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0) |
| 966 | + def add(a: torch.Tensor, b: torch.Tensor): |
| 967 | + out = torch.empty_like(a) |
| 968 | + for tile in hl.tile(out.size()): |
| 969 | + out[tile] = a[tile] + b[tile] |
| 970 | + return out |
| 971 | + |
| 972 | + args = ( |
| 973 | + torch.randn([8], device=DEVICE), |
| 974 | + torch.randn([8], device=DEVICE), |
| 975 | + ) |
| 976 | + return add.bind(args), args |
| 977 | + |
| 978 | + def test_autotune_cache_default_is_local(self): |
| 979 | + """Default (no env var set) -> LocalAutotuneCache.""" |
| 980 | + with without_env_var("HELION_AUTOTUNE_CACHE"): |
| 981 | + bound, args = self._make_bound() |
| 982 | + with patch("torch.accelerator.synchronize", autospec=True) as sync: |
| 983 | + sync.return_value = None |
| 984 | + autotuner = bound.settings.autotuner_fn(bound, args) |
| 985 | + self.assertIsInstance(autotuner, LocalAutotuneCache) |
| 986 | + self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache) |
| 987 | + |
| 988 | + def test_autotune_cache_strict_selected_by_env(self): |
| 989 | + """HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache.""" |
| 990 | + with patch.dict( |
| 991 | + os.environ, |
| 992 | + {"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"}, |
| 993 | + clear=False, |
| 994 | + ): |
| 995 | + bound, args = self._make_bound() |
| 996 | + with patch("torch.accelerator.synchronize", autospec=True) as sync: |
| 997 | + sync.return_value = None |
| 998 | + autotuner = bound.settings.autotuner_fn(bound, args) |
| 999 | + self.assertIsInstance(autotuner, StrictLocalAutotuneCache) |
| 1000 | + |
| 1001 | + def test_autotune_cache_invalid_raises(self): |
| 1002 | + """Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError.""" |
| 1003 | + with patch.dict( |
| 1004 | + os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False |
| 1005 | + ): |
| 1006 | + bound, args = self._make_bound() |
| 1007 | + with patch("torch.accelerator.synchronize", autospec=True) as sync: |
| 1008 | + sync.return_value = None |
| 1009 | + with self.assertRaisesRegex( |
| 1010 | + ValueError, "Unknown HELION_AUTOTUNE_CACHE" |
| 1011 | + ): |
| 1012 | + bound.settings.autotuner_fn(bound, args) |
| 1013 | + |
| 1014 | + |
958 | 1015 | if __name__ == "__main__": |
959 | 1016 | unittest.main() |
0 commit comments