1818
1919from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2020from .helpers import build_model_with_cfg
21- from .layers import ClassifierHead , AvgPool2dSame , ConvBnAct , SEModule
21+ from .layers import ClassifierHead , AvgPool2dSame , ConvBnAct , SEModule , DropPath
2222from .registry import register_model
2323
2424
@@ -195,14 +195,15 @@ class RegStage(nn.Module):
195195 """Stage (sequence of blocks w/ the same output shape)."""
196196
197197 def __init__ (self , in_chs , out_chs , stride , dilation , depth , bottle_ratio , group_width ,
198- block_fn = Bottleneck , se_ratio = 0. ):
198+ block_fn = Bottleneck , se_ratio = 0. , drop_path_rate = None , drop_block = None ):
199199 super (RegStage , self ).__init__ ()
200200 block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
201201 first_dilation = 1 if dilation in (1 , 2 ) else 2
202202 for i in range (depth ):
203203 block_stride = stride if i == 0 else 1
204204 block_in_chs = in_chs if i == 0 else out_chs
205205 block_dilation = first_dilation if i == 0 else dilation
206+ drop_path = DropPath (drop_path_rate [i ]) if drop_path_rate is not None else None
206207 if (block_in_chs != out_chs ) or (block_stride != 1 ):
207208 proj_block = downsample_conv (block_in_chs , out_chs , 1 , block_stride , block_dilation )
208209 else :
@@ -212,7 +213,7 @@ def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group
212213 self .add_module (
213214 name , block_fn (
214215 block_in_chs , out_chs , block_stride , block_dilation , bottle_ratio , group_width , se_ratio ,
215- downsample = proj_block , ** block_kwargs )
216+ downsample = proj_block , drop_block = drop_block , drop_path = drop_path , ** block_kwargs )
216217 )
217218
218219 def forward (self , x ):
@@ -229,7 +230,7 @@ class RegNet(nn.Module):
229230 """
230231
231232 def __init__ (self , cfg , in_chans = 3 , num_classes = 1000 , output_stride = 32 , global_pool = 'avg' , drop_rate = 0. ,
232- zero_init_last_bn = True ):
233+ drop_path_rate = 0. , zero_init_last_bn = True ):
233234 super ().__init__ ()
234235 # TODO add drop block, drop path, anti-aliasing, custom bn/act args
235236 self .num_classes = num_classes
@@ -244,7 +245,7 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_p
244245 # Construct the stages
245246 prev_width = stem_width
246247 curr_stride = 2
247- stage_params = self ._get_stage_params (cfg , output_stride = output_stride )
248+ stage_params = self ._get_stage_params (cfg , output_stride = output_stride , drop_path_rate = drop_path_rate )
248249 se_ratio = cfg ['se_ratio' ]
249250 for i , stage_args in enumerate (stage_params ):
250251 stage_name = "s{}" .format (i + 1 )
@@ -272,7 +273,7 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_p
272273 if hasattr (m , 'zero_init_last_bn' ):
273274 m .zero_init_last_bn ()
274275
275- def _get_stage_params (self , cfg , default_stride = 2 , output_stride = 32 ):
276+ def _get_stage_params (self , cfg , default_stride = 2 , output_stride = 32 , drop_path_rate = 0. ):
276277 # Generate RegNet ws per block
277278 w_a , w_0 , w_m , d = cfg ['wa' ], cfg ['w0' ], cfg ['wm' ], cfg ['depth' ]
278279 widths , num_stages , _ , _ = generate_regnet (w_a , w_0 , w_m , d )
@@ -285,24 +286,26 @@ def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
285286 stage_bottle_ratios = [cfg ['bottle_ratio' ] for _ in range (num_stages )]
286287 stage_strides = []
287288 stage_dilations = []
288- total_stride = 2
289+ net_stride = 2
289290 dilation = 1
290291 for _ in range (num_stages ):
291- if total_stride >= output_stride :
292+ if net_stride >= output_stride :
292293 dilation *= default_stride
293294 stride = 1
294295 else :
295296 stride = default_stride
296- total_stride *= stride
297+ net_stride *= stride
297298 stage_strides .append (stride )
298299 stage_dilations .append (dilation )
300+ stage_dpr = np .split (np .linspace (0 , drop_path_rate , d ), np .cumsum (stage_depths [:- 1 ]))
299301
300302 # Adjust the compatibility of ws and gws
301303 stage_widths , stage_groups = adjust_widths_groups_comp (stage_widths , stage_bottle_ratios , stage_groups )
302- param_names = ['out_chs' , 'stride' , 'dilation' , 'depth' , 'bottle_ratio' , 'group_width' ]
304+ param_names = ['out_chs' , 'stride' , 'dilation' , 'depth' , 'bottle_ratio' , 'group_width' , 'drop_path_rate' ]
303305 stage_params = [
304306 dict (zip (param_names , params )) for params in
305- zip (stage_widths , stage_strides , stage_dilations , stage_depths , stage_bottle_ratios , stage_groups )]
307+ zip (stage_widths , stage_strides , stage_dilations , stage_depths , stage_bottle_ratios , stage_groups ,
308+ stage_dpr )]
306309 return stage_params
307310
308311 def get_classifier (self ):
0 commit comments