Skip to content

Commit 5cc87e6

Browse files
authored
Add dinov2 pretrained models (#1797)
* add dinov2 small, base, and large * fix input size * fix swiglu & dinov2 vit giant * use SwiGLUPacked to replace GluMlp * clean up & add ffn_layer placeholder for ParallelScalingBlock
1 parent af48246 commit 5cc87e6

File tree

2 files changed

+113
-11
lines changed

2 files changed

+113
-11
lines changed

timm/layers/mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def __init__(
133133

134134
def init_weights(self):
135135
# override init of fc1 w/ gate portion set to weight near zero, bias=1
136-
nn.init.ones_(self.fc1a.bias)
137-
nn.init.normal_(self.fc1a.weight, std=1e-6)
136+
nn.init.ones_(self.fc1_g.bias)
137+
nn.init.normal_(self.fc1_g.weight, std=1e-6)
138138

139139
def forward(self, x):
140140
x_gate = self.fc1_g(x)

timm/models/vision_transformer.py

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
3939
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4040
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
41-
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn
41+
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
4242
from ._builder import build_model_with_cfg
4343
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
4444
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@@ -124,7 +124,8 @@ def __init__(
124124
init_values=None,
125125
drop_path=0.,
126126
act_layer=nn.GELU,
127-
norm_layer=nn.LayerNorm
127+
norm_layer=nn.LayerNorm,
128+
ffn_layer=Mlp,
128129
):
129130
super().__init__()
130131
self.norm1 = norm_layer(dim)
@@ -141,7 +142,7 @@ def __init__(
141142
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
142143

143144
self.norm2 = norm_layer(dim)
144-
self.mlp = Mlp(
145+
self.mlp = ffn_layer(
145146
in_features=dim,
146147
hidden_features=int(dim * mlp_ratio),
147148
act_layer=act_layer,
@@ -170,7 +171,8 @@ def __init__(
170171
init_values=None,
171172
drop_path=0.,
172173
act_layer=nn.GELU,
173-
norm_layer=nn.LayerNorm
174+
norm_layer=nn.LayerNorm,
175+
ffn_layer=Mlp,
174176
):
175177
super().__init__()
176178
self.init_values = init_values
@@ -187,7 +189,7 @@ def __init__(
187189
self.norm1 = norm_layer(dim)
188190
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
189191

190-
self.mlp = Mlp(
192+
self.mlp = ffn_layer(
191193
in_features=dim,
192194
hidden_features=int(dim * mlp_ratio),
193195
act_layer=act_layer,
@@ -229,7 +231,8 @@ def __init__(
229231
init_values=None,
230232
drop_path=0.,
231233
act_layer=nn.GELU,
232-
norm_layer=nn.LayerNorm
234+
norm_layer=nn.LayerNorm,
235+
ffn_layer=None, # NOTE: not used
233236
):
234237
super().__init__()
235238
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@@ -322,7 +325,8 @@ def __init__(
322325
attn_drop=0.,
323326
drop_path=0.,
324327
act_layer=nn.GELU,
325-
norm_layer=nn.LayerNorm
328+
norm_layer=nn.LayerNorm,
329+
ffn_layer=Mlp,
326330
):
327331
super().__init__()
328332
self.num_parallel = num_parallel
@@ -345,7 +349,7 @@ def __init__(
345349
])))
346350
self.ffns.append(nn.Sequential(OrderedDict([
347351
('norm', norm_layer(dim)),
348-
('mlp', Mlp(
352+
('mlp', ffn_layer(
349353
dim,
350354
hidden_features=int(dim * mlp_ratio),
351355
act_layer=act_layer,
@@ -409,6 +413,7 @@ def __init__(
409413
norm_layer: Optional[Callable] = None,
410414
act_layer: Optional[Callable] = None,
411415
block_fn: Callable = Block,
416+
ffn_layer: Callable = Mlp,
412417
):
413418
"""
414419
Args:
@@ -484,7 +489,8 @@ def __init__(
484489
attn_drop=attn_drop_rate,
485490
drop_path=dpr[i],
486491
norm_layer=norm_layer,
487-
act_layer=act_layer
492+
act_layer=act_layer,
493+
ffn_layer=ffn_layer,
488494
)
489495
for i in range(depth)])
490496
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@@ -808,6 +814,25 @@ def _convert_openai_clip(state_dict, model):
808814
return out_dict
809815

810816

817+
def _convert_dinov2(state_dict, model):
818+
import re
819+
820+
out_dict = {}
821+
822+
for k, v in state_dict.items():
823+
if k == "mask_token":
824+
continue
825+
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
826+
out_dict[k.replace("w12", "fc1")] = v
827+
continue
828+
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
829+
out_dict[k.replace("w3", "fc2")] = v
830+
continue
831+
832+
out_dict[k] = v
833+
834+
return out_dict
835+
811836
def checkpoint_filter_fn(
812837
state_dict,
813838
model,
@@ -824,6 +849,9 @@ def checkpoint_filter_fn(
824849
if 'visual.class_embedding' in state_dict:
825850
return _convert_openai_clip(state_dict, model)
826851

852+
if "mask_token" in state_dict:
853+
return _convert_dinov2(state_dict, model)
854+
827855
for k, v in state_dict.items():
828856
if 'patch_embed.proj.weight' in k:
829857
O, I, H, W = model.patch_embed.proj.weight.shape
@@ -1043,6 +1071,20 @@ def _cfg(url='', **kwargs):
10431071
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
10441072
hf_hub_id='timm/',
10451073
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1074+
1075+
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
1076+
'vit_small_patch14_dinov2': _cfg(
1077+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
1078+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
1079+
'vit_base_patch14_dinov2': _cfg(
1080+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
1081+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
1082+
'vit_large_patch14_dinov2': _cfg(
1083+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
1084+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
1085+
'vit_giant_patch14_dinov2': _cfg(
1086+
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
1087+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
10461088

10471089
# ViT ImageNet-21K-P pretraining by MILL
10481090
'vit_base_patch16_224_miil.in21k': _cfg(
@@ -1855,6 +1897,66 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
18551897
return model
18561898

18571899

1900+
@register_model
1901+
def vit_small_patch14_dinov2(pretrained=False, **kwargs):
1902+
""" ViT-S/14 for DINOv2
1903+
"""
1904+
model_args = dict(
1905+
patch_size=14, embed_dim=384, depth=12, num_heads=6,
1906+
init_values=1.0, img_size=518,
1907+
)
1908+
1909+
model = _create_vision_transformer(
1910+
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
1911+
return model
1912+
1913+
1914+
@register_model
1915+
def vit_base_patch14_dinov2(pretrained=False, **kwargs):
1916+
""" ViT-B/14 for DINOv2
1917+
"""
1918+
model_args = dict(
1919+
patch_size=14, embed_dim=768, depth=12, num_heads=12,
1920+
init_values=1.0, img_size=518,
1921+
)
1922+
1923+
model = _create_vision_transformer(
1924+
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
1925+
return model
1926+
1927+
1928+
@register_model
1929+
def vit_large_patch14_dinov2(pretrained=False, **kwargs):
1930+
""" ViT-L/14 for DINOv2
1931+
"""
1932+
model_args = dict(
1933+
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
1934+
init_values=1.0, img_size=518,
1935+
)
1936+
1937+
model = _create_vision_transformer(
1938+
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
1939+
return model
1940+
1941+
@register_model
1942+
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
1943+
""" ViT-G/14 for DINOv2
1944+
"""
1945+
1946+
# The hidden_features of SwiGLU is calculated by:
1947+
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
1948+
# When embed_dim=1536, hidden_features=4096
1949+
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
1950+
1951+
model_args = dict(
1952+
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
1953+
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
1954+
)
1955+
1956+
model = _create_vision_transformer(
1957+
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
1958+
return model
1959+
18581960
register_model_deprecations(__name__, {
18591961
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
18601962
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',

0 commit comments

Comments
 (0)