Skip to content

Commit 0968bde

Browse files
author
talrid
committed
vit, tresnet and mobilenetV3 ImageNet-21K-P weights
1 parent b81cd75 commit 0968bde

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

timm/models/mobilenetv3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def _cfg(url='', **kwargs):
3939
'mobilenetv3_large_100': _cfg(
4040
interpolation='bicubic',
4141
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
42+
'mobilenetv3_large_100_1k_miil_77_9': _cfg(
43+
interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1),
44+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_77_9.pth'),
45+
'mobilenetv3_large_100_21k_miil': _cfg(
46+
interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1),
47+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_21k_miil.pth', num_classes=11221),
4248
'mobilenetv3_small_075': _cfg(url=''),
4349
'mobilenetv3_small_100': _cfg(url=''),
4450
'mobilenetv3_rw': _cfg(
@@ -367,6 +373,20 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
367373
return model
368374

369375

376+
@register_model
377+
def mobilenetv3_large_100_1k_miil(pretrained=False, **kwargs):
378+
""" MobileNet V3 """
379+
model = _gen_mobilenet_v3('mobilenetv3_large_100_1k_miil_77_9', 1.0, pretrained=pretrained, **kwargs)
380+
return model
381+
382+
383+
@register_model
384+
def mobilenetv3_large_100_21k_miil(pretrained=False, **kwargs):
385+
""" MobileNet V3 """
386+
model = _gen_mobilenet_v3('mobilenetv3_large_100_21k_miil', 1.0, pretrained=pretrained, **kwargs)
387+
return model
388+
389+
370390
@register_model
371391
def mobilenetv3_small_075(pretrained=False, **kwargs):
372392
""" MobileNet V3 """

timm/models/tresnet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def _cfg(url='', **kwargs):
3232

3333
default_cfgs = {
3434
'tresnet_m': _cfg(
35-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'),
35+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_1k_miil_83_1.pth'),
36+
'tresnet_m_21k_miil': _cfg(
37+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_miil_21k.pth', num_classes=11221),
3638
'tresnet_l': _cfg(
3739
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
3840
'tresnet_xl': _cfg(
@@ -264,6 +266,10 @@ def tresnet_m(pretrained=False, **kwargs):
264266
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
265267
return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs)
266268

269+
@register_model
270+
def tresnet_m_21k_miil(pretrained=False, **kwargs):
271+
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
272+
return _create_tresnet('tresnet_m_21k_miil', pretrained=pretrained, **model_kwargs)
267273

268274
@register_model
269275
def tresnet_l(pretrained=False, **kwargs):

timm/models/vision_transformer.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ def _cfg(url='', **kwargs):
118118
'vit_deit_base_distilled_patch16_384': _cfg(
119119
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
120120
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
121+
122+
# ViT ImageNet-21K-P pretraining
123+
'vit_base_patch16_224_21k_miil': _cfg(
124+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_21k_miil.pth',
125+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
126+
),
127+
'vit_base_patch16_224_1k_miil': _cfg(
128+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
129+
'/vit_base_patch16_224_1k_miil_84_4.pth',
130+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
131+
),
121132
}
122133

123134

@@ -155,7 +166,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
155166
def forward(self, x):
156167
B, N, C = x.shape
157168
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
158-
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
169+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
159170

160171
attn = (q @ k.transpose(-2, -1)) * self.scale
161172
attn = attn.softmax(dim=-1)
@@ -652,7 +663,7 @@ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
652663
"""
653664
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
654665
model = _create_vision_transformer(
655-
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
666+
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
656667
return model
657668

658669

@@ -663,7 +674,7 @@ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
663674
"""
664675
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
665676
model = _create_vision_transformer(
666-
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
677+
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
667678
return model
668679

669680

@@ -674,7 +685,7 @@ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
674685
"""
675686
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
676687
model = _create_vision_transformer(
677-
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
688+
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
678689
return model
679690

680691

@@ -687,3 +698,21 @@ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
687698
model = _create_vision_transformer(
688699
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
689700
return model
701+
702+
@register_model
703+
def vit_base_patch16_224_21k_miil(pretrained=False, **kwargs):
704+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
705+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
706+
"""
707+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
708+
model = _create_vision_transformer('vit_base_patch16_224_21k_miil', pretrained=pretrained, **model_kwargs)
709+
return model
710+
711+
@register_model
712+
def vit_base_patch16_224_1k_miil(pretrained=False, **kwargs):
713+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
714+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
715+
"""
716+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
717+
model = _create_vision_transformer('vit_base_patch16_224_1k_miil_84_4', pretrained=pretrained, **model_kwargs)
718+
return model

0 commit comments

Comments
 (0)