Skip to content

Commit 6890300

Browse files
committed
Add DropPath (stochastic depth) to RegNet
1 parent 47794d2 commit 6890300

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

timm/models/regnet.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2020
from .helpers import build_model_with_cfg
21-
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule
21+
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
2222
from .registry import register_model
2323

2424

@@ -195,14 +195,15 @@ class RegStage(nn.Module):
195195
"""Stage (sequence of blocks w/ the same output shape)."""
196196

197197
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
198-
block_fn=Bottleneck, se_ratio=0.):
198+
block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None):
199199
super(RegStage, self).__init__()
200200
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
201201
first_dilation = 1 if dilation in (1, 2) else 2
202202
for i in range(depth):
203203
block_stride = stride if i == 0 else 1
204204
block_in_chs = in_chs if i == 0 else out_chs
205205
block_dilation = first_dilation if i == 0 else dilation
206+
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None
206207
if (block_in_chs != out_chs) or (block_stride != 1):
207208
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
208209
else:
@@ -212,7 +213,7 @@ def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group
212213
self.add_module(
213214
name, block_fn(
214215
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
215-
downsample=proj_block, **block_kwargs)
216+
downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
216217
)
217218

218219
def forward(self, x):
@@ -229,7 +230,7 @@ class RegNet(nn.Module):
229230
"""
230231

231232
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
232-
zero_init_last_bn=True):
233+
drop_path_rate=0., zero_init_last_bn=True):
233234
super().__init__()
234235
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
235236
self.num_classes = num_classes
@@ -244,7 +245,7 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_p
244245
# Construct the stages
245246
prev_width = stem_width
246247
curr_stride = 2
247-
stage_params = self._get_stage_params(cfg, output_stride=output_stride)
248+
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
248249
se_ratio = cfg['se_ratio']
249250
for i, stage_args in enumerate(stage_params):
250251
stage_name = "s{}".format(i + 1)
@@ -272,7 +273,7 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_p
272273
if hasattr(m, 'zero_init_last_bn'):
273274
m.zero_init_last_bn()
274275

275-
def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
276+
def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
276277
# Generate RegNet ws per block
277278
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
278279
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
@@ -285,24 +286,26 @@ def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
285286
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
286287
stage_strides = []
287288
stage_dilations = []
288-
total_stride = 2
289+
net_stride = 2
289290
dilation = 1
290291
for _ in range(num_stages):
291-
if total_stride >= output_stride:
292+
if net_stride >= output_stride:
292293
dilation *= default_stride
293294
stride = 1
294295
else:
295296
stride = default_stride
296-
total_stride *= stride
297+
net_stride *= stride
297298
stage_strides.append(stride)
298299
stage_dilations.append(dilation)
300+
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
299301

300302
# Adjust the compatibility of ws and gws
301303
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
302-
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width']
304+
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate']
303305
stage_params = [
304306
dict(zip(param_names, params)) for params in
305-
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)]
307+
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
308+
stage_dpr)]
306309
return stage_params
307310

308311
def get_classifier(self):

0 commit comments

Comments
 (0)