@@ -904,16 +904,17 @@ def _n2p(w, t=True):
904904 getattr (block .mlp , f'fc{ r + 1 } ' ).bias .copy_ (_n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /bias' ]))
905905
906906
907- def _convert_openai_clip (state_dict , model ):
907+ def _convert_openai_clip (state_dict , model , prefix = 'visual.' ):
908908 out_dict = {}
909909 swaps = [
910- ('visual.' , '' ), ( ' conv1' , 'patch_embed.proj' ), ('positional_embedding' , 'pos_embed' ),
910+ ('conv1' , 'patch_embed.proj' ), ('positional_embedding' , 'pos_embed' ),
911911 ('transformer.resblocks.' , 'blocks.' ), ('ln_pre' , 'norm_pre' ), ('ln_post' , 'norm' ), ('ln_' , 'norm' ),
912912 ('in_proj_' , 'qkv.' ), ('out_proj' , 'proj' ), ('mlp.c_fc' , 'mlp.fc1' ), ('mlp.c_proj' , 'mlp.fc2' ),
913913 ]
914914 for k , v in state_dict .items ():
915- if not k .startswith ('visual.' ):
915+ if not k .startswith (prefix ):
916916 continue
917+ k = k .replace (prefix , '' )
917918 for sp in swaps :
918919 k = k .replace (sp [0 ], sp [1 ])
919920
@@ -974,6 +975,8 @@ def checkpoint_filter_fn(
974975
975976 if 'visual.class_embedding' in state_dict :
976977 return _convert_openai_clip (state_dict , model )
978+ elif 'module.visual.class_embedding' in state_dict :
979+ return _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
977980
978981 if "mask_token" in state_dict :
979982 state_dict = _convert_dinov2 (state_dict , model )
@@ -1416,6 +1419,10 @@ def _cfg(url='', **kwargs):
14161419 hf_hub_id = 'laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K' ,
14171420 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
14181421 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
1422+ 'vit_base_patch16_clip_224.dfn2b' : _cfg (
1423+ hf_hub_id = 'apple/DFN2B-CLIP-ViT-B-16' ,
1424+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1425+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
14191426 'vit_large_patch14_clip_224.laion2b' : _cfg (
14201427 hf_hub_id = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K' ,
14211428 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
@@ -1424,10 +1431,22 @@ def _cfg(url='', **kwargs):
14241431 hf_hub_id = 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K' ,
14251432 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
14261433 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
1434+ 'vit_large_patch14_clip_224.dfn2b' : _cfg (
1435+ hf_hub_id = 'apple/DFN2B-CLIP-ViT-L-14' ,
1436+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1437+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
14271438 'vit_huge_patch14_clip_224.laion2b' : _cfg (
14281439 hf_hub_id = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K' ,
14291440 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
14301441 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
1442+ 'vit_huge_patch14_clip_224.dfn5b' : _cfg (
1443+ hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14' ,
1444+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1445+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
1446+ 'vit_huge_patch14_clip_378.dfn5b' : _cfg (
1447+ hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14-378' ,
1448+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1449+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
14311450 'vit_giant_patch14_clip_224.laion2b' : _cfg (
14321451 hf_hub_id = 'laion/CLIP-ViT-g-14-laion2B-s12B-b42K' ,
14331452 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
@@ -2026,6 +2045,16 @@ def vit_huge_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer:
20262045 return model
20272046
20282047
2048+ @register_model
2049+ def vit_huge_patch14_clip_378 (pretrained = False , ** kwargs ) -> VisionTransformer :
2050+ """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
2051+ """
2052+ model_args = dict (patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , pre_norm = True , norm_layer = nn .LayerNorm )
2053+ model = _create_vision_transformer (
2054+ 'vit_huge_patch14_clip_378' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2055+ return model
2056+
2057+
20292058@register_model
20302059def vit_giant_patch14_clip_224 (pretrained = False , ** kwargs ) -> VisionTransformer :
20312060 """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
0 commit comments