Skip to content

Commit a2e4a4c

Browse files
committed
Add quickgelu vit clip variants, simplify get_norm_layer and allow string args in vit norm/act. Add metaclip CLIP weights
1 parent c55bc41 commit a2e4a4c

File tree

4 files changed

+163
-42
lines changed

4 files changed

+163
-42
lines changed

timm/layers/activations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,17 @@ def __init__(self, inplace: bool = False):
157157

158158
def forward(self, input: torch.Tensor) -> torch.Tensor:
159159
return F.gelu(input, approximate='tanh')
160+
161+
162+
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
163+
return x * torch.sigmoid(1.702 * x)
164+
165+
166+
class QuickGELU(nn.Module):
167+
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
168+
"""
169+
def __init__(self, inplace: bool = False):
170+
super(QuickGELU, self).__init__()
171+
172+
def forward(self, input: torch.Tensor) -> torch.Tensor:
173+
return quick_gelu(input)

timm/layers/create_act.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
selu=F.selu,
3030
gelu=gelu,
3131
gelu_tanh=gelu_tanh,
32+
quick_gelu=quick_gelu,
3233
sigmoid=sigmoid,
3334
tanh=tanh,
3435
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
@@ -42,7 +43,7 @@
4243
mish=F.mish if _has_mish else mish_jit,
4344
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
4445
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
45-
hard_mish=hard_mish_jit
46+
hard_mish=hard_mish_jit,
4647
)
4748

4849
_ACT_FN_ME = dict(
@@ -73,6 +74,7 @@
7374
selu=nn.SELU,
7475
gelu=GELU,
7576
gelu_tanh=GELUTanh,
77+
quick_gelu=QuickGELU,
7678
sigmoid=Sigmoid,
7779
tanh=Tanh,
7880
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
@@ -87,7 +89,7 @@
8789
mish=nn.Mish if _has_mish else MishJit,
8890
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
8991
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
90-
hard_mish=HardMishJit
92+
hard_mish=HardMishJit,
9193
)
9294

9395
_ACT_LAYER_ME = dict(

timm/layers/create_norm.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
55
Copyright 2022 Ross Wightman
66
"""
7-
import types
87
import functools
8+
import types
9+
from typing import Type
910

1011
import torch.nn as nn
1112

12-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
13+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
14+
from torchvision.ops import FrozenBatchNorm2d
1315

1416
_NORM_MAP = dict(
1517
batchnorm=nn.BatchNorm2d,
@@ -19,6 +21,8 @@
1921
groupnorm1=GroupNorm1,
2022
layernorm=LayerNorm,
2123
layernorm2d=LayerNorm2d,
24+
rmsnorm=RmsNorm,
25+
frozenbatchnorm2d=FrozenBatchNorm2d,
2226
)
2327
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
2428

@@ -30,7 +34,10 @@ def create_norm_layer(layer_name, num_features, **kwargs):
3034

3135

3236
def get_norm_layer(norm_layer):
33-
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
37+
if not norm_layer:
38+
# None or '' should return None
39+
return None
40+
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
3441
norm_kwargs = {}
3542

3643
# unbind partial fn, so args can be rebound later
@@ -40,16 +47,9 @@ def get_norm_layer(norm_layer):
4047

4148
if isinstance(norm_layer, str):
4249
layer_name = norm_layer.replace('_', '')
43-
norm_layer = _NORM_MAP.get(layer_name, None)
44-
elif norm_layer in _NORM_TYPES:
45-
norm_layer = norm_layer
46-
elif isinstance(norm_layer, types.FunctionType):
47-
# if function type, assume it is a lambda/fn that creates a norm layer
48-
norm_layer = norm_layer
50+
norm_layer = _NORM_MAP[layer_name]
4951
else:
50-
type_name = norm_layer.__name__.lower().replace('_', '')
51-
norm_layer = _NORM_MAP.get(type_name, None)
52-
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
52+
norm_layer = norm_layer
5353

5454
if norm_kwargs:
5555
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args

timm/models/vision_transformer.py

Lines changed: 133 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@
2727
import math
2828
from collections import OrderedDict
2929
from functools import partial
30-
from typing import Callable, List, Optional, Sequence, Tuple, Union
30+
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
3131

3232
import torch
3333
import torch.nn as nn
3434
import torch.nn.functional as F
3535
import torch.utils.checkpoint
3636
from torch.jit import Final
3737

