33Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
44 - https://arxiv.org/abs/2101.08692
55
6-
76Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
87 - https://arxiv.org/abs/2102.06171
98
109Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
1110
1211Status:
1312* These models are a work in progress, experiments ongoing.
14- * Two pretrained weights so far, more to come.
15- * Model details update to closer match official JAX code now that it's released
13+ * Pretrained weights for two models so far, more to come.
14+ * Model details updated to closer match official JAX code now that it's released
1615* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
1716
1817Hacked together by / copyright Ross Wightman, 2021.
@@ -150,7 +149,7 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
150149 num_features = channels [- 1 ] * 2
151150 attn_kwargs = attn_kwargs or dict (reduction_ratio = 0.5 , divisor = 8 )
152151 cfg = NfCfg (
153- depths = depths , channels = channels , stem_type = 'nff ' , group_size = 128 , bottle_ratio = 0.5 , extra_conv = True ,
152+ depths = depths , channels = channels , stem_type = 'deep_quad ' , group_size = 128 , bottle_ratio = 0.5 , extra_conv = True ,
154153 num_features = num_features , act_layer = act_layer , attn_layer = attn_layer , attn_kwargs = attn_kwargs )
155154 return cfg
156155
@@ -176,9 +175,6 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
176175 nfnet_f6s = _nfnet_cfg (depths = (7 , 14 , 42 , 21 ), act_layer = 'silu' ),
177176 nfnet_f7s = _nfnet_cfg (depths = (8 , 16 , 48 , 24 ), act_layer = 'silu' ),
178177
179- # NFNet-F models w/ SiLU (much faster in PyTorch)
180- # FIXME add remainder if silu vs gelu proves worthwhile
181-
182178 # EffNet influenced RegNet defs.
183179 # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
184180 nf_regnet_b0 = _nfreg_cfg (depths = (1 , 3 , 6 , 6 )),
@@ -194,9 +190,9 @@ def _nfnet_cfg(depths, act_layer='gelu', attn_layer='se', attn_kwargs=None):
194190 nf_resnet50 = _nfres_cfg (depths = (3 , 4 , 6 , 3 )),
195191 nf_resnet101 = _nfres_cfg (depths = (3 , 4 , 23 , 3 )),
196192
197- nf_seresnet26 = _nfres_cfg (depths = (2 , 2 , 2 , 2 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
198- nf_seresnet50 = _nfres_cfg (depths = (3 , 4 , 6 , 3 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
199- nf_seresnet101 = _nfres_cfg (depths = (3 , 4 , 23 , 3 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 0.25 )),
193+ nf_seresnet26 = _nfres_cfg (depths = (2 , 2 , 2 , 2 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 1 / 16 )),
194+ nf_seresnet50 = _nfres_cfg (depths = (3 , 4 , 6 , 3 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 1 / 16 )),
195+ nf_seresnet101 = _nfres_cfg (depths = (3 , 4 , 23 , 3 ), attn_layer = 'se' , attn_kwargs = dict (reduction_ratio = 1 / 16 )),
200196
201197 nf_ecaresnet26 = _nfres_cfg (depths = (2 , 2 , 2 , 2 ), attn_layer = 'eca' , attn_kwargs = dict ()),
202198 nf_ecaresnet50 = _nfres_cfg (depths = (3 , 4 , 6 , 3 ), attn_layer = 'eca' , attn_kwargs = dict ()),
@@ -315,38 +311,26 @@ def forward(self, x):
315311 return out
316312
317313
318- def stem_info (stem_type ):
319- stem_stride = 2
320- if 'nff' in stem_type or 'pool' in stem_type :
321- stem_stride = 4
322- stem_feat = ''
323- if 'nff' in stem_type :
324- stem_feat = 'stem.act3'
325- elif 'deep' in stem_type and not 'pool' in stem_type :
326- stem_feat = 'stem.act2'
327- return stem_stride , stem_feat
328-
329-
330314def create_stem (in_chs , out_chs , stem_type = '' , conv_layer = None , act_layer = None ):
331315 stem_stride = 2
332- stem_feature = ''
316+ stem_feature = dict ( num_chs = out_chs , reduction = 2 , module = '' )
333317 stem = OrderedDict ()
334- assert stem_type in ('' , 'nff ' , 'deep ' , 'deep_tiered ' , '3x3' , '7x7' , 'deep_pool' , '3x3_pool' , '7x7_pool' )
318+ assert stem_type in ('' , 'deep ' , 'deep_tiered ' , 'deep_quad ' , '3x3' , '7x7' , 'deep_pool' , '3x3_pool' , '7x7_pool' )
335319 if 'deep' in stem_type or 'nff' in stem_type :
336320 # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
337- if 'nff ' in stem_type :
321+ if 'quad ' in stem_type :
338322 assert not 'pool' in stem_type
339323 stem_chs = (16 , 32 , 64 , out_chs )
340324 strides = (2 , 1 , 1 , 2 )
341325 stem_stride = 4
342- stem_feature = 'stem.act4'
326+ stem_feature = dict ( num_chs = 64 , reduction = 2 , module = 'stem.act4' )
343327 else :
344328 if 'tiered' in stem_type :
345- stem_chs = (3 * out_chs // 8 , out_chs // 2 , out_chs )
329+ stem_chs = (3 * out_chs // 8 , out_chs // 2 , out_chs ) # like 'T' resnets in resnet.py
346330 else :
347- stem_chs = (out_chs // 2 , out_chs // 2 , out_chs )
331+ stem_chs = (out_chs // 2 , out_chs // 2 , out_chs ) # 'D' ResNets
348332 strides = (2 , 1 , 1 )
349- stem_feature = 'stem.act3'
333+ stem_feature = dict ( num_chs = out_chs // 2 , reduction = 2 , module = 'stem.act3' )
350334 last_idx = len (stem_chs ) - 1
351335 for i , (c , s ) in enumerate (zip (stem_chs , strides )):
352336 stem [f'conv{ i + 1 } ' ] = conv_layer (in_chs , c , kernel_size = 3 , stride = s )
@@ -401,7 +385,7 @@ class NormFreeNet(nn.Module):
401385 * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
402386 impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
403387 * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
404- apply it in each activation. This is slightly slower, and yields slightly different results .
388+ apply it in each activation. This is slightly slower, numerically different, but matches official impl .
405389 * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
406390 for what it is/does. Approx 8-10% throughput loss.
407391 """
@@ -424,7 +408,7 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
424408 self .stem , stem_stride , stem_feat = create_stem (
425409 in_chans , stem_chs , cfg .stem_type , conv_layer = conv_layer , act_layer = act_layer )
426410
427- self .feature_info = [dict ( num_chs = stem_chs , reduction = 2 , module = stem_feat ) ] if stem_stride == 4 else []
411+ self .feature_info = [stem_feat ] if stem_stride == 4 else []
428412 drop_path_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (cfg .depths )).split (cfg .depths )]
429413 prev_chs = stem_chs
430414 net_stride = stem_stride
@@ -476,7 +460,6 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
476460 # The paper NFRegNet models have an EfficientNet-like final head convolution.
477461 self .num_features = make_divisible (cfg .width_factor * cfg .num_features , cfg .ch_div )
478462 self .final_conv = conv_layer (prev_chs , self .num_features , 1 )
479- # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's pushed into stdconv
480463 else :
481464 self .num_features = prev_chs
482465 self .final_conv = nn .Identity ()
@@ -554,10 +537,12 @@ def nfnet_f3(pretrained=False, **kwargs):
554537 return _create_normfreenet ('nfnet_f3' , pretrained = pretrained , ** kwargs )
555538
556539
540+ @register_model
557541def nfnet_f4 (pretrained = False , ** kwargs ):
558542 return _create_normfreenet ('nfnet_f4' , pretrained = pretrained , ** kwargs )
559543
560544
545+ @register_model
561546def nfnet_f5 (pretrained = False , ** kwargs ):
562547 return _create_normfreenet ('nfnet_f5' , pretrained = pretrained , ** kwargs )
563548
@@ -567,6 +552,7 @@ def nfnet_f6(pretrained=False, **kwargs):
567552 return _create_normfreenet ('nfnet_f6' , pretrained = pretrained , ** kwargs )
568553
569554
555+ @register_model
570556def nfnet_f7 (pretrained = False , ** kwargs ):
571557 return _create_normfreenet ('nfnet_f7' , pretrained = pretrained , ** kwargs )
572558
@@ -591,10 +577,12 @@ def nfnet_f3s(pretrained=False, **kwargs):
591577 return _create_normfreenet ('nfnet_f3s' , pretrained = pretrained , ** kwargs )
592578
593579
580+ @register_model
594581def nfnet_f4s (pretrained = False , ** kwargs ):
595582 return _create_normfreenet ('nfnet_f4s' , pretrained = pretrained , ** kwargs )
596583
597584
585+ @register_model
598586def nfnet_f5s (pretrained = False , ** kwargs ):
599587 return _create_normfreenet ('nfnet_f5s' , pretrained = pretrained , ** kwargs )
600588
@@ -604,6 +592,7 @@ def nfnet_f6s(pretrained=False, **kwargs):
604592 return _create_normfreenet ('nfnet_f6s' , pretrained = pretrained , ** kwargs )
605593
606594
595+ @register_model
607596def nfnet_f7s (pretrained = False , ** kwargs ):
608597 return _create_normfreenet ('nfnet_f7s' , pretrained = pretrained , ** kwargs )
609598
0 commit comments