Skip to content

Commit a39cc43

Browse files
committed
Bring EfficientNet and MobileNetV3 up to date with my gen-efficientnet repo
* Split MobileNetV3 and EfficientNet model files and put builder and blocks in own files (getting too large) * Finalize CondConv EfficientNet variant * Add the AdvProp weights files and B8 EfficientNet model * Refine the feature extraction module for EfficientNet and MobileNetV3
1 parent ad93347 commit a39cc43

File tree

9 files changed

+1621
-1291
lines changed

9 files changed

+1621
-1291
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .nasnet import *
99
from .pnasnet import *
1010
from .gen_efficientnet import *
11+
from .mobilenetv3 import *
1112
from .inception_v3 import *
1213
from .gluon_resnet import *
1314
from .gluon_xception import *

timm/models/activations.py

Lines changed: 38 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,91 +7,75 @@
77
if _USE_MEM_EFFICIENT_ISH:
88
# This version reduces memory overhead of Swish during training by
99
# recomputing torch.sigmoid(x) in backward instead of saving it.
10-
class SwishAutoFn(torch.autograd.Function):
11-
"""Swish - Described in: https://arxiv.org/abs/1710.05941
12-
Memory efficient variant from:
13-
https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
14-
"""
15-
@staticmethod
16-
def forward(ctx, x):
17-
result = x.mul(torch.sigmoid(x))
18-
ctx.save_for_backward(x)
19-
return result
10+
@torch.jit.script
11+
def swish_jit_fwd(x):
12+
return x.mul(torch.sigmoid(x))
2013

21-
@staticmethod
22-
def backward(ctx, grad_output):
23-
x = ctx.saved_variables[0]
24-
sigmoid_x = torch.sigmoid(x)
25-
return grad_output.mul(sigmoid_x * (1 + x * (1 - sigmoid_x)))
2614

27-
def swish(x, inplace=False):
28-
# inplace ignored
29-
return SwishAutoFn.apply(x)
15+
@torch.jit.script
16+
def swish_jit_bwd(x, grad_output):
17+
x_sigmoid = torch.sigmoid(x)
18+
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
3019

3120

32-
class MishAutoFn(torch.autograd.Function):
33-
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
34-
Experimental memory-efficient variant
21+
class SwishJitAutoFn(torch.autograd.Function):
22+
""" torch.jit.script optimised Swish
23+
Inspired by conversation btw Jeremy Howard & Adam Pazske
24+
https://twitter.com/jeremyphoward/status/1188251041835315200
3525
"""
3626

3727
@staticmethod
3828
def forward(ctx, x):
3929
ctx.save_for_backward(x)
40-
y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
41-
return y
30+
return swish_jit_fwd(x)
4231

4332
@staticmethod
4433
def backward(ctx, grad_output):
45-
x = ctx.saved_variables[0]
46-
x_sigmoid = torch.sigmoid(x)
47-
x_tanh_sp = F.softplus(x).tanh()
48-
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
34+
x = ctx.saved_tensors[0]
35+
return swish_jit_bwd(x, grad_output)
4936

50-
def mish(x, inplace=False):
51-
# inplace ignored
52-
return MishAutoFn.apply(x)
5337

38+
def swish(x, _inplace=False):
39+
return SwishJitAutoFn.apply(x)
5440

55-
class WishAutoFn(torch.autograd.Function):
56-
"""Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments.
57-
Experimental memory-efficient variant
58-
"""
5941

42+
@torch.jit.script
43+
def mish_jit_fwd(x):
44+
return x.mul(torch.tanh(F.softplus(x)))
45+
46+
47+
@torch.jit.script
48+
def mish_jit_bwd(x, grad_output):
49+
x_sigmoid = torch.sigmoid(x)
50+
x_tanh_sp = F.softplus(x).tanh()
51+
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
52+
53+
54+
class MishJitAutoFn(torch.autograd.Function):
6055
@staticmethod
6156
def forward(ctx, x):
6257
ctx.save_for_backward(x)
63-
y = x.mul(torch.tanh(torch.exp(x)))
64-
return y
58+
return mish_jit_fwd(x)
6559

