@@ -82,19 +82,30 @@ def fuse(self):
8282
8383
8484class RepVggDw (nn .Module ):
85- def __init__ (self , ed , kernel_size ):
85+ def __init__ (self , ed , kernel_size , legacy = False ):
8686 super ().__init__ ()
8787 self .conv = ConvNorm (ed , ed , kernel_size , 1 , (kernel_size - 1 ) // 2 , groups = ed )
88- self .conv1 = ConvNorm (ed , ed , 1 , 1 , 0 , groups = ed )
88+ if legacy :
89+ self .conv1 = ConvNorm (ed , ed , 1 , 1 , 0 , groups = ed )
90+ # Make torchscript happy.
91+ self .bn = nn .Identity ()
92+ else :
93+ self .conv1 = nn .Conv2d (ed , ed , 1 , 1 , 0 , groups = ed )
94+ self .bn = nn .BatchNorm2d (ed )
8995 self .dim = ed
96+ self .legacy = legacy
9097
9198 def forward (self , x ):
92- return self .conv (x ) + self .conv1 (x ) + x
99+ return self .bn ( self . conv (x ) + self .conv1 (x ) + x )
93100
94101 @torch .no_grad ()
95102 def fuse (self ):
96103 conv = self .conv .fuse ()
97- conv1 = self .conv1 .fuse ()
104+
105+ if self .legacy :
106+ conv1 = self .conv1 .fuse ()
107+ else :
108+ conv1 = self .conv1
98109
99110 conv_w = conv .weight
100111 conv_b = conv .bias
@@ -112,6 +123,14 @@ def fuse(self):
112123
113124 conv .weight .data .copy_ (final_conv_w )
114125 conv .bias .data .copy_ (final_conv_b )
126+
127+ if not self .legacy :
128+ bn = self .bn
129+ w = bn .weight / (bn .running_var + bn .eps ) ** 0.5
130+ w = conv .weight * w [:, None , None , None ]
131+ b = bn .bias + (conv .bias - bn .running_mean ) * bn .weight / (bn .running_var + bn .eps ) ** 0.5
132+ conv .weight .data .copy_ (w )
133+ conv .bias .data .copy_ (b )
115134 return conv
116135
117136
@@ -127,10 +146,10 @@ def forward(self, x):
127146
128147
129148class RepViTBlock (nn .Module ):
130- def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer ):
149+ def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer , legacy = False ):
131150 super (RepViTBlock , self ).__init__ ()
132151
133- self .token_mixer = RepVggDw (in_dim , kernel_size )
152+ self .token_mixer = RepVggDw (in_dim , kernel_size , legacy )
134153 self .se = SqueezeExcite (in_dim , 0.25 ) if use_se else nn .Identity ()
135154 self .channel_mixer = RepVitMlp (in_dim , in_dim * mlp_ratio , act_layer )
136155
@@ -155,9 +174,9 @@ def forward(self, x):
155174
156175
157176class RepVitDownsample (nn .Module ):
158- def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer ):
177+ def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer , legacy = False ):
159178 super ().__init__ ()
160- self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer )
179+ self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer , legacy = legacy )
161180 self .spatial_downsample = ConvNorm (in_dim , in_dim , kernel_size , 2 , (kernel_size - 1 ) // 2 , groups = in_dim )
162181 self .channel_downsample = ConvNorm (in_dim , out_dim , 1 , 1 )
163182 self .ffn = RepVitMlp (out_dim , out_dim * mlp_ratio , act_layer )
@@ -172,7 +191,7 @@ def forward(self, x):
172191
173192
174193class RepVitClassifier (nn .Module ):
175- def __init__ (self , dim , num_classes , distillation = False , drop = 0. ):
194+ def __init__ (self , dim , num_classes , distillation = False , drop = 0.0 ):
176195 super ().__init__ ()
177196 self .head_drop = nn .Dropout (drop )
178197 self .head = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -211,18 +230,18 @@ def fuse(self):
211230
212231
213232class RepVitStage (nn .Module ):
214- def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True ):
233+ def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True , legacy = False ):
215234 super ().__init__ ()
216235 if downsample :
217- self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
236+ self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer , legacy )
218237 else :
219238 assert in_dim == out_dim
220239 self .downsample = nn .Identity ()
221240
222241 blocks = []
223242 use_se = True
224243 for _ in range (depth ):
225- blocks .append (RepViTBlock (out_dim , mlp_ratio , kernel_size , use_se , act_layer ))
244+ blocks .append (RepViTBlock (out_dim , mlp_ratio , kernel_size , use_se , act_layer , legacy ))
226245 use_se = not use_se
227246
228247 self .blocks = nn .Sequential (* blocks )
@@ -246,7 +265,8 @@ def __init__(
246265 num_classes = 1000 ,
247266 act_layer = nn .GELU ,
248267 distillation = True ,
249- drop_rate = 0. ,
268+ drop_rate = 0.0 ,
269+ legacy = False ,
250270 ):
251271 super (RepVit , self ).__init__ ()
252272 self .grad_checkpointing = False
@@ -275,6 +295,7 @@ def __init__(
275295 act_layer = act_layer ,
276296 kernel_size = kernel_size ,
277297 downsample = downsample ,
298+ legacy = legacy ,
278299 )
279300 )
280301 stage_stride = 2 if downsample else 1
@@ -290,12 +311,9 @@ def __init__(
290311
291312 @torch .jit .ignore
292313 def group_matcher (self , coarse = False ):
293- matcher = dict (
294- stem = r'^stem' , # stem and embed
295- blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]
296- )
314+ matcher = dict (stem = r'^stem' , blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]) # stem and embed
297315 return matcher
298-
316+
299317 @torch .jit .ignore
300318 def set_grad_checkpointing (self , enable = True ):
301319 self .grad_checkpointing = enable
@@ -369,15 +387,42 @@ def _cfg(url='', **kwargs):
369387 {
370388 'repvit_m1.dist_in1k' : _cfg (
371389 hf_hub_id = 'timm/' ,
372- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
373390 ),
374391 'repvit_m2.dist_in1k' : _cfg (
375392 hf_hub_id = 'timm/' ,
376- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
377393 ),
378394 'repvit_m3.dist_in1k' : _cfg (
379395 hf_hub_id = 'timm/' ,
380- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
396+ ),
397+ 'repvit_m0_9.dist_300e_in1k' : _cfg (
398+ hf_hub_id = 'timm/' ,
399+ ),
400+ 'repvit_m0_9.dist_450e_in1k' : _cfg (
401+ hf_hub_id = 'timm/' ,
402+ ),
403+ 'repvit_m1_0.dist_300e_in1k' : _cfg (
404+ hf_hub_id = 'timm/' ,
405+ ),
406+ 'repvit_m1_0.dist_450e_in1k' : _cfg (
407+ hf_hub_id = 'timm/' ,
408+ ),
409+ 'repvit_m1_1.dist_300e_in1k' : _cfg (
410+ hf_hub_id = 'timm/' ,
411+ ),
412+ 'repvit_m1_1.dist_450e_in1k' : _cfg (
413+ hf_hub_id = 'timm/' ,
414+ ),
415+ 'repvit_m1_5.dist_300e_in1k' : _cfg (
416+ hf_hub_id = 'timm/' ,
417+ ),
418+ 'repvit_m1_5.dist_450e_in1k' : _cfg (
419+ hf_hub_id = 'timm/' ,
420+ ),
421+ 'repvit_m2_3.dist_300e_in1k' : _cfg (
422+ hf_hub_id = 'timm/' ,
423+ ),
424+ 'repvit_m2_3.dist_450e_in1k' : _cfg (
425+ hf_hub_id = 'timm/' ,
381426 ),
382427 }
383428)
@@ -386,7 +431,9 @@ def _cfg(url='', **kwargs):
386431def _create_repvit (variant , pretrained = False , ** kwargs ):
387432 out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 ))
388433 model = build_model_with_cfg (
389- RepVit , variant , pretrained ,
434+ RepVit ,
435+ variant ,
436+ pretrained ,
390437 feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
391438 ** kwargs ,
392439 )
@@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs):
398445 """
399446 Constructs a RepViT-M1 model
400447 """
401- model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ))
448+ model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ), legacy = True )
402449 return _create_repvit ('repvit_m1' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
403450
404451
@@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs):
407454 """
408455 Constructs a RepViT-M2 model
409456 """
410- model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ))
457+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ), legacy = True )
411458 return _create_repvit ('repvit_m2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
412459
413460
@@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs):
416463 """
417464 Constructs a RepViT-M3 model
418465 """
419- model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 18 , 2 ))
466+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 18 , 2 ), legacy = True )
420467 return _create_repvit ('repvit_m3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
468+
469+
470+ @register_model
471+ def repvit_m0_9 (pretrained = False , ** kwargs ):
472+ """
473+ Constructs a RepViT-M0.9 model
474+ """
475+ model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ))
476+ return _create_repvit ('repvit_m0_9' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
477+
478+
479+ @register_model
480+ def repvit_m1_0 (pretrained = False , ** kwargs ):
481+ """
482+ Constructs a RepViT-M1.0 model
483+ """
484+ model_args = dict (embed_dim = (56 , 112 , 224 , 448 ), depth = (2 , 2 , 14 , 2 ))
485+ return _create_repvit ('repvit_m1_0' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
486+
487+
488+ @register_model
489+ def repvit_m1_1 (pretrained = False , ** kwargs ):
490+ """
491+ Constructs a RepViT-M1.1 model
492+ """
493+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ))
494+ return _create_repvit ('repvit_m1_1' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
495+
496+
497+ @register_model
498+ def repvit_m1_5 (pretrained = False , ** kwargs ):
499+ """
500+ Constructs a RepViT-M1.5 model
501+ """
502+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 24 , 4 ))
503+ return _create_repvit ('repvit_m1_5' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
504+
505+
506+ @register_model
507+ def repvit_m2_3 (pretrained = False , ** kwargs ):
508+ """
509+ Constructs a RepViT-M2.3 model
510+ """
511+ model_args = dict (embed_dim = (80 , 160 , 320 , 640 ), depth = (6 , 6 , 34 , 2 ))
512+ return _create_repvit ('repvit_m2_3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments