11"""
2- Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`_.
2+ Module for hyperparameter optimization.
3+
4+ Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`.
35"""
46
57__all__ = ["optimize_hyperparameters" ]
@@ -87,8 +89,9 @@ def optimize_hyperparameters(
8789 ** kwargs : Any ,
8890) -> optuna .Study :
8991 """
90- Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
91- PyTorch Lightning learning rate finder.
92+ Optimize hyperparameters. Run hyperparameter optimization.
93+
94+ Learning rate for is determined with the PyTorch Lightning learning rate finder.
9295
9396 Args:
9497 train_dataloaders (DataLoader):
@@ -98,65 +101,68 @@ def optimize_hyperparameters(
98101 model_path (str):
99102 Folder to which model checkpoints are saved.
100103 monitor (str):
101- Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
102- reads this metric to score configuration. By default, the lower the better.
104+ Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config,
105+ and reads this metric to score configuration. By default, the lower the better.
103106 direction (str):
104- By default, direction is "minimize", meaning that lower values of the specified `monitor` are
105- better. You can change this, e.g. to "maximize".
107+ By default, direction is "minimize", meaning that lower values of the specified
108+ ``monitor`` are better. You can change this, e.g. to "maximize".
106109 max_epochs (int, optional):
107110 Maximum number of epochs to run training. Defaults to 20.
108111 n_trials (int, optional):
109112 Number of hyperparameter trials to run. Defaults to 100.
110113 timeout (float, optional):
111- Time in seconds after which training is stopped regardless of number of epochs or validation
112- metric. Defaults to 3600*8.0.
114+ Time in seconds after which training is stopped regardless of number of epochs or
115+ validation metric. Defaults to 3600*8.0.
113116 input_params (dict, optional):
114- A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
115- ` "ranges"`. Example:
116- >>> {"hidden_size": {
117- >>> "method": "suggest_int",
118- >>> "ranges": (16, 265),
119- >>> }}
120- The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
121- ranges for the specified method.
117+ A dictionary, where each `` key`` contains another dictionary with two keys: `` "method"``
118+ and `` "ranges"` `. Example:
119+ >>> {"hidden_size": {
120+ "method": "suggest_int",
121+ "ranges": (16, 265),
122+ }}
123+ The method key has to be a method of the `` optuna.Trial`` object.
124+ The ranges key are the input ranges for the specified method.
122125 input_params_generator (Callable, optional):
123- A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
124- `, returning the parameter values to set up your model for the current trial/run.
126+ A function with the following signature:
127+ `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`,
128+ returning the parameter values to set up your model for the current trial/run.
125129 Example:
126- >>> def fn(trial: optuna.Trial , param_ranges: Tuple[int, int] = (16, 265)) -> Dict[str, Any]:
127- >>> param = trial.suggest_int("param", *param_ranges, log=True)
128- >>> model_params = {"param": param}
129- >>> return model_params
130- Then, when your model is created (before training it and report the metrics for the current
131- combination of hyperparameters), these dictionary is used as follows:
132- >>> model = YourModelClass.from_dataset(
133- >>> train_dataloaders.dataset,
134- >>> log_interval=-1,
135- >>> **model_params,
136- >>> )
130+ >>> def fn(trial, param_ranges = (16, 265)) -> Dict[str, Any]:
131+ param = trial.suggest_int("param", *param_ranges, log=True)
132+ model_params = {"param": param}
133+ return model_params
134+ Then, when your model is created (before training it and report the metrics for
135+ the current combination of hyperparameters), these dictionary is used as follows:
136+ >>> model = YourModelClass.from_dataset(
137+ train_dataloaders.dataset,
138+ log_interval=-1,
139+ **model_params,
140+ )
137141 generator_params (dict, optional):
138- The additional parameters to be passed to the `input_params_generator` function, if required.
142+ The additional parameters to be passed to the ``input_params_generator`` function,
143+ if required.
139144 learning_rate_range (Tuple[float, float], optional):
140145 Learning rate range. Defaults to (1e-5, 1.0).
141146 use_learning_rate_finder (bool):
142- If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
147+ If to use learning rate finder or optimize as part of hyperparameters.
148+ Defaults to True.
143149 trainer_kwargs (Dict[str, Any], optional):
144150 Additional arguments to the
145- ` PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`
146- such as `limit_train_batches`. Defaults to {}.
151+ PyTorch Lightning trainer such as ``limit_train_batches``.
152+ Defaults to {}.
147153 log_dir (str, optional):
148154 Folder into which to log results for tensorboard. Defaults to "lightning_logs".
149155 study (optuna.Study, optional):
150156 Study to resume. Will create new study by default.
151157 verbose (Union[int, bool]):
152158 Level of verbosity.
153- * None: no change in verbosity level (equivalent to verbose=1 by optuna-set default ).
159+ * None: no change in verbosity level (equivalent to verbose=1).
154160 * 0 or False: log only warnings.
155161 * 1 or True: log pruning events.
156162 * 2: optuna logging level at debug level.
157163 Defaults to None.
158164 pruner (optuna.pruners.BasePruner, optional):
159- The optuna pruner to use. Defaults to `optuna.pruners.SuccessiveHalvingPruner()`.
165+ The optuna pruner to use. Defaults to `` optuna.pruners.SuccessiveHalvingPruner()` `.
160166 **kwargs:
161167 Additional arguments for your model's class.
162168
0 commit comments