Skip to content

Commit f670d98

Browse files
committed
Make a few more layers symbolically traceable (remove from FX leaf modules)
* remove dtype kwarg from .to() calls in EvoNorm as it messed up script + trace combo * BatchNormAct2d always uses custom forward (cut & paste from original) instead of super().forward. Fixes #1176 * BlurPool groups==channels, no need to use input.dim[1]
1 parent a9ecb88 commit f670d98

File tree

5 files changed

+31
-48
lines changed

5 files changed

+31
-48
lines changed

timm/models/fx_features.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,17 @@
1515
has_fx_feature_extraction = False
1616

1717
# Layers we went to treat as leaf modules
18-
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
19-
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
20-
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
18+
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
2119
from .layers.non_local_attn import BilinearAttnTransform
2220
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
2321

2422
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
2523
# BUT modules from timm.models should use the registration mechanism below
2624
_leaf_modules = {
27-
BatchNormAct2d, # reason: flow control for jit scripting
2825
BilinearAttnTransform, # reason: flow control t <= 1
29-
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
3026
# Reason: get_same_padding has a max which raises a control flow error
3127
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
3228
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
33-
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
34-
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
35-
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
36-
3729
}
3830

3931
try:

timm/models/layers/blur_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def __init__(self, channels, filt_size=3, stride=2) -> None:
3939

4040
def forward(self, x: torch.Tensor) -> torch.Tensor:
4141
x = F.pad(x, self.padding, 'reflect')
42-
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
42+
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)

timm/models/layers/drop.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ class DropBlock2d(nn.Module):
107107

