2828
2929from monai .apps .mmars .mmars import _get_all_ngc_models
3030from monai .apps .utils import _basename , download_url , extractall , get_logger
31+ from monai .bundle .config_item import ConfigComponent
3132from monai .bundle .config_parser import ConfigParser
3233from monai .bundle .utils import DEFAULT_INFERENCE , DEFAULT_METADATA
3334from monai .bundle .workflows import BundleWorkflow , ConfigWorkflow
@@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:
247248 return Path (bundle_dir )
248249
249250
250- @deprecated_arg_default ("source" , "github" , "monaihosting" , since = "1.3" , replaced = "1.4 " )
251+ @deprecated_arg_default ("source" , "github" , "monaihosting" , since = "1.3" , replaced = "1.5 " )
251252def download (
252253 name : str | None = None ,
253254 version : str | None = None ,
@@ -375,8 +376,9 @@ def download(
375376 )
376377
377378
378- @deprecated_arg ("net_name" , since = "1.3" , removed = "1.4" , msg_suffix = "please use ``model`` instead." )
379- @deprecated_arg ("net_kwargs" , since = "1.3" , removed = "1.3" , msg_suffix = "please use ``model`` instead." )
379+ @deprecated_arg ("net_name" , since = "1.3" , removed = "1.5" , msg_suffix = "please use ``model`` instead." )
380+ @deprecated_arg ("net_kwargs" , since = "1.3" , removed = "1.5" , msg_suffix = "please use ``model`` instead." )
381+ @deprecated_arg ("return_state_dict" , since = "1.3" , removed = "1.5" )
380382def load (
381383 name : str ,
382384 model : torch .nn .Module | None = None ,
@@ -395,8 +397,10 @@ def load(
395397 workflow_name : str | BundleWorkflow | None = None ,
396398 args_file : str | None = None ,
397399 copy_model_args : dict | None = None ,
400+ return_state_dict : bool = True ,
401+ net_override : dict | None = None ,
398402 net_name : str | None = None ,
399- ** net_override : Any ,
403+ ** net_kwargs : Any ,
400404) -> object | tuple [torch .nn .Module , dict , dict ] | Any :
401405 """
402406 Load model weights or TorchScript module of a bundle.
@@ -441,7 +445,12 @@ def load(
441445 workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
442446 args_file: a JSON or YAML file to provide default values for all the args in "download" function.
443447 copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
444- net_override: id-value pairs to override the parameters in the network of the bundle.
448+ return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
449+ from `_workflow.network_def` will be instantiated and load the achieved weights.
450+ net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
451+ net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
452+ This argument only works when loading weights.
453+ net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
445454
446455 Returns:
447456 1. If `load_ts_module` is `False` and `model` is `None`,
@@ -452,9 +461,15 @@ def load(
452461 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
453462 the corresponding metadata dict, and extra files dict.
454463 please check `monai.data.load_net_with_metadata` for more details.
464+ 4. If `return_state_dict` is True, return model weights, only used for compatibility
465+ when `model` and `net_name` are all `None`.
455466
456467 """
468+ if return_state_dict and (model is not None or net_name is not None ):
469+ warnings .warn ("Incompatible values: model and net_name are all specified, return state dict instead." )
470+
457471 bundle_dir_ = _process_bundle_dir (bundle_dir )
472+ net_override = {} if net_override is None else net_override
458473 copy_model_args = {} if copy_model_args is None else copy_model_args
459474
460475 if device is None :
@@ -466,7 +481,7 @@ def load(
466481 if remove_prefix :
467482 name = _remove_ngc_prefix (name , prefix = remove_prefix )
468483 full_path = os .path .join (bundle_dir_ , name , model_file )
469- if not os .path .exists (full_path ) or model is None :
484+ if not os .path .exists (full_path ):
470485 download (
471486 name = name ,
472487 version = version ,
@@ -477,34 +492,52 @@ def load(
477492 progress = progress ,
478493 args_file = args_file ,
479494 )
480- train_config_file = bundle_dir_ / name / "configs" / f"{ workflow_type } .json"
481- if train_config_file .is_file ():
482- _net_override = {f"network_def#{ key } " : value for key , value in net_override .items ()}
483- _workflow = create_workflow (
484- workflow_name = workflow_name ,
485- args_file = args_file ,
486- config_file = str (train_config_file ),
487- workflow_type = workflow_type ,
488- ** _net_override ,
489- )
490- else :
491- _workflow = None
492495
493496 # loading with `torch.jit.load`
494497 if load_ts_module is True :
495498 return load_net_with_metadata (full_path , map_location = torch .device (device ), more_extra_files = config_files )
496499 # loading with `torch.load`
497500 model_dict = torch .load (full_path , map_location = torch .device (device ))
501+
498502 if not isinstance (model_dict , Mapping ):
499503 warnings .warn (f"the state dictionary from { full_path } should be a dictionary but got { type (model_dict )} ." )
500504 model_dict = get_state_dict (model_dict )
501505
502- if model is None and _workflow is None :
506+ if return_state_dict :
503507 return model_dict
504- model = _workflow .network_def if model is None else model
505- model .to (device )
506508
507- copy_model_state (dst = model , src = model_dict if key_in_ckpt is None else model_dict [key_in_ckpt ], ** copy_model_args )
509+ _workflow = None
510+ if model is None and net_name is None :
511+ bundle_config_file = bundle_dir_ / name / "configs" / f"{ workflow_type } .json"
512+ if bundle_config_file .is_file ():
513+ _net_override = {f"network_def#{ key } " : value for key , value in net_override .items ()}
514+ _workflow = create_workflow (
515+ workflow_name = workflow_name ,
516+ args_file = args_file ,
517+ config_file = str (bundle_config_file ),
518+ workflow_type = workflow_type ,
519+ ** _net_override ,
520+ )
521+ else :
522+ warnings .warn (f"Cannot find the config file: { bundle_config_file } , return state dict instead." )
523+ return model_dict
524+ if _workflow is not None :
525+ if not hasattr (_workflow , "network_def" ):
526+ warnings .warn ("No available network definition in the bundle, return state dict instead." )
527+ return model_dict
528+ else :
529+ model = _workflow .network_def
530+ elif net_name is not None :
531+ net_kwargs ["_target_" ] = net_name
532+ configer = ConfigComponent (config = net_kwargs )
533+ model = configer .instantiate () # type: ignore
534+
535+ model .to (device ) # type: ignore
536+
537+ copy_model_state (
538+ dst = model , src = model_dict if key_in_ckpt is None else model_dict [key_in_ckpt ], ** copy_model_args # type: ignore
539+ )
540+
508541 return model
509542
510543
0 commit comments