Skip to content

Commit 3b4868f

Browse files
committed
A few more additions to Gluon Xception models to match interface of others.
1 parent 4d505e0 commit 3b4868f

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
120120
| gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 | |
121121
| gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 | |
122122
| gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 | |
123+
| gluon_xception65 | 79.604 (20.396) | 94.748 (5.252) | 39.92 | bicubic | 299 | |
123124
| gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 | |
124125
| gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | 224 | |
125126
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | 224 | |

timm/models/gluon_xception.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
2424
'input_size': (3, 299, 299),
2525
'crop_pct': 0.875,
26+
'pool_size': (10, 10),
2627
'interpolation': 'bicubic',
2728
'mean': IMAGENET_DEFAULT_MEAN,
2829
'std': IMAGENET_DEFAULT_STD,
@@ -35,6 +36,7 @@
3536
'url': '',
3637
'input_size': (3, 299, 299),
3738
'crop_pct': 0.875,
39+
'pool_size': (10, 10),
3840
'interpolation': 'bicubic',
3941
'mean': IMAGENET_DEFAULT_MEAN,
4042
'std': IMAGENET_DEFAULT_STD,
@@ -181,7 +183,9 @@ class Xception65(nn.Module):
181183
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
182184
norm_kwargs=None, drop_rate=0., global_pool='avg'):
183185
super(Xception65, self).__init__()
186+
self.num_classes = num_classes
184187
self.drop_rate = drop_rate
188+
self.global_pool = global_pool
185189
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
186190
if output_stride == 32:
187191
entry_block3_stride = 2
@@ -240,14 +244,26 @@ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn
240244
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
241245
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
242246

247+
self.num_features = 2048
243248
self.conv5 = SeparableConv2d(
244-
1536, 2048, 3, stride=1, dilation=exit_block_dilations[1],
249+
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
245250
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
246-
self.bn5 = norm_layer(num_features=2048, **norm_kwargs)
247-
self.avgpool = nn.AdaptiveAvgPool2d(1)
248-
self.fc = nn.Linear(in_features=2048, out_features=num_classes)
251+
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
252+
self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes)
253+
254+
def get_classifier(self):
255+
return self.fc
256+
257+
def reset_classifier(self, num_classes, global_pool='avg'):
258+
self.num_classes = num_classes
259+
self.global_pool = global_pool
260+
del self.fc
261+
if num_classes:
262+
self.fc = nn.Linear(self.num_features, num_classes)
263+
else:
264+
self.fc = None
249265

250-
def forward(self, x):
266+
def forward_features(self, x, pool=True):
251267
# Entry flow
252268
x = self.conv1(x)
253269
x = self.bn1(x)
@@ -284,10 +300,15 @@ def forward(self, x):
284300
x = self.bn5(x)
285301
x = self.relu(x)
286302

287-
x = self.avgpool(x)
288-
x = x.view(x.size(0), -1)
289-
if self.drop_rate > 0.:
290-
x = F.dropout(x, p=self.drop_rate, training=self.training)
303+
if pool:
304+
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
305+
x = x.view(x.size(0), -1)
306+
return x
307+
308+
def forward(self, x):
309+
x = self.forward_features(x)
310+
if self.drop_rate:
311+
F.dropout(x, self.drop_rate, training=self.training)
291312
x = self.fc(x)
292313
return x
293314

@@ -299,7 +320,9 @@ class Xception71(nn.Module):
299320
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
300321
norm_kwargs=None, drop_rate=0., global_pool='avg'):
301322
super(Xception71, self).__init__()
323+
self.num_classes = num_classes
302324
self.drop_rate = drop_rate
325+
self.global_pool = global_pool
303326
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
304327
if output_stride == 32:
305328
entry_block3_stride = 2
@@ -365,14 +388,26 @@ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn
365388
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
366389
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
367390

391+
self.num_features = 2048
368392
self.conv5 = SeparableConv2d(
369-
1536, 2048, 3, stride=1, dilation=exit_block_dilations[1],
393+
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
370394
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
371-
self.bn5 = norm_layer(num_features=2048, **norm_kwargs)
372-
self.avgpool = nn.AdaptiveAvgPool2d(1)
373-
self.fc = nn.Linear(in_features=2048, out_features=num_classes)
395+
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
396+
self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes)
397+
398+
def get_classifier(self):
399+
return self.fc
400+
401+
def reset_classifier(self, num_classes, global_pool='avg'):
402+
self.num_classes = num_classes
403+
self.global_pool = global_pool
404+
del self.fc
405+
if num_classes:
406+
self.fc = nn.Linear(self.num_features, num_classes)
407+
else:
408+
self.fc = None
374409

375-
def forward(self, x):
410+
def forward_features(self, x, pool=True):
376411
# Entry flow
377412
x = self.conv1(x)
378413
x = self.bn1(x)
@@ -409,16 +444,23 @@ def forward(self, x):
409444
x = self.bn5(x)
410445
x = self.relu(x)
411446

412-
x = self.avgpool(x)
413-
x = x.view(x.size(0), -1)
414-
if self.drop_rate > 0.:
415-
x = F.dropout(x, p=self.drop_rate, training=self.training)
447+
if pool:
448+
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
449+
x = x.view(x.size(0), -1)
450+
return x
451+
452+
def forward(self, x):
453+
x = self.forward_features(x)
454+
if self.drop_rate:
455+
F.dropout(x, self.drop_rate, training=self.training)
416456
x = self.fc(x)
417457
return x
418458

419459

420460
@register_model
421461
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
462+
""" Modified Aligned Xception-65
463+
"""
422464
default_cfg = default_cfgs['gluon_xception65']
423465
model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs)
424466
model.default_cfg = default_cfg
@@ -429,6 +471,8 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
429471

430472
@register_model
431473
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
474+
""" Modified Aligned Xception-71
475+
"""
432476
default_cfg = default_cfgs['gluon_xception71']
433477
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
434478
model.default_cfg = default_cfg

0 commit comments

Comments
 (0)