1212__all__ = ['parse_model_name' , 'safe_model_name' , 'create_model' ]
1313
1414
15- def parse_model_name (model_name ):
15+ def parse_model_name (model_name : str ):
1616 if model_name .startswith ('hf_hub' ):
1717 # NOTE for backwards compat, deprecate hf_hub use
1818 model_name = model_name .replace ('hf_hub' , 'hf-hub' )
@@ -26,7 +26,7 @@ def parse_model_name(model_name):
2626 return 'timm' , model_name
2727
2828
29- def safe_model_name (model_name , remove_source = True ):
29+ def safe_model_name (model_name : str , remove_source : bool = True ):
3030 # return a filename / path safe model name
3131 def make_safe (name ):
3232 return '' .join (c if c .isalnum () else '_' for c in name ).rstrip ('_' )
@@ -56,16 +56,19 @@ def create_model(
5656 </Tip>
5757
5858 Args:
59- model_name (`str`): Name of model to instantiate.
60- pretrained (`bool`): If set to `True`, load pretrained ImageNet-1k weights.
61- pretrained_cfg (`Union[str, dict, PretrainedCfg]`): Pass in an external pretrained_cfg for model.
62- pretrained_cfg_overlay (`dict`): Replace key-values in base pretrained_cfg with these.
63- checkpoint_path (`str`): Path of checkpoint to load _after_ the model is initialized.
64- scriptable (`bool`): Set layer config so that model is jit scriptable (not working for all models yet).
65- exportable (`bool`): Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
66- no_jit (`bool`): Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
67- **drop_rate (`float`): Dropout rate for training. Defaults to `0.0`.
68- **global_pool (`str`): Global pooling type. Defaults to `'avg'`.
59+ model_name: Name of model to instantiate.
60+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
61+ pretrained_cfg: Pass in an external pretrained_cfg for model.
62+ pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
63+ checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
64+ scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
65+ exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
66+ no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
67+
68+ Keyword Args:
69+ drop_rate (float): Classifier dropout rate for training.
70+ drop_path_rate (float): Stochastic depth drop rate for training.
71+ global_pool (str): Classifier global pooling type.
6972
7073 Example:
7174
0 commit comments