Skip to content

Commit 6853b07

Browse files
committed
Improve RegVGG block identity/vs non for clariy and fix attn usage. Add comments.
1 parent 0356e77 commit 6853b07

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

timm/models/byobnet.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
import torch
3434
import torch.nn as nn
35-
import torch.nn.functional as F
3635

3736
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3837
from .helpers import build_model_with_cfg
@@ -443,7 +442,7 @@ class RepVggBlock(nn.Module):
443442
444443
Adapted from impl at https://github.com/DingXiaoH/RepVGG
445444
446-
This version does not currently support the deploy optimization. It is currently fixed in 'train' model.
445+
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
447446
"""
448447

449448
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
@@ -461,8 +460,8 @@ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bo
461460
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
462461
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
463462
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
464-
self.attn = None if attn_layer is None else attn_layer(out_chs)
465-
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
463+
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
464+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
466465
self.act = act_layer(inplace=True)
467466

468467
def init_weights(self, zero_init_last_bn=False):
@@ -474,14 +473,14 @@ def init_weights(self, zero_init_last_bn=False):
474473

475474
def forward(self, x):
476475
if self.identity is None:
477-
identity = 0
476+
x = self.conv_1x1(x) + self.conv_kxk(x)
478477
else:
479478
identity = self.identity(x)
480-
x = self.conv_1x1(x) + self.conv_kxk(x)
481-
if self.attn is not None:
482-
x = self.attn(x)
483-
x = self.drop_path(x)
484-
x = self.act(x + identity)
479+
x = self.conv_1x1(x) + self.conv_kxk(x)
480+
x = self.drop_path(x) # not in the paper / official impl, experimental
481+
x = x + identity
482+
x = self.attn(x) # no attn in the paper / official impl, experimental
483+
x = self.act(x)
485484
return x
486485

487486

@@ -654,54 +653,87 @@ def _create_byobnet(variant, pretrained=False, **kwargs):
654653

655654
@register_model
656655
def gernet_l(pretrained=False, **kwargs):
656+
""" GEResNet-Large (GENet-Large from official impl)
657+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
658+
"""
657659
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
658660

659661

660662
@register_model
661663
def gernet_m(pretrained=False, **kwargs):
664+
""" GEResNet-Medium (GENet-Normal from official impl)
665+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
666+
"""
662667
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
663668

664669

665670
@register_model
666671
def gernet_s(pretrained=False, **kwargs):
672+
""" EResNet-Small (GENet-Small from official impl)
673+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
674+
"""
667675
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
668676

669677

670678
@register_model
671679
def repvgg_a2(pretrained=False, **kwargs):
680+
""" RepVGG-A2
681+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
682+
"""
672683
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
673684

674685

675686
@register_model
676687
def repvgg_b0(pretrained=False, **kwargs):
688+
""" RepVGG-B0
689+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
690+
"""
677691
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
678692

679693

680694
@register_model
681695
def repvgg_b1(pretrained=False, **kwargs):
696+
""" RepVGG-B1
697+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
698+
"""
682699
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
683700

684701

685702
@register_model
686703
def repvgg_b1g4(pretrained=False, **kwargs):
704+
""" RepVGG-B1g4
705+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
706+
"""
687707
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
688708

689709

690710
@register_model
691711
def repvgg_b2(pretrained=False, **kwargs):
712+
""" RepVGG-B2
713+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
714+
"""
692715
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
693716

694717

695718
@register_model
696719
def repvgg_b2g4(pretrained=False, **kwargs):
720+
""" RepVGG-B2g4
721+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
722+
"""
697723
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
698724

699725

700726
@register_model
701727
def repvgg_b3(pretrained=False, **kwargs):
728+
""" RepVGG-B3
729+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
730+
"""
702731
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
703732

704733

705734
@register_model
706735
def repvgg_b3g4(pretrained=False, **kwargs):
736+
""" RepVGG-B3g4
737+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
738+
"""
707739
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)