Skip to content

Commit 0caf09e

Browse files
Add DE-Surrogate hybrid autotuner algorithm + early stopping option for DE and DE-Surrogate (#1096)
1 parent 913f7c7 commit 0caf09e

File tree

10 files changed

+559
-6
lines changed

10 files changed

+559
-6
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ jobs:
100100
- name: Install Helion
101101
run: |
102102
source .venv/bin/activate
103-
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]'
103+
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
104104
python -c "import helion; print(helion.__name__)"
105105
106106
- name: Install Benchmark Requirements

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
run: |
147147
source .venv/bin/activate
148148
uv pip install setuptools ninja
149-
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]'
149+
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]'
150150
python -c "import helion; print(helion.__name__)"
151151
152152
- name: Run Tests

helion/autotuner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .config_fragment import ListOf as ListOf
77
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
88
from .config_spec import ConfigSpec as ConfigSpec
9+
from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid
910
from .differential_evolution import (
1011
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
1112
)
@@ -20,6 +21,7 @@
2021
from .random_search import RandomSearch as RandomSearch
2122

2223
search_algorithms = {
24+
"DESurrogateHybrid": DESurrogateHybrid,
2325
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
2426
"FiniteSearch": FiniteSearch,
2527
"PatternSearch": PatternSearch,

helion/autotuner/base_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ class PopulationMember:
624624
perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks.
625625
flat_values (FlatConfig): The flat representation of the configuration values.
626626
config (Config): The full configuration object.
627+
compile_time (float | None): The compilation time for this configuration.
627628
"""
628629

629630
fn: Callable[..., object]

helion/autotuner/config_fragment.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ def get_minimum(self) -> int:
5757
"""
5858
raise NotImplementedError
5959

60+
def encode_scalar(self, value: object) -> float:
61+
"""
62+
Encode a configuration value into a float for ML models.
63+
64+
This is used by surrogate-assisted algorithms to convert configurations
65+
into numerical vectors for prediction models.
66+
67+
Args:
68+
value: The configuration value to encode.
69+
70+
Returns:
71+
A float representing the encoded value.
72+
"""
73+
# Default: convert to float if possible
74+
if not isinstance(value, (int, float, bool)):
75+
raise TypeError(
76+
f"Cannot encode {type(value).__name__} value {value!r} for ML"
77+
)
78+
return float(value)
79+
6080

6181
@dataclasses.dataclass
6282
class PermutationFragment(ConfigSpecFragment):
@@ -121,6 +141,14 @@ def pattern_neighbors(self, current: object) -> list[object]:
121141
neighbors.append(upper)
122142
return neighbors
123143

144+
def encode_scalar(self, value: object) -> float:
145+
"""Encode integer values directly as floats."""
146+
if not isinstance(value, (int, float)):
147+
raise TypeError(
148+
f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}"
149+
)
150+
return float(value)
151+
124152

125153
class PowerOfTwoFragment(BaseIntegerFragment):
126154
def random(self) -> int:
@@ -152,6 +180,20 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
152180
return self.clamp(ai * 2)
153181
return ai
154182

183+
def encode_scalar(self, value: object) -> float:
184+
"""Encode power-of-2 values using log2 transformation."""
185+
import math
186+
187+
if not isinstance(value, (int, float)):
188+
raise TypeError(
189+
f"Expected int/float for PowerOfTwoFragment, got {type(value).__name__}: {value!r}"
190+
)
191+
if value <= 0:
192+
raise ValueError(
193+
f"Expected positive value for PowerOfTwoFragment, got {value}"
194+
)
195+
return math.log2(float(value))
196+
155197

156198
class IntegerFragment(BaseIntegerFragment):
157199
def random(self) -> int:
@@ -193,6 +235,17 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
193235
choices.remove(a)
194236
return random.choice(choices)
195237

238+
def encode_scalar(self, value: object) -> float:
239+
"""Encode enum values as their index."""
240+
try:
241+
choice_idx = self.choices.index(value)
242+
except ValueError:
243+
raise ValueError(
244+
f"Invalid enum value {value!r} for EnumFragment. "
245+
f"Valid choices: {self.choices}"
246+
) from None
247+
return float(choice_idx)
248+
196249

197250
class BooleanFragment(ConfigSpecFragment):
198251
def default(self) -> bool:

helion/autotuner/config_generation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,24 @@ def differential_mutation(
181181
# TODO(jansel): can this be larger? (too large and Triton compile times blow up)
182182
self.shrink_config(result, 8192)
183183
return result
184+
185+
def encode_config(self, flat_config: FlatConfig) -> list[float]:
186+
"""
187+
Encode a flat configuration into a numerical vector for ML models.
188+
189+
This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need
190+
to represent configurations as continuous vectors for prediction models.
191+
192+
Args:
193+
flat_config: The flat configuration values to encode.
194+
195+
Returns:
196+
A list of floats representing the encoded configuration.
197+
"""
198+
encoded: list[float] = []
199+
200+
for flat_idx, spec in enumerate(self.flat_spec):
201+
value = flat_config[flat_idx]
202+
encoded.append(spec.encode_scalar(value))
203+
204+
return encoded

0 commit comments

Comments
 (0)