Skip to content

Commit 1c17f0e

Browse files
authored
Enable Chaining in Auto3DSeg CLI (#7168)
Fixes #7167 ### Description Make all the setting methods in `AutoRunner` to return `self` to enable chaining in cli. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
1 parent 798570c commit 1c17f0e

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

monai/apps/auto3dseg/auto_runner.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:
407407

408408
def set_gpu_customization(
409409
self, gpu_customization: bool = False, gpu_customization_specs: dict[str, Any] | None = None
410-
) -> None:
410+
) -> AutoRunner:
411411
"""
412412
Set options for GPU-based parameter customization/optimization.
413413
@@ -442,7 +442,9 @@ def set_gpu_customization(
442442
if gpu_customization_specs is not None:
443443
self.gpu_customization_specs = gpu_customization_specs
444444

445-
def set_num_fold(self, num_fold: int = 5) -> None:
445+
return self
446+
447+
def set_num_fold(self, num_fold: int = 5) -> AutoRunner:
446448
"""
447449
Set the number of cross validation folds for all algos.
448450
@@ -454,7 +456,9 @@ def set_num_fold(self, num_fold: int = 5) -> None:
454456
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
455457
self.num_fold = num_fold
456458

457-
def set_training_params(self, params: dict[str, Any] | None = None) -> None:
459+
return self
460+
461+
def set_training_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
458462
"""
459463
Set the training params for all algos.
460464
@@ -474,13 +478,15 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
474478
DeprecationWarning,
475479
)
476480

481+
return self
482+
477483
def set_device_info(
478484
self,
479485
cuda_visible_devices: list[int] | str | None = None,
480486
num_nodes: int | None = None,
481487
mn_start_method: str | None = None,
482488
cmd_prefix: str | None = None,
483-
) -> None:
489+
) -> AutoRunner:
484490
"""
485491
Set the device related info
486492
@@ -531,7 +537,9 @@ def set_device_info(
531537
if cmd_prefix is not None:
532538
logger.info(f"Using user defined command running prefix {cmd_prefix}, will override other settings")
533539

534-
def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
540+
return self
541+
542+
def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> AutoRunner:
535543
"""
536544
Set the bundle ensemble method name and parameters for save image transform parameters.
537545
@@ -546,7 +554,9 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol
546554
)
547555
self.kwargs.update(kwargs)
548556

549-
def set_image_save_transform(self, **kwargs: Any) -> None:
557+
return self
558+
559+
def set_image_save_transform(self, **kwargs: Any) -> AutoRunner:
550560
"""
551561
Set the ensemble output transform.
552562
@@ -565,7 +575,9 @@ def set_image_save_transform(self, **kwargs: Any) -> None:
565575
"Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information."
566576
)
567577

568-
def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
578+
return self
579+
580+
def set_prediction_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
569581
"""
570582
Set the prediction params for all algos.
571583
@@ -581,7 +593,9 @@ def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
581593
"""
582594
self.pred_params = deepcopy(params) if params is not None else {}
583595

584-
def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
596+
return self
597+
598+
def set_analyze_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
585599
"""
586600
Set the data analysis extra params.
587601
@@ -595,7 +609,9 @@ def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
595609
else:
596610
self.analyze_params = deepcopy(params)
597611

598-
def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
612+
return self
613+
614+
def set_hpo_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
599615
"""
600616
Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle
601617
templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the
@@ -621,7 +637,9 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
621637
"""
622638
self.hpo_params = self.train_params if params is None else params
623639

624-
def set_nni_search_space(self, search_space):
640+
return self
641+
642+
def set_nni_search_space(self, search_space: dict[str, Any]) -> AutoRunner:
625643
"""
626644
Set the search space for NNI parameter search.
627645
@@ -638,6 +656,8 @@ def set_nni_search_space(self, search_space):
638656
self.search_space = search_space
639657
self.hpo_tasks = value_combinations
640658

659+
return self
660+
641661
def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
642662
"""
643663
Train the Algos in a sequential scheme. The order of training is randomized.

0 commit comments

Comments
 (0)