Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/hyperactive/base/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ class BaseExperiment(BaseObject):

def __init__(self):
super().__init__()
self._callbacks_pre = []
self._callbacks_post = []

def add_callback(self, callback, pre=False):
"""Register a callback.

Parameters
----------
callback : callable
For post callbacks: callback(experiment, params, result, metadata).
For pre callbacks: callback(experiment, params).
pre : bool, default=False
If True, callback runs before evaluation. If False, after.
"""
if pre:
self._callbacks_pre.append(callback)
else:
self._callbacks_post.append(callback)

def __call__(self, params):
"""Score parameters. Same as score call, returns only a first element."""
Expand Down Expand Up @@ -77,10 +95,24 @@ def evaluate(self, params):
f"Parameters passed to {type(self)}.evaluate do not match: "
f"expected {paramnames}, got {list(params.keys())}."
)

self._run_callbacks_pre(params)
res, metadata = self._evaluate(params)
res = np.float64(res)
self._run_callbacks_post(params, res, metadata)

return res, metadata

def _run_callbacks_pre(self, params):
"""Run pre-evaluation callbacks."""
for callback in self._callbacks_pre:
callback(self, params)

def _run_callbacks_post(self, params, result, metadata):
"""Run post-evaluation callbacks."""
for callback in self._callbacks_post:
callback(self, params, result, metadata)

def _evaluate(self, params):
"""Evaluate the parameters.

Expand Down
33 changes: 33 additions & 0 deletions src/hyperactive/base/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,32 @@ class BaseOptimizer(BaseObject):

def __init__(self):
super().__init__()
self._callbacks_pre_solve = []
self._callbacks_post_solve = []

assert hasattr(self, "experiment"), "Optimizer must have an experiment."
search_config = self.get_params()
self._experiment = search_config.pop("experiment", None)

if self.get_tag("info:name") is None:
self.set_tags(**{"info:name": self.__class__.__name__})

def add_callback(self, callback, pre=False):
"""Register a callback.

Parameters
----------
callback : callable
For post callbacks: callback(optimizer, best_params).
For pre callbacks: callback(optimizer).
pre : bool, default=False
If True, callback runs before solve. If False, after.
"""
if pre:
self._callbacks_pre_solve.append(callback)
else:
self._callbacks_post_solve.append(callback)

def get_search_config(self):
"""Get the search configuration.

Expand Down Expand Up @@ -76,13 +95,27 @@ def solve(self):
The dict ``best_params`` can be used in ``experiment.score`` or
``experiment.evaluate`` directly.
"""
self._run_callbacks_pre_solve()

experiment = self.get_experiment()
search_config = self.get_search_config()

best_params = self._solve(experiment, **search_config)
self.best_params_ = best_params
self._run_callbacks_post_solve(best_params)

return best_params

def _run_callbacks_pre_solve(self):
"""Run pre-solve callbacks."""
for callback in self._callbacks_pre_solve:
callback(self)

def _run_callbacks_post_solve(self, best_params):
"""Run post-solve callbacks."""
for callback in self._callbacks_post_solve:
callback(self, best_params)

def _solve(self, experiment, *args, **kwargs):
"""Run the optimization search process.

