@@ -407,7 +407,7 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:
407407
408408 def set_gpu_customization (
409409 self , gpu_customization : bool = False , gpu_customization_specs : dict [str , Any ] | None = None
410- ) -> None :
410+ ) -> AutoRunner :
411411 """
412412 Set options for GPU-based parameter customization/optimization.
413413
@@ -442,7 +442,9 @@ def set_gpu_customization(
442442 if gpu_customization_specs is not None :
443443 self .gpu_customization_specs = gpu_customization_specs
444444
445- def set_num_fold (self , num_fold : int = 5 ) -> None :
445+ return self
446+
447+ def set_num_fold (self , num_fold : int = 5 ) -> AutoRunner :
446448 """
447449 Set the number of cross validation folds for all algos.
448450
@@ -454,7 +456,9 @@ def set_num_fold(self, num_fold: int = 5) -> None:
454456 raise ValueError (f"num_fold is expected to be an integer greater than zero. Now it gets { num_fold } " )
455457 self .num_fold = num_fold
456458
457- def set_training_params (self , params : dict [str , Any ] | None = None ) -> None :
459+ return self
460+
461+ def set_training_params (self , params : dict [str , Any ] | None = None ) -> AutoRunner :
458462 """
459463 Set the training params for all algos.
460464
@@ -474,13 +478,15 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
474478 DeprecationWarning ,
475479 )
476480
481+ return self
482+
477483 def set_device_info (
478484 self ,
479485 cuda_visible_devices : list [int ] | str | None = None ,
480486 num_nodes : int | None = None ,
481487 mn_start_method : str | None = None ,
482488 cmd_prefix : str | None = None ,
483- ) -> None :
489+ ) -> AutoRunner :
484490 """
485491 Set the device related info
486492
@@ -531,7 +537,9 @@ def set_device_info(
531537 if cmd_prefix is not None :
532538 logger .info (f"Using user defined command running prefix { cmd_prefix } , will override other settings" )
533539
534- def set_ensemble_method (self , ensemble_method_name : str = "AlgoEnsembleBestByFold" , ** kwargs : Any ) -> None :
540+ return self
541+
542+ def set_ensemble_method (self , ensemble_method_name : str = "AlgoEnsembleBestByFold" , ** kwargs : Any ) -> AutoRunner :
535543 """
536544 Set the bundle ensemble method name and parameters for save image transform parameters.
537545
@@ -546,7 +554,9 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol
546554 )
547555 self .kwargs .update (kwargs )
548556
549- def set_image_save_transform (self , ** kwargs : Any ) -> None :
557+ return self
558+
559+ def set_image_save_transform (self , ** kwargs : Any ) -> AutoRunner :
550560 """
551561 Set the ensemble output transform.
552562
@@ -565,7 +575,9 @@ def set_image_save_transform(self, **kwargs: Any) -> None:
565575 "Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information."
566576 )
567577
568- def set_prediction_params (self , params : dict [str , Any ] | None = None ) -> None :
578+ return self
579+
580+ def set_prediction_params (self , params : dict [str , Any ] | None = None ) -> AutoRunner :
569581 """
570582 Set the prediction params for all algos.
571583
@@ -581,7 +593,9 @@ def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
581593 """
582594 self .pred_params = deepcopy (params ) if params is not None else {}
583595
584- def set_analyze_params (self , params : dict [str , Any ] | None = None ) -> None :
596+ return self
597+
598+ def set_analyze_params (self , params : dict [str , Any ] | None = None ) -> AutoRunner :
585599 """
586600 Set the data analysis extra params.
587601
@@ -595,7 +609,9 @@ def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
595609 else :
596610 self .analyze_params = deepcopy (params )
597611
598- def set_hpo_params (self , params : dict [str , Any ] | None = None ) -> None :
612+ return self
613+
614+ def set_hpo_params (self , params : dict [str , Any ] | None = None ) -> AutoRunner :
599615 """
600616 Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle
601617 templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the
@@ -621,7 +637,9 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
621637 """
622638 self .hpo_params = self .train_params if params is None else params
623639
624- def set_nni_search_space (self , search_space ):
640+ return self
641+
642+ def set_nni_search_space (self , search_space : dict [str , Any ]) -> AutoRunner :
625643 """
626644 Set the search space for NNI parameter search.
627645
@@ -638,6 +656,8 @@ def set_nni_search_space(self, search_space):
638656 self .search_space = search_space
639657 self .hpo_tasks = value_combinations
640658
659+ return self
660+
641661 def _train_algo_in_sequence (self , history : list [dict [str , Any ]]) -> None :
642662 """
643663 Train the Algos in a sequential scheme. The order of training is randomized.
0 commit comments