38+
3839
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
3940
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4041
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
41-
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
42+
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
43+
get_act_layer, get_norm_layer, LayerType
4244
from ._builder import build_model_with_cfg
4345
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
4446
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@@ -414,10 +416,10 @@ def __init__(
414416
drop_path_rate: float = 0.,
415417
weight_init: str = '',
416418
embed_layer: Callable = PatchEmbed,
417-
norm_layer: Optional[Callable] = None,
418-
act_layer: Optional[Callable] = None,
419-
block_fn: Callable = Block,
420-
mlp_layer: Callable = Mlp,
419+
norm_layer: Optional[LayerType] = None,
420+
act_layer: Optional[LayerType] = None,
421+
block_fn: Type[nn.Module] = Block,
422+
mlp_layer: Type[nn.Module] = Mlp,
421423
):
422424
"""
423425
Args:
@@ -450,8 +452,8 @@ def __init__(
450452
assert global_pool in ('', 'avg', 'token', 'map')
451453
assert class_token or global_pool != 'token'
452454
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
453-
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
454-
act_layer = act_layer or nn.GELU
455+
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
456+
act_layer = get_act_layer(act_layer) or nn.GELU
455457

456458
self.num_classes = num_classes
457459
self.global_pool = global_pool
@@ -1415,46 +1417,75 @@ def _cfg(url='', **kwargs):
14151417
hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
14161418
hf_hub_filename='open_clip_pytorch_model.bin',
14171419
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1418-
'vit_base_patch16_clip_224.datacompxl': _cfg(
1419-
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
1420-
hf_hub_filename='open_clip_pytorch_model.bin',
1421-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1422-
'vit_base_patch16_clip_224.dfn2b': _cfg(
1423-
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
1424-
hf_hub_filename='open_clip_pytorch_model.bin',
1425-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
14261420
'vit_large_patch14_clip_224.laion2b': _cfg(
14271421
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
14281422
hf_hub_filename='open_clip_pytorch_model.bin',
14291423
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
1424+
'vit_huge_patch14_clip_224.laion2b': _cfg(
1425+
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
1426+
hf_hub_filename='open_clip_pytorch_model.bin',
1427+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1428+
'vit_giant_patch14_clip_224.laion2b': _cfg(
1429+
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
1430+
hf_hub_filename='open_clip_pytorch_model.bin',
1431+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1432+
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
1433+
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
1434+
hf_hub_filename='open_clip_pytorch_model.bin',
1435+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
1436+
1437+
'vit_base_patch32_clip_224.datacompxl': _cfg(
1438+
hf_hub_id='laion/',
1439+
hf_hub_filename='open_clip_pytorch_model.bin',
1440+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1441+
'vit_base_patch32_clip_256.datacompxl': _cfg(
1442+
hf_hub_id='laion/',
1443+
hf_hub_filename='open_clip_pytorch_model.bin',
1444+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1445+
crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
1446+
'vit_base_patch16_clip_224.datacompxl': _cfg(
1447+
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
1448+
hf_hub_filename='open_clip_pytorch_model.bin',
1449+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
14301450
'vit_large_patch14_clip_224.datacompxl': _cfg(
14311451
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
14321452
hf_hub_filename='open_clip_pytorch_model.bin',
14331453
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
1454+
1455+
'vit_base_patch16_clip_224.dfn2b': _cfg(
1456+
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
1457+
hf_hub_filename='open_clip_pytorch_model.bin',
1458+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
14341459
'vit_large_patch14_clip_224.dfn2b': _cfg(
14351460
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
14361461
hf_hub_filename='open_clip_pytorch_model.bin',
14371462
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
1438-
'vit_huge_patch14_clip_224.laion2b': _cfg(
1439-
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
1440-
hf_hub_filename='open_clip_pytorch_model.bin',
1441-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
14421463
'vit_huge_patch14_clip_224.dfn5b': _cfg(
14431464
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
14441465
hf_hub_filename='open_clip_pytorch_model.bin',
14451466
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
14461467
'vit_huge_patch14_clip_378.dfn5b': _cfg(
14471468
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
14481469
hf_hub_filename='open_clip_pytorch_model.bin',
1470+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1471+
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
1472+
1473+
'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
1474+
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
1475+
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
1476+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1477+
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
1478+
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
1479+
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
1480+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1481+
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
1482+
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
1483+
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
1484+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
1485+
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
1486+
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
1487+
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
14491488
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1450-
'vit_giant_patch14_clip_224.laion2b': _cfg(
1451-
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
1452-
hf_hub_filename='open_clip_pytorch_model.bin',
1453-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1454-
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
1455-
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
1456-
hf_hub_filename='open_clip_pytorch_model.bin',
1457-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
14581489

14591490
'vit_base_patch32_clip_224.openai': _cfg(
14601491
hf_hub_id='timm/',
@@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform
20782109
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
20792110
return model
20802111

2112+
2113+
@register_model
2114+
def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
2115+
""" ViT-B/32 CLIP image tower @ 224x224
2116+
"""
2117+
model_args = dict(
2118+
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
2119+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2120+
model = _create_vision_transformer(
2121+
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2122+
return model
2123+
2124+
2125+
@register_model
2126+
def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
2127+
""" ViT-B/16 CLIP image tower w/ QuickGELU act
2128+
"""
2129+
model_args = dict(
2130+
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
2131+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2132+
model = _create_vision_transformer(
2133+
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2134+
return model
2135+
2136+
2137+
@register_model
2138+
def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
2139+
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
2140+
"""
2141+
from timm.layers import get_act_layer
2142+
model_args = dict(
2143+
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
2144+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2145+
model = _create_vision_transformer(
2146+
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2147+
return model
2148+
2149+
2150+
@register_model
2151+
def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer:
2152+
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
2153+
"""
2154+
model_args = dict(
2155+
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
2156+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2157+
model = _create_vision_transformer(
2158+
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
2159+
return model
2160+
2161+
2162+
@register_model
2163+
def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
2164+
""" ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
2165+
"""
2166+
model_args = dict(
2167+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
2168+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2169+
model = _create_vision_transformer(
2170+
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2171+
return model
2172+
2173+
2174+
@register_model
2175+
def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer:
2176+
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
2177+
"""
2178+
model_args = dict(
2179+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
2180+
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
2181+
model = _create_vision_transformer(
2182+
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
2183+
return model
2184+
2185+
20812186
# Experimental models below
20822187

20832188
@register_model

0 commit comments

Comments
 (0)