108108
def __init__(
109109
self,
110-
drop_prob=0.1,
111-
block_size=7,
112-
gamma_scale=1.0,
113-
with_noise=False,
114-
inplace=False,
115-
batchwise=False,
116-
fast=True):
110+
drop_prob: float = 0.1,
111+
block_size: int = 7,
112+
gamma_scale: float = 1.0,
113+
with_noise: bool = False,
114+
inplace: bool = False,
115+
batchwise: bool = False,
116+
fast: bool = True):
117117
super(DropBlock2d, self).__init__()
118118
self.drop_prob = drop_prob
119119
self.gamma_scale = gamma_scale
@@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
157157
class DropPath(nn.Module):
158158
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
159159
"""
160-
def __init__(self, drop_prob=None, scale_by_keep=True):
160+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
161161
super(DropPath, self).__init__()
162162
self.drop_prob = drop_prob
163163
self.scale_by_keep = scale_by_keep

timm/models/layers/evo_norm.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
9292
_assert(C % groups == 0, '')
9393
x_dtype = x.dtype
9494
x = x.reshape(B, groups, C // groups, H, W)
95-
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
95+
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
9696
return rms.expand(x.shape).reshape(B, C, H, W)
9797

9898

@@ -160,14 +160,14 @@ def forward(self, x):
160160
n = x.numel() / x.shape[1]
161161
self.running_var.copy_(
162162
self.running_var * (1 - self.momentum) +
163-
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
163+
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
164164
else:
165165
var = self.running_var
166-
var = var.to(dtype=x_dtype).view(v_shape)
166+
var = var.to(x_dtype).view(v_shape)
167167
left = var.add(self.eps).sqrt_()
168168
right = (x + 1) * instance_rms(x, self.eps)
169169
x = x / left.max(right)
170-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
170+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
171171

172172

173173
class EvoNorm2dB2(nn.Module):
@@ -195,14 +195,14 @@ def forward(self, x):
195195
n = x.numel() / x.shape[1]
196196
self.running_var.copy_(
197197
self.running_var * (1 - self.momentum) +
198-
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
198+
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
199199
else:
200200
var = self.running_var
201-
var = var.to(dtype=x_dtype).view(v_shape)
201+
var = var.to(x_dtype).view(v_shape)
202202
left = var.add(self.eps).sqrt_()
203203
right = instance_rms(x, self.eps) - x
204204
x = x / left.max(right)
205-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
205+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
206206

207207

208208
class EvoNorm2dS0(nn.Module):
@@ -231,9 +231,9 @@ def forward(self, x):
231231
x_dtype = x.dtype
232232
v_shape = (1, -1, 1, 1)
233233
if self.v is not None:
234-
v = self.v.view(v_shape).to(dtype=x_dtype)
234+
v = self.v.view(v_shape).to(x_dtype)
235235
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
236-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
236+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
237237

238238

239239
class EvoNorm2dS0a(EvoNorm2dS0):
@@ -247,10 +247,10 @@ def forward(self, x):
247247
v_shape = (1, -1, 1, 1)
248248
d = group_std(x, self.groups, self.eps)
249249
if self.v is not None:
250-
v = self.v.view(v_shape).to(dtype=x_dtype)
250+
v = self.v.view(v_shape).to(x_dtype)
251251
x = x * (x * v).sigmoid()
252252
x = x / d
253-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
253+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
254254

255255

256256
class EvoNorm2dS1(nn.Module):
@@ -284,7 +284,7 @@ def forward(self, x):
284284
v_shape = (1, -1, 1, 1)
285285
if self.apply_act:
286286
x = self.act(x) / group_std(x, self.groups, self.eps)
287-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
287+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
288288

289289

290290
class EvoNorm2dS1a(EvoNorm2dS1):
@@ -299,7 +299,7 @@ def forward(self, x):
299299
x_dtype = x.dtype
300300
v_shape = (1, -1, 1, 1)
301301
x = self.act(x) / group_std(x, self.groups, self.eps)
302-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
302+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
303303

304304

305305
class EvoNorm2dS2(nn.Module):
@@ -332,7 +332,7 @@ def forward(self, x):
332332
v_shape = (1, -1, 1, 1)
333333
if self.apply_act:
334334
x = self.act(x) / group_rms(x, self.groups, self.eps)
335-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
335+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
336336

337337

338338
class EvoNorm2dS2a(EvoNorm2dS2):
@@ -347,4 +347,4 @@ def forward(self, x):
347347
x_dtype = x.dtype
348348
v_shape = (1, -1, 1, 1)
349349
x = self.act(x) / group_rms(x, self.groups, self.eps)
350-
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
350+
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)

timm/models/layers/norm_act.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import nn as nn
77
from torch.nn import functional as F
88

9+
from .trace_utils import _assert
910
from .create_act import get_act_layer
1011

1112

@@ -29,9 +30,10 @@ def __init__(
2930
else:
3031
self.act = nn.Identity()
3132

32-
def _forward_jit(self, x):
33-
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
34-
"""
33+
def forward(self, x):
34+
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
35+
_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
36+
3537
# exponential_average_factor is set to self.momentum
3638
# (when it is available) only so that it gets updated
3739
# in ONNX graph when this node is exported to ONNX.
@@ -63,7 +65,7 @@ def _forward_jit(self, x):
6365
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
6466
used for normalization (i.e. in eval mode when buffers are not None).
6567
"""
66-
return F.batch_norm(
68+
x = F.batch_norm(
6769
x,
6870
# If buffers are not to be tracked, ensure that they won't be updated
6971
self.running_mean if not self.training or self.track_running_stats else None,
@@ -74,17 +76,6 @@ def _forward_jit(self, x):
7476
exponential_average_factor,
7577
self.eps,
7678
)
77-
78-
@torch.jit.ignore
79-
def _forward_python(self, x):
80-
return super(BatchNormAct2d, self).forward(x)
81-
82-
def forward(self, x):
83-
# FIXME cannot call parent forward() and maintain jit.script compatibility?
84-
if torch.jit.is_scripting():
85-
x = self._forward_jit(x)
86-
else:
87-
x = self._forward_python(x)
8879
x = self.drop(x)
8980
x = self.act(x)
9081
return x

0 commit comments

Comments
 (0)