1212__all__ = ['parse_model_name' , 'safe_model_name' , 'create_model' ]
1313
1414
15- def parse_model_name (model_name ):
15+ def parse_model_name (model_name : str ):
1616 if model_name .startswith ('hf_hub' ):
1717 # NOTE for backwards compat, deprecate hf_hub use
1818 model_name = model_name .replace ('hf_hub' , 'hf-hub' )
@@ -26,7 +26,7 @@ def parse_model_name(model_name):
2626 return 'timm' , model_name
2727
2828
29- def safe_model_name (model_name , remove_source = True ):
29+ def safe_model_name (model_name : str , remove_source : bool = True ):
3030 # return a filename / path safe model name
3131 def make_safe (name ):
3232 return '' .join (c if c .isalnum () else '_' for c in name ).rstrip ('_' )
@@ -46,27 +46,48 @@ def create_model(
4646 no_jit : Optional [bool ] = None ,
4747 ** kwargs ,
4848):
49- """Create a model
49+ """Create a model.
5050
5151 Lookup model's entrypoint function and pass relevant args to create a new model.
5252
53- **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
54- and then the model class __init__(). kwargs values set to None are pruned before passing.
53+ <Tip>
54+ **kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
55+ and then the model class __init__(). kwargs values set to None are pruned before passing.
56+ </Tip>
5557
5658 Args:
57- model_name (str): name of model to instantiate
58- pretrained (bool): load pretrained ImageNet-1k weights if true
59- pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
60- pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
61- checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
62- scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
63- exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
64- no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
59+ model_name: Name of model to instantiate.
60+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
61+ pretrained_cfg: Pass in an external pretrained_cfg for model.
62+ pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
63+ checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
64+ scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
65+ exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
66+ no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
6567
6668 Keyword Args:
67- drop_rate (float): dropout rate for training (default: 0.0)
68- global_pool (str): global pool type (default: 'avg')
69- **: other kwargs are consumed by builder or model __init__()
69+ drop_rate (float): Classifier dropout rate for training.
70+ drop_path_rate (float): Stochastic depth drop rate for training.
71+ global_pool (str): Classifier global pooling type.
72+
73+ Example:
74+
75+ ```py
76+ >>> from timm import create_model
77+
78+ >>> # Create a MobileNetV3-Large model with no pretrained weights.
79+ >>> model = create_model('mobilenetv3_large_100')
80+
81+ >>> # Create a MobileNetV3-Large model with pretrained weights.
82+ >>> model = create_model('mobilenetv3_large_100', pretrained=True)
83+ >>> model.num_classes
84+ 1000
85+
86+ >>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
87+ >>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
88+ >>> model.num_classes
89+ 10
90+ ```
7091 """
7192 # Parameters that aren't supported by all models or are intended to only override model defaults if set
7293 # should default to None in command line args/cfg. Remove them if they are present and not set so that
0 commit comments