6660
@staticmethod
6761
def backward(ctx, grad_output):
68-
x = ctx.saved_variables[0]
69-
x_exp = x.exp()
70-
x_tanh_exp = x_exp.tanh()
71-
return grad_output.mul(x_tanh_exp + x * x_exp * (1 - x_tanh_exp * x_tanh_exp))
72-
73-
def wish(x, inplace=False):
74-
# inplace ignored
75-
return WishAutoFn.apply(x)
62+
x = ctx.saved_tensors[0]
63+
return mish_jit_bwd(x, grad_output)
64+
65+
def mish(x, _inplace=False):
66+
return MishJitAutoFn.apply(x)
67+
7668
else:
7769
def swish(x, inplace=False):
7870
"""Swish - Described in: https://arxiv.org/abs/1710.05941
7971
"""
8072
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
8173

8274

83-
def mish(x, inplace=False):
75+
def mish(x, _inplace=False):
8476
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
8577
"""
86-
inner = F.softplus(x).tanh()
87-
return x.mul_(inner) if inplace else x.mul(inner)
88-
89-
90-
def wish(x, inplace=False):
91-
"""Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments.
92-
"""
93-
inner = x.exp().tanh()
94-
return x.mul_(inner) if inplace else x.mul(inner)
78+
return x.mul(F.softplus(x).tanh())
9579

9680

9781
class Swish(nn.Module):
@@ -112,15 +96,6 @@ def forward(self, x):
11296
return mish(x, self.inplace)
11397

11498

115-
class Wish(nn.Module):
116-
def __init__(self, inplace=False):
117-
super(Wish, self).__init__()
118-
self.inplace = inplace
119-
120-
def forward(self, x):
121-
return wish(x, self.inplace)
122-
123-
12499
def sigmoid(x, inplace=False):
125100
return x.sigmoid_() if inplace else x.sigmoid()
126101

timm/models/conv2d_layers.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,14 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
102102

103103
class MixedConv2d(nn.Module):
104104
""" Mixed Grouped Convolution
105-
106105
Based on MDConv and GroupedConv in MixNet impl:
107106
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
107+
108+
NOTE: This does not currently work with torch.jit.script
108109
"""
109110

110111
def __init__(self, in_channels, out_channels, kernel_size=3,
111-
stride=1, padding='', dilation=1, mixed_dilated=False, depthwise=False, **kwargs):
112+
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
112113
super(MixedConv2d, self).__init__()
113114

114115
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
@@ -118,17 +119,13 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
118119
self.in_channels = sum(in_splits)
119120
self.out_channels = sum(out_splits)
120121
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
121-
d = dilation
122-
# FIXME make compat with non-square kernel/dilations/strides
123-
if stride == 1 and mixed_dilated:
124-
d, k = (k - 1) // 2, 3
125122
conv_groups = out_ch if depthwise else 1
126123
# use add_module to keep key space clean
127124
self.add_module(
128125
str(idx),
129126
create_conv2d_pad(
130127
in_ch, out_ch, k, stride=stride,
131-
padding=padding, dilation=d, groups=conv_groups, **kwargs)
128+
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
132129
)
133130
self.splits = in_splits
134131

@@ -154,12 +151,12 @@ def condconv_initializer(weight):
154151

155152
class CondConv2d(nn.Module):
156153
""" Conditional Convolution
157-
158154
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
159155
160156
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
161157
https://github.com/pytorch/pytorch/issues/17983
162158
"""
159+
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
163160

164161
def __init__(self, in_channels, out_channels, kernel_size=3,
165162
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
@@ -171,13 +168,10 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
171168
self.stride = _pair(stride)
172169
padding_val, is_padding_dynamic = get_padding_value(
173170
padding, kernel_size, stride=stride, dilation=dilation)
174-
self.conv_fn = conv2d_same if is_padding_dynamic else F.conv2d
171+
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
175172
self.padding = _pair(padding_val)
176173
self.dilation = _pair(dilation)
177-
self.transposed = False
178-
self.output_padding = _pair(0)
179174
self.groups = groups
180-
self.padding_mode = 'zero'
181175
self.num_experts = num_experts
182176

183177
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
@@ -186,60 +180,63 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
186180
weight_num_param *= wd
187181
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
188182

189-
# FIXME I haven't tested bias yet
190183
if bias:
191184
self.bias_shape = (self.out_channels,)
192-
condconv_bias_shape = (self.num_experts, self.out_channels)
193-
self.bias = torch.nn.Parameter(torch.Tensor(condconv_bias_shape))
185+
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
194186
else:
195187
self.register_parameter('bias', None)
196188

197189
self.reset_parameters()
198-
# FIXME once I'm satisfied this works, remove the looping path?
199-
self._use_groups = True # use groups for parallel per-batch-element kernel convolution
200190

201191
def reset_parameters(self):
202192
init_weight = get_condconv_initializer(
203193
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
204194
init_weight(self.weight)
205195
if self.bias is not None:
206-
# FIXME bias not tested
207196
fan_in = np.prod(self.weight_shape[1:])
208197
bound = 1 / math.sqrt(fan_in)
209198
init_bias = get_condconv_initializer(
210199
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
211200
init_bias(self.bias)
212201

213202
def forward(self, x, routing_weights):
214-
weight = torch.matmul(routing_weights, self.weight)
215-
bias = torch.matmul(routing_weights, self.bias) if self.bias is not None else None
216203
B, C, H, W = x.shape
217-
if self._use_groups:
218-
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
219-
weight = weight.view(new_weight_shape)
220-
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
221-
x = x.view(1, B * C, H, W)
222-
out = self.conv_fn(
204+
weight = torch.matmul(routing_weights, self.weight)
205+
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
206+
weight = weight.view(new_weight_shape)
207+
bias = None
208+
if self.bias is not None:
209+
bias = torch.matmul(routing_weights, self.bias)
210+
bias = bias.view(B * self.out_channels)
211+
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
212+
x = x.view(1, B * C, H, W)
213+
if self.dynamic_padding:
214+
out = conv2d_same(
223215
x, weight, bias, stride=self.stride, padding=self.padding,
224216
dilation=self.dilation, groups=self.groups * B)
225-
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
226217
else:
227-
x = torch.split(x, 1, 0)
228-
weight = torch.split(weight, 1, 0)
229-
if self.bias is not None:
230-
bias = torch.matmul(routing_weights, self.bias)
231-
bias = torch.split(bias, 1, 0)
232-
else:
233-
bias = [None] * B
234-
out = []
235-
for xi, wi, bi in zip(x, weight, bias):
236-
wi = wi.view(*self.weight_shape)
237-
if bi is not None:
238-
bi = bi.view(*self.bias_shape)
239-
out.append(self.conv_fn(
240-
xi, wi, bi, stride=self.stride, padding=self.padding,
241-
dilation=self.dilation, groups=self.groups))
242-
out = torch.cat(out, 0)
218+
out = F.conv2d(
219+
x, weight, bias, stride=self.stride, padding=self.padding,
220+
dilation=self.dilation, groups=self.groups * B)
221+
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
222+
223+
# Literal port (from TF definition)
224+
# x = torch.split(x, 1, 0)
225+
# weight = torch.split(weight, 1, 0)
226+
# if self.bias is not None:
227+
# bias = torch.matmul(routing_weights, self.bias)
228+
# bias = torch.split(bias, 1, 0)
229+
# else:
230+
# bias = [None] * B
231+
# out = []
232+
# for xi, wi, bi in zip(x, weight, bias):
233+
# wi = wi.view(*self.weight_shape)
234+
# if bi is not None:
235+
# bi = bi.view(*self.bias_shape)
236+
# out.append(self.conv_fn(
237+
# xi, wi, bi, stride=self.stride, padding=self.padding,
238+
# dilation=self.dilation, groups=self.groups))
239+
# out = torch.cat(out, 0)
243240
return out
244241

245242

@@ -250,13 +247,14 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
250247
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
251248
# We're going to use only lists for defining the MixedConv2d kernel groups,
252249
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
253-
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
250+
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
254251
else:
255252
depthwise = kwargs.pop('depthwise', False)
256253
groups = out_chs if depthwise else 1
257254
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
258-
create_fn = CondConv2d
255+
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
259256
else:
260-
create_fn = create_conv2d_pad
261-
return create_fn(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
257+
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
258+
return m
259+
262260

0 commit comments

Comments
 (0)