@@ -118,6 +118,17 @@ def _cfg(url='', **kwargs):
118118 'vit_deit_base_distilled_patch16_384' : _cfg (
119119 url = 'https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth' ,
120120 input_size = (3 , 384 , 384 ), crop_pct = 1.0 , classifier = ('head' , 'head_dist' )),
121+
122+ # ViT ImageNet-21K-P pretraining
123+ 'vit_base_patch16_224_21k_miil' : _cfg (
124+ url = 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_21k_miil.pth' ,
125+ mean = (0 , 0 , 0 ), std = (1 , 1 , 1 ), crop_pct = 0.875 , interpolation = 'bilinear' , num_classes = 11221 ,
126+ ),
127+ 'vit_base_patch16_224_1k_miil' : _cfg (
128+ url = 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
129+ '/vit_base_patch16_224_1k_miil_84_4.pth' ,
130+ mean = (0 , 0 , 0 ), std = (1 , 1 , 1 ), crop_pct = 0.875 , interpolation = 'bilinear' ,
131+ ),
121132}
122133
123134
@@ -155,7 +166,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
155166 def forward (self , x ):
156167 B , N , C = x .shape
157168 qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
158- q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # make torchscript happy (cannot use tensor as tuple)
169+ q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # make torchscript happy (cannot use tensor as tuple)
159170
160171 attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
161172 attn = attn .softmax (dim = - 1 )
@@ -652,7 +663,7 @@ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
652663 """
653664 model_kwargs = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
654665 model = _create_vision_transformer (
655- 'vit_deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
666+ 'vit_deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
656667 return model
657668
658669
@@ -663,7 +674,7 @@ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
663674 """
664675 model_kwargs = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
665676 model = _create_vision_transformer (
666- 'vit_deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
677+ 'vit_deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
667678 return model
668679
669680
@@ -674,7 +685,7 @@ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
674685 """
675686 model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
676687 model = _create_vision_transformer (
677- 'vit_deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
688+ 'vit_deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
678689 return model
679690
680691
@@ -687,3 +698,21 @@ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
687698 model = _create_vision_transformer (
688699 'vit_deit_base_distilled_patch16_384' , pretrained = pretrained , distilled = True , ** model_kwargs )
689700 return model
701+
702+ @register_model
703+ def vit_base_patch16_224_21k_miil (pretrained = False , ** kwargs ):
704+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
705+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
706+ """
707+ model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , ** kwargs )
708+ model = _create_vision_transformer ('vit_base_patch16_224_21k_miil' , pretrained = pretrained , ** model_kwargs )
709+ return model
710+
711+ @register_model
712+ def vit_base_patch16_224_1k_miil (pretrained = False , ** kwargs ):
713+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
714+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
715+ """
716+ model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , qkv_bias = False , ** kwargs )
717+ model = _create_vision_transformer ('vit_base_patch16_224_1k_miil_84_4' , pretrained = pretrained , ** model_kwargs )
718+ return model
0 commit comments