-
Notifications
You must be signed in to change notification settings - Fork 783
feat: add tuple support and exam_cb for callbacks #665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ca55033
89a68fe
e4c93fc
727a1d7
32e7bf4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| # Distributed under the MIT software license | ||
|
|
||
| import heapq | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import os | ||
|
|
@@ -78,6 +79,88 @@ | |
| _log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| _PROGRESS_CALLBACK_NAMES = ("bag", "stage", "step", "term", "metric") | ||
| _EXAM_CALLBACK_NAMES = ("bag", "stage", "step", "term", "gain") | ||
| _CallbackSpec = Callable[..., bool] | tuple[Callable[..., bool], ...] | ||
|
|
||
|
|
||
| def _classify_callback(callback): | ||
| if not callable(callback): | ||
| msg = "callback must be a callable or a tuple of callables" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| signature = inspect.signature(callback) | ||
| except (TypeError, ValueError) as exc: | ||
| msg = "callback must have an inspectable signature" | ||
| _log.error(msg) | ||
| raise ValueError(msg) from exc | ||
|
|
||
| has_metric = "metric" in signature.parameters | ||
| has_gain = "gain" in signature.parameters | ||
| if has_metric == has_gain: | ||
| msg = ( | ||
| "callback must accept either the progress signature " | ||
| "(*, bag, stage, step, term, metric) or the examination signature " | ||
| "(*, bag, stage, step, term, gain)" | ||
| ) | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| required_names = _PROGRESS_CALLBACK_NAMES if has_metric else _EXAM_CALLBACK_NAMES | ||
| missing_names = [ | ||
| name for name in required_names if name not in signature.parameters | ||
| ] | ||
| if missing_names: | ||
| msg = f"callback is missing required parameters: {missing_names}" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| signature.bind(**{name: None for name in required_names}) | ||
| except TypeError as exc: | ||
| msg = f"callback must be callable with keyword arguments {required_names}" | ||
| _log.error(msg) | ||
| raise ValueError(msg) from exc | ||
|
|
||
| return "progress" if has_metric else "exam" | ||
|
|
||
|
|
||
| def _normalize_callbacks(callback): | ||
| if callback is None: | ||
| return None, None | ||
|
|
||
| callbacks = callback if isinstance(callback, tuple) else (callback,) | ||
| if len(callbacks) == 0: | ||
| msg = "callback tuple cannot be empty" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
| if len(callbacks) > 2: | ||
| msg = "callback tuple can contain at most one progress callback and one examination callback" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
Comment on lines
+135
to
+142
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can get rid of these checks. Empty callback tuple should be allowed and return None, None which it does in the code below already. And if callbacks has more than two values then it will be detected by the duplicate checks below. |
||
|
|
||
| progress_callback = None | ||
| exam_callback = None | ||
| for callback_item in callbacks: | ||
| callback_type = _classify_callback(callback_item) | ||
| if callback_type == "progress": | ||
| if progress_callback is not None: | ||
| msg = "callback tuple cannot contain more than one progress callback" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
| progress_callback = callback_item | ||
| else: | ||
| if exam_callback is not None: | ||
| msg = "callback tuple cannot contain more than one examination callback" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
| exam_callback = callback_item | ||
|
|
||
| return progress_callback, exam_callback | ||
|
|
||
|
|
||
| class EBMExplanation(FeatureValueExplanation): | ||
| """Visualizes specifically for EBM.""" | ||
|
|
||
|
|
@@ -851,7 +934,8 @@ def fit( | |
| interaction_smoothing_rounds = 0 | ||
| early_stopping_rounds = 0 | ||
| early_stopping_tolerance = 0.0 | ||
| callback = None | ||
| progress_callback = None | ||
| exam_callback = None | ||
| min_samples_leaf = 0 | ||
| min_hessian = 0.0 | ||
| reg_alpha = 0.0 | ||
|
|
@@ -879,7 +963,7 @@ def fit( | |
| interaction_smoothing_rounds = self.interaction_smoothing_rounds | ||
| early_stopping_rounds = self.early_stopping_rounds | ||
| early_stopping_tolerance = self.early_stopping_tolerance | ||
| callback = self.callback | ||
| progress_callback, exam_callback = _normalize_callbacks(self.callback) | ||
| min_samples_leaf = self.min_samples_leaf | ||
| min_hessian = self.min_hessian | ||
| reg_alpha = self.reg_alpha | ||
|
|
@@ -1018,7 +1102,8 @@ def fit( | |
| shared, | ||
| ) | ||
|
|
||
| with nullcontext() if callback is None else SharedMemoryManager() as smm: | ||
| has_callback = progress_callback is not None or exam_callback is not None | ||
| with nullcontext() if not has_callback else SharedMemoryManager() as smm: | ||
| stop_flag: npt.NDArray[np.bool_] | None | ||
| if smm is not None: | ||
| shm = smm.SharedMemory(size=1) | ||
|
|
@@ -1034,7 +1119,8 @@ def fit( | |
| shm_name=shm_name, | ||
| bag_idx=idx, | ||
| stage=0, | ||
| callback=callback, | ||
| progress_callback=progress_callback, | ||
| exam_callback=exam_callback, | ||
| dataset=( | ||
| shared.name if shared.name is not None else shared.dataset | ||
| ), | ||
|
|
@@ -1274,7 +1360,8 @@ def fit( | |
| shm_name=shm_name, | ||
| bag_idx=idx, | ||
| stage=1, | ||
| callback=callback, | ||
| progress_callback=progress_callback, | ||
| exam_callback=exam_callback, | ||
| dataset=( | ||
| shared.name | ||
| if shared.name is not None | ||
|
|
@@ -1386,7 +1473,8 @@ def fit( | |
| shm_name=None, | ||
| bag_idx=0, | ||
| stage=-1, | ||
| callback=None, | ||
| progress_callback=None, | ||
| exam_callback=None, | ||
| dataset=shared.dataset, | ||
| intercept_rounds=develop.get_option("n_intercept_rounds_final"), | ||
| intercept_learning_rate=develop.get_option( | ||
|
|
@@ -3312,15 +3400,15 @@ class EBMModel(BaseEBM): | |
| tradeoff for the ensemble of models --- not the individual models --- a small | ||
| amount of overfitting of the individual models can improve the accuracy of | ||
| the ensemble as a whole. | ||
| callback : Optional[Callable[..., bool]], default=None | ||
| A user-defined function invoked after each progressing boosting step. Must use | ||
| keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. | ||
| If it returns True, boosting is stopped immediately. | ||
| The callback receives: ``bag`` (int) the outer bag index, | ||
| ``stage`` (int) the boosting stage (0=mains, 1=pairs), | ||
| ``step`` (int) the number of boosting steps completed, | ||
| ``term`` (int) the index of the term that was just boosted, | ||
| and ``metric`` (float) the current validation metric. | ||
| callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None | ||
| A user-defined callback or tuple of callbacks invoked during boosting. | ||
| A progress callback is invoked after each progressing boosting step and must use | ||
| keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. | ||
| An examination callback is invoked whenever a term is examined and its gain is | ||
| calculated, and must use keyword-only arguments: | ||
| ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, | ||
| boosting is stopped immediately. A tuple can contain at most one progress callback | ||
| and one examination callback. | ||
| min_samples_leaf : int, default=4 | ||
| Minimum number of samples allowed in the leaves. | ||
| min_hessian : float, default=0.0 | ||
|
|
@@ -3431,13 +3519,13 @@ def __init__( | |
| # Boosting | ||
| learning_rate: float = 0.02, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The float type hint includes int, so remove int and the noqa |
||
| smoothing_rounds: int | None = 100, | ||
| interaction_smoothing_rounds: int | None = 50, | ||
| max_rounds: int | None = 50000, | ||
| early_stopping_rounds: int | None = 100, | ||
| early_stopping_tolerance: float | None = 1e-5, | ||
| callback: Callable[..., bool] | None = None, | ||
| callback: _CallbackSpec | None = None, | ||
| # Trees | ||
| min_samples_leaf: int | None = 4, | ||
| min_hessian: float | None = 0.0, | ||
|
|
@@ -3577,15 +3665,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel): | |
| tradeoff for the ensemble of models --- not the individual models --- a small | ||
| amount of overfitting of the individual models can improve the accuracy of | ||
| the ensemble as a whole. | ||
| callback : Optional[Callable[..., bool]], default=None | ||
| A user-defined function invoked after each progressing boosting step. Must use | ||
| keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. | ||
| If it returns True, boosting is stopped immediately. | ||
| The callback receives: ``bag`` (int) the outer bag index, | ||
| ``stage`` (int) the boosting stage (0=mains, 1=pairs), | ||
| ``step`` (int) the number of boosting steps completed, | ||
| ``term`` (int) the index of the term that was just boosted, | ||
| and ``metric`` (float) the current validation metric. | ||
| callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None | ||
| A user-defined callback or tuple of callbacks invoked during boosting. | ||
| A progress callback is invoked after each progressing boosting step and must use | ||
| keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. | ||
| An examination callback is invoked whenever a term is examined and its gain is | ||
| calculated, and must use keyword-only arguments: | ||
| ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, | ||
| boosting is stopped immediately. A tuple can contain at most one progress callback | ||
| and one examination callback. | ||
| min_samples_leaf : int, default=4 | ||
| Minimum number of samples allowed in the leaves. | ||
| min_hessian : float, default=1e-4 | ||
|
|
@@ -3755,13 +3843,13 @@ def __init__( | |
| # Boosting | ||
| learning_rate: float = 0.015, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove int since float includes the int type hint, and also remove noqa |
||
| smoothing_rounds: int | None = 75, | ||
| interaction_smoothing_rounds: int | None = 75, | ||
| max_rounds: int | None = 50000, | ||
| early_stopping_rounds: int | None = 100, | ||
| early_stopping_tolerance: float | None = 1e-5, | ||
| callback: Callable[..., bool] | None = None, | ||
| callback: _CallbackSpec | None = None, | ||
| # Trees | ||
| min_samples_leaf: int | None = 4, | ||
| min_hessian: float | None = 1e-4, | ||
|
|
@@ -3903,15 +3991,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel): | |
| tradeoff for the ensemble of models --- not the individual models --- a small | ||
| amount of overfitting of the individual models can improve the accuracy of | ||
| the ensemble as a whole. | ||
| callback : Optional[Callable[..., bool]], default=None | ||
| A user-defined function invoked after each progressing boosting step. Must use | ||
| keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. | ||
| If it returns True, boosting is stopped immediately. | ||
| The callback receives: ``bag`` (int) the outer bag index, | ||
| ``stage`` (int) the boosting stage (0=mains, 1=pairs), | ||
| ``step`` (int) the number of boosting steps completed, | ||
| ``term`` (int) the index of the term that was just boosted, | ||
| and ``metric`` (float) the current validation metric. | ||
| callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None | ||
| A user-defined callback or tuple of callbacks invoked during boosting. | ||
| A progress callback is invoked after each progressing boosting step and must use | ||
| keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``. | ||
| An examination callback is invoked whenever a term is examined and its gain is | ||
| calculated, and must use keyword-only arguments: | ||
| ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True, | ||
| boosting is stopped immediately. A tuple can contain at most one progress callback | ||
| and one examination callback. | ||
| min_samples_leaf : int, default=4 | ||
| Minimum number of samples allowed in the leaves. | ||
| min_hessian : float, default=0.0 | ||
|
|
@@ -4085,13 +4173,13 @@ def __init__( | |
| # Boosting | ||
| learning_rate: float = 0.04, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above. remove int and noqa |
||
| smoothing_rounds: int | None = 500, | ||
| interaction_smoothing_rounds: int | None = 100, | ||
| max_rounds: int | None = 50000, | ||
| early_stopping_rounds: int | None = 100, | ||
| early_stopping_tolerance: float | None = 1e-5, | ||
| callback: Callable[..., bool] | None = None, | ||
| callback: _CallbackSpec | None = None, | ||
| # Trees | ||
| min_samples_leaf: int | None = 4, | ||
| min_hessian: float | None = 0.0, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,8 @@ def boost( | |
| shm_name, | ||
| bag_idx, | ||
| stage, | ||
| callback, | ||
| progress_callback, | ||
| exam_callback, | ||
| dataset, | ||
| intercept_rounds, | ||
| intercept_learning_rate, | ||
|
|
@@ -264,6 +265,19 @@ def boost( | |
| # penalize nominals a bit because they benefit from sorting categories | ||
| avg_gain *= gain_scale | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should check here |
||
| if exam_callback is not None: | ||
| is_done = exam_callback( | ||
| bag=bag_idx, | ||
| stage=stage, | ||
| step=step_idx, | ||
| term=term_idx, | ||
| gain=avg_gain, | ||
| ) | ||
| if is_done: | ||
| if stop_flag is not None: | ||
| stop_flag[0] = True | ||
| break | ||
|
|
||
| gainkey = (-avg_gain, native.generate_seed(rng), term_idx) | ||
| if not make_progress and ( | ||
| bestkey is None or gainkey < bestkey | ||
|
|
@@ -368,8 +382,8 @@ def boost( | |
| if stop_flag is not None and stop_flag[0]: | ||
| break | ||
|
|
||
| if callback is not None: | ||
| is_done = callback( | ||
| if progress_callback is not None: | ||
| is_done = progress_callback( | ||
| bag=bag_idx, | ||
| stage=stage, | ||
| step=step_idx, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably simpler and more expandable as something like:
callback_types = {"progress" : {"bag", "stage", "step", "term", "metric"}, "exam": {"bag", "stage", "step", "term", "gain"}}
param_names = set(inspect.signature(callback).parameters)
for name, params in callback_types.items():
if params == param_names:
return name
raise something