@@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
113113 digits of the SHA256 hash of the contents of the file. The hash is used to
114114 ensure unique names and to verify the contents of the file. Default: False
115115 """
116- if cfg is None :
117- cfg = getattr (model , 'default_cfg' )
118- if cfg is None or 'url' not in cfg or not cfg ['url' ]:
119- _logger .warning ("Pretrained model URL does not exist, using random initialization." )
116+ cfg = cfg or getattr (model , 'default_cfg' )
117+ if cfg is None or not cfg .get ('url' , None ):
118+ _logger .warning ("No pretrained weights exist for this model. Using random initialization." )
120119 return
121120 url = cfg ['url' ]
122121
@@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):
174173
175174
176175def load_pretrained (model , cfg = None , num_classes = 1000 , in_chans = 3 , filter_fn = None , strict = True , progress = False ):
177- if cfg is None :
178- cfg = getattr (model , 'default_cfg' )
179- if cfg is None or 'url' not in cfg or not cfg ['url' ]:
176+ cfg = cfg or getattr (model , 'default_cfg' )
177+ if cfg is None or not cfg .get ('url' , None ):
180178 _logger .warning ("No pretrained weights exist for this model. Using random initialization." )
181179 return
182180
@@ -376,3 +374,11 @@ def build_model_with_cfg(
376374 model .default_cfg = default_cfg_for_features (default_cfg ) # add back default_cfg
377375
378376 return model
377+
378+
379+ def model_parameters (model , exclude_head = False ):
380+ if exclude_head :
381+ # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
382+ return [p for p in model .parameters ()][:- 2 ]
383+ else :
384+ return model .parameters ()
0 commit comments