Skip to content

Commit 09284ad

Browse files
Add ConfigEncoder for ML-based autotuners
ConfigEncoder converts Helion's discrete configurations into numerical vectors suitable for machine learning models like Random Forests and Gaussian Processes. This is a required dependency for DESurrogateHybrid and other ML-assisted autotuners. It handles: - Power-of-2 values with log2 encoding - Categorical variables with one-hot encoding - Proper bounds computation for optimization
1 parent 5535177 commit 09284ad

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import TYPE_CHECKING
5+
6+
import numpy as np
7+
8+
from .config_fragment import Category
9+
10+
if TYPE_CHECKING:
11+
from .config_generation import ConfigGeneration
12+
from .config_generation import FlatConfig
13+
14+
15+
class ConfigEncoder:
16+
"""
17+
Encodes Helion configurations into numerical vectors for Gaussian Process models.
18+
19+
Handles various config types:
20+
- Power-of-2 values: log2 encoding
21+
- Integers: direct encoding with normalization
22+
- Booleans: 0/1 encoding
23+
- Enums: one-hot encoding
24+
- Permutations: inversion count encoding
25+
"""
26+
27+
def __init__(self, config_gen: ConfigGeneration) -> None:
28+
"""
29+
Initialize the encoder with a configuration generator.
30+
31+
Args:
32+
config_gen: The configuration generator containing the flat spec.
33+
"""
34+
self.config_gen = config_gen
35+
self.flat_spec = config_gen.flat_spec
36+
self._compute_encoding_metadata()
37+
38+
def _compute_encoding_metadata(self) -> None:
39+
"""Precompute metadata for encoding to determine output dimensionality."""
40+
self.encoded_dim = 0
41+
self.encoding_map: list[tuple[int, int, str]] = [] # (start_idx, end_idx, type)
42+
43+
for spec in self.flat_spec:
44+
category = spec.category()
45+
start_idx = self.encoded_dim
46+
47+
if category in {
48+
Category.BLOCK_SIZE,
49+
Category.NUM_WARPS,
50+
}:
51+
# Single numerical value
52+
self.encoded_dim += 1
53+
self.encoding_map.append((start_idx, self.encoded_dim, "numerical"))
54+
elif hasattr(spec, "choices"):
55+
# Enum - one-hot encoding
56+
num_choices = len(spec.choices) # type: ignore[no-untyped-call]
57+
self.encoded_dim += num_choices
58+
self.encoding_map.append((start_idx, self.encoded_dim, "enum"))
59+
else:
60+
# Boolean or other single value
61+
self.encoded_dim += 1
62+
self.encoding_map.append((start_idx, self.encoded_dim, "numerical"))
63+
64+
def encode(self, flat_config: FlatConfig) -> np.ndarray:
65+
"""
66+
Convert a flat configuration to a numerical vector.
67+
68+
Args:
69+
flat_config: The flat configuration values.
70+
71+
Returns:
72+
A numpy array suitable for GP training.
73+
"""
74+
encoded = np.zeros(self.encoded_dim, dtype=np.float64)
75+
76+
for flat_idx, spec in enumerate(self.flat_spec):
77+
value = flat_config[flat_idx]
78+
category = spec.category()
79+
enc_start, enc_end, enc_type = self.encoding_map[flat_idx]
80+
81+
if enc_type == "numerical":
82+
if category in {Category.BLOCK_SIZE, Category.NUM_WARPS}:
83+
# Power-of-2: use log2 encoding
84+
if isinstance(value, (int, float)) and value > 0:
85+
encoded[enc_start] = math.log2(float(value))
86+
else:
87+
encoded[enc_start] = 0.0
88+
else:
89+
# Other numerical: direct encoding
90+
encoded[enc_start] = (
91+
float(value) if isinstance(value, (int, float)) else 0.0
92+
)
93+
elif enc_type == "enum":
94+
# One-hot encoding
95+
if hasattr(spec, "choices"):
96+
choices = spec.choices # type: ignore[attr-defined]
97+
try:
98+
choice_idx = choices.index(value)
99+
encoded[enc_start + choice_idx] = 1.0
100+
except (ValueError, IndexError):
101+
# Default to first choice if value not found
102+
encoded[enc_start] = 1.0
103+
104+
return encoded
105+
106+
def get_bounds(self) -> list[tuple[float, float]]:
107+
"""
108+
Get bounds for each encoded dimension.
109+
110+
Returns:
111+
List of (min, max) tuples for each dimension.
112+
"""
113+
bounds: list[tuple[float, float]] = []
114+
115+
for flat_idx, spec in enumerate(self.flat_spec):
116+
category = spec.category()
117+
enc_start, enc_end, enc_type = self.encoding_map[flat_idx]
118+
119+
if enc_type == "numerical":
120+
if category in {Category.BLOCK_SIZE, Category.NUM_WARPS}:
121+
# Power-of-2: log2 bounds
122+
min_val = math.log2(float(spec.low)) # type: ignore[attr-defined]
123+
max_val = math.log2(float(spec.high)) # type: ignore[attr-defined]
124+
bounds.append((min_val, max_val))
125+
else:
126+
# Other numerical bounds
127+
bounds.append(
128+
(float(spec.low), float(spec.high)) # type: ignore[attr-defined]
129+
)
130+
elif enc_type == "enum":
131+
# One-hot: each dimension is 0 or 1
132+
num_choices = enc_end - enc_start
133+
bounds.extend([(0.0, 1.0)] * num_choices)
134+
135+
return bounds

0 commit comments

Comments
 (0)