Expand Down
175 changes: 175 additions & 0 deletions src/hyperactive/base/tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Tests for callback functionality."""

# copyright: hyperactive developers, MIT License (see LICENSE file)

import numpy as np

def test_experiment_callbacks_post():
"""Test that post-evaluation callbacks are called."""
from hyperactive.experiment.bench import Sphere

callback_calls = []

def track_callback(exp, params, result, metadata):
callback_calls.append(
{
"params": params.copy(),
"result": result,
}
)

exp = Sphere(n_dim=2)
exp.add_callback(track_callback)

exp.evaluate({"x0": 0.0, "x1": 0.0})
exp.evaluate({"x0": 1.0, "x1": 1.0})
exp.evaluate({"x0": 2.0, "x1": 2.0})

assert len(callback_calls) == 3
assert callback_calls[0]["params"] == {"x0": 0.0, "x1": 0.0}
assert callback_calls[1]["params"] == {"x0": 1.0, "x1": 1.0}
assert callback_calls[2]["params"] == {"x0": 2.0, "x1": 2.0}

def test_experiment_callbacks_pre():
"""Test that pre-evaluation callbacks are called."""
from hyperactive.experiment.bench import Sphere

pre_calls = []

def pre_callback(exp, params):
pre_calls.append(params.copy())

exp = Sphere(n_dim=2)
exp.add_callback(pre_callback, pre=True)

exp.evaluate({"x0": 1.0, "x1": 2.0})

assert len(pre_calls) == 1
assert pre_calls[0] == {"x0": 1.0, "x1": 2.0}

def test_optimizer_callbacks():
"""Test optimizer pre/post solve callbacks."""
from hyperactive.experiment.bench import Sphere
from hyperactive.opt import HillClimbing

pre_solve_called = []
post_solve_called = []

def pre_solve_cb(optimizer):
pre_solve_called.append(True)

def post_solve_cb(optimizer, best_params):
post_solve_called.append(best_params)

exp = Sphere(n_dim=2)
optimizer = HillClimbing(
search_space={
"x0": np.linspace(-5, 5, 11),
"x1": np.linspace(-5, 5, 11),
},
n_iter=10,
experiment=exp,
)
optimizer.add_callback(pre_solve_cb, pre=True)
optimizer.add_callback(post_solve_cb)

best_params = optimizer.solve()

assert len(pre_solve_called) == 1
assert len(post_solve_called) == 1
assert post_solve_called[0] == best_params

def test_history_callback():
"""Test HistoryCallback records evaluations."""
from hyperactive.experiment.bench import Sphere
from hyperactive.opt import HillClimbing
from hyperactive.utils.callbacks import HistoryCallback

history_cb = HistoryCallback()
exp = Sphere(n_dim=2)
exp.add_callback(history_cb)

optimizer = HillClimbing(
search_space={
"x0": np.linspace(-5, 5, 11),
"x1": np.linspace(-5, 5, 11),
},
n_iter=10,
experiment=exp,
)

optimizer.solve()

assert len(history_cb.history) >= 10
for record in history_cb.history:
assert "params" in record
assert "result" in record
assert "metadata" in record

best = history_cb.get_best(higher_is_better=False)
assert best is not None
assert "result" in best

def test_logging_callback(capsys):
"""Test LoggingCallback prints to stdout."""
from hyperactive.experiment.bench import Sphere
from hyperactive.utils.callbacks import LoggingCallback

exp = Sphere(n_dim=2)
exp.add_callback(LoggingCallback())

exp.evaluate({"x0": 0.0, "x1": 0.0})
exp.evaluate({"x0": 1.0, "x1": 1.0})

captured = capsys.readouterr()
assert "Eval 1:" in captured.out
assert "Eval 2:" in captured.out

def test_sleep_callback():
"""Test SleepCallback adds delay."""
import time

from hyperactive.experiment.bench import Sphere
from hyperactive.utils.callbacks import SleepCallback

exp = Sphere(n_dim=2)
exp.add_callback(SleepCallback(seconds=0.1))

start = time.time()
exp.evaluate({"x0": 0.0, "x1": 0.0})
exp.evaluate({"x0": 1.0, "x1": 1.0})
elapsed = time.time() - start

assert elapsed >= 0.2

def test_target_reached_callback():
"""Test TargetReachedCallback tracks target."""
from hyperactive.experiment.bench import Sphere
from hyperactive.utils.callbacks import TargetReachedCallback

target_cb = TargetReachedCallback(target_score=0.5, higher_is_better=False)
exp = Sphere(n_dim=2)
exp.add_callback(target_cb)

exp.evaluate({"x0": 5.0, "x1": 5.0})
assert not target_cb.reached

exp.evaluate({"x0": 0.0, "x1": 0.0})
assert target_cb.reached

def test_multiple_callbacks():
"""Test that multiple callbacks can be registered."""
from hyperactive.experiment.bench import Sphere
from hyperactive.utils.callbacks import HistoryCallback, LoggingCallback

history_cb = HistoryCallback()
log_cb = LoggingCallback()
exp = Sphere(n_dim=2)
exp.add_callback(history_cb)
exp.add_callback(log_cb)

exp.evaluate({"x0": 1.0, "x1": 1.0})
exp.evaluate({"x0": 2.0, "x1": 2.0})

assert len(history_cb.history) == 2
assert log_cb._count == 2
10 changes: 10 additions & 0 deletions src/hyperactive/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""Utility functionality."""

from hyperactive.utils.callbacks import (
HistoryCallback,
LoggingCallback,
SleepCallback,
TargetReachedCallback,
)
from hyperactive.utils.estimator_checks import check_estimator

__all__ = [
"check_estimator",
"HistoryCallback",
"LoggingCallback",
"SleepCallback",
"TargetReachedCallback",
]
Loading