|
27 | 27 | import math |
28 | 28 | from collections import OrderedDict |
29 | 29 | from functools import partial |
30 | | -from typing import Callable, List, Optional, Sequence, Tuple, Union |
| 30 | +from typing import Callable, List, Optional, Sequence, Tuple, Type, Union |
31 | 31 |
|
32 | 32 | import torch |
33 | 33 | import torch.nn as nn |
34 | 34 | import torch.nn.functional as F |
35 | 35 | import torch.utils.checkpoint |
36 | 36 | from torch.jit import Final |
37 | 37 |
|
| 38 | + |
38 | 39 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ |
39 | 40 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
40 | 41 | from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ |
41 | | - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn |
| 42 | + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ |
| 43 | + get_act_layer, get_norm_layer, LayerType |
42 | 44 | from ._builder import build_model_with_cfg |
43 | 45 | from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv |
44 | 46 | from ._registry import generate_default_cfgs, register_model, register_model_deprecations |
@@ -414,10 +416,10 @@ def __init__( |
414 | 416 | drop_path_rate: float = 0., |
415 | 417 | weight_init: str = '', |
416 | 418 | embed_layer: Callable = PatchEmbed, |
417 | | - norm_layer: Optional[Callable] = None, |
418 | | - act_layer: Optional[Callable] = None, |
419 | | - block_fn: Callable = Block, |
420 | | - mlp_layer: Callable = Mlp, |
| 419 | + norm_layer: Optional[LayerType] = None, |
| 420 | + act_layer: Optional[LayerType] = None, |
| 421 | + block_fn: Type[nn.Module] = Block, |
| 422 | + mlp_layer: Type[nn.Module] = Mlp, |
421 | 423 | ): |
422 | 424 | """ |
423 | 425 | Args: |
@@ -450,8 +452,8 @@ def __init__( |
450 | 452 | assert global_pool in ('', 'avg', 'token', 'map') |
451 | 453 | assert class_token or global_pool != 'token' |
452 | 454 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm |
453 | | - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
454 | | - act_layer = act_layer or nn.GELU |
| 455 | + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) |
| 456 | + act_layer = get_act_layer(act_layer) or nn.GELU |
455 | 457 |
|
456 | 458 | self.num_classes = num_classes |
457 | 459 | self.global_pool = global_pool |
@@ -1415,46 +1417,75 @@ def _cfg(url='', **kwargs): |
1415 | 1417 | hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', |
1416 | 1418 | hf_hub_filename='open_clip_pytorch_model.bin', |
1417 | 1419 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
1418 | | - 'vit_base_patch16_clip_224.datacompxl': _cfg( |
1419 | | - hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K', |
1420 | | - hf_hub_filename='open_clip_pytorch_model.bin', |
1421 | | - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
1422 | | - 'vit_base_patch16_clip_224.dfn2b': _cfg( |
1423 | | - hf_hub_id='apple/DFN2B-CLIP-ViT-B-16', |
1424 | | - hf_hub_filename='open_clip_pytorch_model.bin', |
1425 | | - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
1426 | 1420 | 'vit_large_patch14_clip_224.laion2b': _cfg( |
1427 | 1421 | hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', |
1428 | 1422 | hf_hub_filename='open_clip_pytorch_model.bin', |
1429 | 1423 | mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), |
| 1424 | + 'vit_huge_patch14_clip_224.laion2b': _cfg( |
| 1425 | + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', |
| 1426 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1427 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
| 1428 | + 'vit_giant_patch14_clip_224.laion2b': _cfg( |
| 1429 | + hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', |
| 1430 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1431 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
| 1432 | + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( |
| 1433 | + hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', |
| 1434 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1435 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), |
| 1436 | + |
| 1437 | + 'vit_base_patch32_clip_224.datacompxl': _cfg( |
| 1438 | + hf_hub_id='laion/', |
| 1439 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1440 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1441 | + 'vit_base_patch32_clip_256.datacompxl': _cfg( |
| 1442 | + hf_hub_id='laion/', |
| 1443 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1444 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, |
| 1445 | + crop_pct=1.0, input_size=(3, 256, 256), num_classes=512), |
| 1446 | + 'vit_base_patch16_clip_224.datacompxl': _cfg( |
| 1447 | + hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K', |
| 1448 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1449 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
1430 | 1450 | 'vit_large_patch14_clip_224.datacompxl': _cfg( |
1431 | 1451 | hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K', |
1432 | 1452 | hf_hub_filename='open_clip_pytorch_model.bin', |
1433 | 1453 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), |
| 1454 | + |
| 1455 | + 'vit_base_patch16_clip_224.dfn2b': _cfg( |
| 1456 | + hf_hub_id='apple/DFN2B-CLIP-ViT-B-16', |
| 1457 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1458 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
1434 | 1459 | 'vit_large_patch14_clip_224.dfn2b': _cfg( |
1435 | 1460 | hf_hub_id='apple/DFN2B-CLIP-ViT-L-14', |
1436 | 1461 | hf_hub_filename='open_clip_pytorch_model.bin', |
1437 | 1462 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), |
1438 | | - 'vit_huge_patch14_clip_224.laion2b': _cfg( |
1439 | | - hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', |
1440 | | - hf_hub_filename='open_clip_pytorch_model.bin', |
1441 | | - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
1442 | 1463 | 'vit_huge_patch14_clip_224.dfn5b': _cfg( |
1443 | 1464 | hf_hub_id='apple/DFN5B-CLIP-ViT-H-14', |
1444 | 1465 | hf_hub_filename='open_clip_pytorch_model.bin', |
1445 | 1466 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
1446 | 1467 | 'vit_huge_patch14_clip_378.dfn5b': _cfg( |
1447 | 1468 | hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378', |
1448 | 1469 | hf_hub_filename='open_clip_pytorch_model.bin', |
| 1470 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, |
| 1471 | + crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), |
| 1472 | + |
| 1473 | + 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg( |
| 1474 | + hf_hub_id='facebook/metaclip-b32-fullcc2.5b', |
| 1475 | + hf_hub_filename='metaclip_b32_fullcc2.5b.bin', |
| 1476 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1477 | + 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg( |
| 1478 | + hf_hub_id='facebook/metaclip-b16-fullcc2.5b', |
| 1479 | + hf_hub_filename='metaclip_b16_fullcc2.5b.bin', |
| 1480 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1481 | + 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg( |
| 1482 | + hf_hub_id='facebook/metaclip-l14-fullcc2.5b', |
| 1483 | + hf_hub_filename='metaclip_l14_fullcc2.5b.bin', |
| 1484 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), |
| 1485 | + 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg( |
| 1486 | + hf_hub_id='facebook/metaclip-h14-fullcc2.5b', |
| 1487 | + hf_hub_filename='metaclip_h14_fullcc2.5b.bin', |
1449 | 1488 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
1450 | | - 'vit_giant_patch14_clip_224.laion2b': _cfg( |
1451 | | - hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', |
1452 | | - hf_hub_filename='open_clip_pytorch_model.bin', |
1453 | | - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), |
1454 | | - 'vit_gigantic_patch14_clip_224.laion2b': _cfg( |
1455 | | - hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', |
1456 | | - hf_hub_filename='open_clip_pytorch_model.bin', |
1457 | | - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), |
1458 | 1489 |
|
1459 | 1490 | 'vit_base_patch32_clip_224.openai': _cfg( |
1460 | 1491 | hf_hub_id='timm/', |
@@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform |
2078 | 2109 | 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
2079 | 2110 | return model |
2080 | 2111 |
|
| 2112 | + |
| 2113 | +@register_model |
| 2114 | +def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: |
| 2115 | + """ ViT-B/32 CLIP image tower @ 224x224 |
| 2116 | + """ |
| 2117 | + model_args = dict( |
| 2118 | + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, |
| 2119 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2120 | + model = _create_vision_transformer( |
| 2121 | + 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2122 | + return model |
| 2123 | + |
| 2124 | + |
| 2125 | +@register_model |
| 2126 | +def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: |
| 2127 | + """ ViT-B/16 CLIP image tower w/ QuickGELU act |
| 2128 | + """ |
| 2129 | + model_args = dict( |
| 2130 | + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, |
| 2131 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2132 | + model = _create_vision_transformer( |
| 2133 | + 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2134 | + return model |
| 2135 | + |
| 2136 | + |
| 2137 | +@register_model |
| 2138 | +def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: |
| 2139 | + """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act |
| 2140 | + """ |
| 2141 | + from timm.layers import get_act_layer |
| 2142 | + model_args = dict( |
| 2143 | + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, |
| 2144 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2145 | + model = _create_vision_transformer( |
| 2146 | + 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2147 | + return model |
| 2148 | + |
| 2149 | + |
| 2150 | +@register_model |
| 2151 | +def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer: |
| 2152 | + """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act |
| 2153 | + """ |
| 2154 | + model_args = dict( |
| 2155 | + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, |
| 2156 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2157 | + model = _create_vision_transformer( |
| 2158 | + 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2159 | + return model |
| 2160 | + |
| 2161 | + |
| 2162 | +@register_model |
| 2163 | +def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: |
| 2164 | + """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act. |
| 2165 | + """ |
| 2166 | + model_args = dict( |
| 2167 | + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, |
| 2168 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2169 | + model = _create_vision_transformer( |
| 2170 | + 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2171 | + return model |
| 2172 | + |
| 2173 | + |
| 2174 | +@register_model |
| 2175 | +def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer: |
| 2176 | + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act |
| 2177 | + """ |
| 2178 | + model_args = dict( |
| 2179 | + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, |
| 2180 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2181 | + model = _create_vision_transformer( |
| 2182 | + 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2183 | + return model |
| 2184 | + |
| 2185 | + |
2081 | 2186 | # Experimental models below |
2082 | 2187 |
|
2083 | 2188 | @register_model |
|
0 commit comments