Skip to content

Commit 5c8e194

Browse files
authored
Set setstate/getstate methods to Config (#868)
1 parent 86c54e5 commit 5c8e194

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

helion/runtime/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def __eq__(self, other: object) -> bool:
111111
def __hash__(self) -> int:
112112
return hash(frozenset([(k, _list_to_tuple(v)) for k, v in self.config.items()]))
113113

114+
def __getstate__(self) -> dict[str, object]:
115+
return dict(self.config)
116+
117+
def __setstate__(self, state: dict[str, object]) -> None:
118+
self.config = dict(state)
119+
114120
def to_json(self) -> str:
115121
"""Convert the config to a JSON string."""
116122
return json.dumps(self.config, indent=2)

test/test_autotuner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
from pathlib import Path
7+
import pickle
78
import random
89
import tempfile
910
from types import SimpleNamespace
@@ -135,6 +136,21 @@ def test_save_load_config(self):
135136
self.assertEqual(config, loaded_config)
136137
self.assertExpectedJournal(config.to_json())
137138

139+
def test_config_pickle_roundtrip(self):
140+
config = helion.Config(
141+
block_sizes=[64, 64, 32],
142+
loop_orders=[[1, 0]],
143+
num_warps=4,
144+
num_stages=2,
145+
indexing="tensor_descriptor",
146+
extra_metadata={"nested": [1, 2, 3]},
147+
)
148+
restored = pickle.loads(pickle.dumps(config))
149+
self.assertIsInstance(restored, helion.Config)
150+
self.assertEqual(config, restored)
151+
self.assertIsNot(config, restored)
152+
self.assertIsNot(config.config, restored.config)
153+
138154
def test_run_fixed_config(self):
139155
@helion.kernel(
140156
config=helion.Config(

0 commit comments

Comments
 (0)