Skip to content

Commit c55bc41

Browse files
committed
DFN CLIP ViT support
1 parent d5f1525 commit c55bc41

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

timm/models/vision_transformer.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,16 +904,17 @@ def _n2p(w, t=True):
904904
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
905905

906906

907-
def _convert_openai_clip(state_dict, model):
907+
def _convert_openai_clip(state_dict, model, prefix='visual.'):
908908
out_dict = {}
909909
swaps = [
910-
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
910+
('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
911911
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
912912
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
913913
]
914914
for k, v in state_dict.items():
915-
if not k.startswith('visual.'):
915+
if not k.startswith(prefix):
916916
continue
917+
k = k.replace(prefix, '')
917918
for sp in swaps:
918919
k = k.replace(sp[0], sp[1])
919920

@@ -974,6 +975,8 @@ def checkpoint_filter_fn(
974975

975976
if 'visual.class_embedding' in state_dict:
976977
return _convert_openai_clip(state_dict, model)
978+
elif 'module.visual.class_embedding' in state_dict:
979+
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
977980

978981
if "mask_token" in state_dict:
979982
state_dict = _convert_dinov2(state_dict, model)
@@ -1416,6 +1419,10 @@ def _cfg(url='', **kwargs):
14161419
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
14171420
hf_hub_filename='open_clip_pytorch_model.bin',
14181421
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1422+
'vit_base_patch16_clip_224.dfn2b': _cfg(
1423+
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
1424+
hf_hub_filename='open_clip_pytorch_model.bin',
1425+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
14191426
'vit_large_patch14_clip_224.laion2b': _cfg(
14201427
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
14211428
hf_hub_filename='open_clip_pytorch_model.bin',
@@ -1424,10 +1431,22 @@ def _cfg(url='', **kwargs):
14241431
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
14251432
hf_hub_filename='open_clip_pytorch_model.bin',
14261433
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
1434+
'vit_large_patch14_clip_224.dfn2b': _cfg(
1435+
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
1436+
hf_hub_filename='open_clip_pytorch_model.bin',
1437+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
14271438
'vit_huge_patch14_clip_224.laion2b': _cfg(
14281439
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
14291440
hf_hub_filename='open_clip_pytorch_model.bin',
14301441
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1442+
'vit_huge_patch14_clip_224.dfn5b': _cfg(
1443+
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
1444+
hf_hub_filename='open_clip_pytorch_model.bin',
1445+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1446+
'vit_huge_patch14_clip_378.dfn5b': _cfg(
1447+
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
1448+
hf_hub_filename='open_clip_pytorch_model.bin',
1449+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
14311450
'vit_giant_patch14_clip_224.laion2b': _cfg(
14321451
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
14331452
hf_hub_filename='open_clip_pytorch_model.bin',
@@ -2026,6 +2045,16 @@ def vit_huge_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer:
20262045
return model
20272046

20282047

2048+
@register_model
2049+
def vit_huge_patch14_clip_378(pretrained=False, **kwargs) -> VisionTransformer:
2050+
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
2051+
"""
2052+
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
2053+
model = _create_vision_transformer(
2054+
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
2055+
return model
2056+
2057+
20292058
@register_model
20302059
def vit_giant_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
20312060
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560

0 commit comments

Comments
 (0)