3838from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD , \
3939 OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4040from timm .layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_ , resample_patch_embed , \
41- resample_abs_pos_embed , RmsNorm , PatchDropout , use_fused_attn
41+ resample_abs_pos_embed , RmsNorm , PatchDropout , use_fused_attn , SwiGLUPacked
4242from ._builder import build_model_with_cfg
4343from ._manipulate import named_apply , checkpoint_seq , adapt_input_conv
4444from ._registry import generate_default_cfgs , register_model , register_model_deprecations
@@ -124,7 +124,8 @@ def __init__(
124124 init_values = None ,
125125 drop_path = 0. ,
126126 act_layer = nn .GELU ,
127- norm_layer = nn .LayerNorm
127+ norm_layer = nn .LayerNorm ,
128+ ffn_layer = Mlp ,
128129 ):
129130 super ().__init__ ()
130131 self .norm1 = norm_layer (dim )
@@ -141,7 +142,7 @@ def __init__(
141142 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
142143
143144 self .norm2 = norm_layer (dim )
144- self .mlp = Mlp (
145+ self .mlp = ffn_layer (
145146 in_features = dim ,
146147 hidden_features = int (dim * mlp_ratio ),
147148 act_layer = act_layer ,
@@ -170,7 +171,8 @@ def __init__(
170171 init_values = None ,
171172 drop_path = 0. ,
172173 act_layer = nn .GELU ,
173- norm_layer = nn .LayerNorm
174+ norm_layer = nn .LayerNorm ,
175+ ffn_layer = Mlp ,
174176 ):
175177 super ().__init__ ()
176178 self .init_values = init_values
@@ -187,7 +189,7 @@ def __init__(
187189 self .norm1 = norm_layer (dim )
188190 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
189191
190- self .mlp = Mlp (
192+ self .mlp = ffn_layer (
191193 in_features = dim ,
192194 hidden_features = int (dim * mlp_ratio ),
193195 act_layer = act_layer ,
@@ -229,7 +231,8 @@ def __init__(
229231 init_values = None ,
230232 drop_path = 0. ,
231233 act_layer = nn .GELU ,
232- norm_layer = nn .LayerNorm
234+ norm_layer = nn .LayerNorm ,
235+ ffn_layer = None , # NOTE: not used
233236 ):
234237 super ().__init__ ()
235238 assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
@@ -322,7 +325,8 @@ def __init__(
322325 attn_drop = 0. ,
323326 drop_path = 0. ,
324327 act_layer = nn .GELU ,
325- norm_layer = nn .LayerNorm
328+ norm_layer = nn .LayerNorm ,
329+ ffn_layer = Mlp ,
326330 ):
327331 super ().__init__ ()
328332 self .num_parallel = num_parallel
@@ -345,7 +349,7 @@ def __init__(
345349 ])))
346350 self .ffns .append (nn .Sequential (OrderedDict ([
347351 ('norm' , norm_layer (dim )),
348- ('mlp' , Mlp (
352+ ('mlp' , ffn_layer (
349353 dim ,
350354 hidden_features = int (dim * mlp_ratio ),
351355 act_layer = act_layer ,
@@ -409,6 +413,7 @@ def __init__(
409413 norm_layer : Optional [Callable ] = None ,
410414 act_layer : Optional [Callable ] = None ,
411415 block_fn : Callable = Block ,
416+ ffn_layer : Callable = Mlp ,
412417 ):
413418 """
414419 Args:
@@ -484,7 +489,8 @@ def __init__(
484489 attn_drop = attn_drop_rate ,
485490 drop_path = dpr [i ],
486491 norm_layer = norm_layer ,
487- act_layer = act_layer
492+ act_layer = act_layer ,
493+ ffn_layer = ffn_layer ,
488494 )
489495 for i in range (depth )])
490496 self .norm = norm_layer (embed_dim ) if not use_fc_norm else nn .Identity ()
@@ -808,6 +814,25 @@ def _convert_openai_clip(state_dict, model):
808814 return out_dict
809815
810816
817+ def _convert_dinov2 (state_dict , model ):
818+ import re
819+
820+ out_dict = {}
821+
822+ for k , v in state_dict .items ():
823+ if k == "mask_token" :
824+ continue
825+ elif re .match (r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)" , k ):
826+ out_dict [k .replace ("w12" , "fc1" )] = v
827+ continue
828+ elif re .match (r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)" , k ):
829+ out_dict [k .replace ("w3" , "fc2" )] = v
830+ continue
831+
832+ out_dict [k ] = v
833+
834+ return out_dict
835+
811836def checkpoint_filter_fn (
812837 state_dict ,
813838 model ,
@@ -824,6 +849,9 @@ def checkpoint_filter_fn(
824849 if 'visual.class_embedding' in state_dict :
825850 return _convert_openai_clip (state_dict , model )
826851
852+ if "mask_token" in state_dict :
853+ return _convert_dinov2 (state_dict , model )
854+
827855 for k , v in state_dict .items ():
828856 if 'patch_embed.proj.weight' in k :
829857 O , I , H , W = model .patch_embed .proj .weight .shape
@@ -1043,6 +1071,20 @@ def _cfg(url='', **kwargs):
10431071 url = 'https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth' ,
10441072 hf_hub_id = 'timm/' ,
10451073 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 ),
1074+
1075+ # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
1076+ 'vit_small_patch14_dinov2' : _cfg (
1077+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth' ,
1078+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 , input_size = (3 , 518 , 518 )),
1079+ 'vit_base_patch14_dinov2' : _cfg (
1080+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth' ,
1081+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 , input_size = (3 , 518 , 518 )),
1082+ 'vit_large_patch14_dinov2' : _cfg (
1083+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth' ,
1084+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 , input_size = (3 , 518 , 518 )),
1085+ 'vit_giant_patch14_dinov2' : _cfg (
1086+ url = 'https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth' ,
1087+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , num_classes = 0 , input_size = (3 , 518 , 518 )),
10461088
10471089 # ViT ImageNet-21K-P pretraining by MILL
10481090 'vit_base_patch16_224_miil.in21k' : _cfg (
@@ -1855,6 +1897,66 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
18551897 return model
18561898
18571899
1900+ @register_model
1901+ def vit_small_patch14_dinov2 (pretrained = False , ** kwargs ):
1902+ """ ViT-S/14 for DINOv2
1903+ """
1904+ model_args = dict (
1905+ patch_size = 14 , embed_dim = 384 , depth = 12 , num_heads = 6 ,
1906+ init_values = 1.0 , img_size = 518 ,
1907+ )
1908+
1909+ model = _create_vision_transformer (
1910+ 'vit_small_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1911+ return model
1912+
1913+
1914+ @register_model
1915+ def vit_base_patch14_dinov2 (pretrained = False , ** kwargs ):
1916+ """ ViT-B/14 for DINOv2
1917+ """
1918+ model_args = dict (
1919+ patch_size = 14 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
1920+ init_values = 1.0 , img_size = 518 ,
1921+ )
1922+
1923+ model = _create_vision_transformer (
1924+ 'vit_base_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1925+ return model
1926+
1927+
1928+ @register_model
1929+ def vit_large_patch14_dinov2 (pretrained = False , ** kwargs ):
1930+ """ ViT-L/14 for DINOv2
1931+ """
1932+ model_args = dict (
1933+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
1934+ init_values = 1.0 , img_size = 518 ,
1935+ )
1936+
1937+ model = _create_vision_transformer (
1938+ 'vit_large_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1939+ return model
1940+
1941+ @register_model
1942+ def vit_giant_patch14_dinov2 (pretrained = False , ** kwargs ):
1943+ """ ViT-G/14 for DINOv2
1944+ """
1945+
1946+ # The hidden_features of SwiGLU is calculated by:
1947+ # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
1948+ # When embed_dim=1536, hidden_features=4096
1949+ # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
1950+
1951+ model_args = dict (
1952+ patch_size = 14 , embed_dim = 1536 , depth = 40 , num_heads = 24 , init_values = 1.0 ,
1953+ mlp_ratio = 2.66667 * 2 , ffn_layer = SwiGLUPacked , img_size = 518 , act_layer = nn .SiLU
1954+ )
1955+
1956+ model = _create_vision_transformer (
1957+ 'vit_giant_patch14_dinov2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1958+ return model
1959+
18581960register_model_deprecations (__name__ , {
18591961 'vit_tiny_patch16_224_in21k' : 'vit_tiny_patch16_224.augreg_in21k' ,
18601962 'vit_small_patch32_224_in21k' : 'vit_small_patch32_224.augreg_in21k' ,
0 commit comments