Skip to content

Commit c9db470

Browse files
authored
Merge pull request #1799 from huggingface/dot_nine_cleanup
Final cleanup before .9 release
2 parents 5cc87e6 + b9d43c7 commit c9db470

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+1453
-1236
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
3535
from .padding import get_padding, get_same_padding, pad_same
3636
from .patch_dropout import PatchDropout
37-
from .patch_embed import PatchEmbed, resample_patch_embed
37+
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
3838
from .pool2d_same import AvgPool2dSame, create_pool2d
3939
from .pos_embed import resample_abs_pos_embed
4040
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords

timm/layers/patch_embed.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Hacked together by / Copyright 2020 Ross Wightman
1010
"""
1111
import logging
12-
from typing import List, Optional, Callable
12+
from typing import Callable, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torch import nn as nn
@@ -75,6 +75,49 @@ def forward(self, x):
7575
return x
7676

7777

78+
class PatchEmbedWithSize(PatchEmbed):
79+
""" 2D Image to Patch Embedding
80+
"""
81+
output_fmt: Format
82+
83+
def __init__(
84+
self,
85+
img_size: Optional[int] = 224,
86+
patch_size: int = 16,
87+
in_chans: int = 3,
88+
embed_dim: int = 768,
89+
norm_layer: Optional[Callable] = None,
90+
flatten: bool = True,
91+
output_fmt: Optional[str] = None,
92+
bias: bool = True,
93+
):
94+
super().__init__(
95+
img_size=img_size,
96+
patch_size=patch_size,
97+
in_chans=in_chans,
98+
embed_dim=embed_dim,
99+
norm_layer=norm_layer,
100+
flatten=flatten,
101+
output_fmt=output_fmt,
102+
bias=bias,
103+
)
104+
105+
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
106+
B, C, H, W = x.shape
107+
if self.img_size is not None:
108+
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
109+
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
110+
111+
x = self.proj(x)
112+
grid_size = x.shape[-2:]
113+
if self.flatten:
114+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
115+
elif self.output_fmt != Format.NCHW:
116+
x = nchw_to(x, self.output_fmt)
117+
x = self.norm(x)
118+
return x, grid_size
119+
120+
78121
def resample_patch_embed(
79122
patch_embed,
80123
new_size: List[int],

timm/layers/pos_embed.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,31 @@ def resample_abs_pos_embed(
2424
verbose: bool = False,
2525
):
2626
# sort out sizes, assume square if old size not provided
27-
new_size = to_2tuple(new_size)
28-
new_ntok = new_size[0] * new_size[1]
29-
if not old_size:
30-
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
31-
old_size = to_2tuple(old_size)
32-
if new_size == old_size: # might not both be same container type
27+
num_pos_tokens = posemb.shape[1]
28+
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
29+
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
3330
return posemb
3431

32+
if not old_size:
33+
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
34+
old_size = hw, hw
35+
3536
if num_prefix_tokens:
3637
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
3738
else:
3839
posemb_prefix, posemb = None, posemb
3940

4041
# do the interpolation
42+
embed_dim = posemb.shape[-1]
4143
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
4244
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
43-
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
44-
45-
if verbose:
46-
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
45+
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
4746

4847
# add back extra (class, etc) prefix tokens
4948
if posemb_prefix is not None:
50-
print(posemb_prefix.shape, posemb.shape)
5149
posemb = torch.cat([posemb_prefix, posemb], dim=1)
50+
51+
if not torch.jit.is_scripting() and verbose:
52+
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
53+
5254
return posemb

timm/models/_builder.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
2323
_DOWNLOAD_PROGRESS = False
2424
_CHECK_HASH = False
25-
25+
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0
2626

2727
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
2828
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
@@ -32,6 +32,7 @@ def _resolve_pretrained_source(pretrained_cfg):
3232
cfg_source = pretrained_cfg.get('source', '')
3333
pretrained_url = pretrained_cfg.get('url', None)
3434
pretrained_file = pretrained_cfg.get('file', None)
35+
pretrained_sd = pretrained_cfg.get('state_dict', None)
3536
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
3637

3738
# resolve where to load pretrained weights from
@@ -44,14 +45,21 @@ def _resolve_pretrained_source(pretrained_cfg):
4445
pretrained_loc = hf_hub_id
4546
else:
4647
# default source == timm or unspecified
47-
if pretrained_file:
48-
# file load override is the highest priority if set
48+
if pretrained_sd:
49+
# direct state_dict pass through is the highest priority
50+
load_from = 'state_dict'
51+
pretrained_loc = pretrained_sd
52+
assert isinstance(pretrained_loc, dict)
53+
elif pretrained_file:
54+
# file load override is the second-highest priority if set
4955
load_from = 'file'
5056
pretrained_loc = pretrained_file
5157
else:
52-
# next, HF hub is prioritized unless a valid cached version of weights exists already
53-
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
54-
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
58+
old_cache_valid = False
59+
if _USE_OLD_CACHE:
60+
# prioritized old cached weights if exists and env var enabled
61+
old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
62+
if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
5563
# hf-hub available as alternate weight source in default_cfg
5664
load_from = 'hf-hub'
5765
pretrained_loc = hf_hub_id
@@ -106,7 +114,7 @@ def load_custom_pretrained(
106114
if not load_from:
107115
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
108116
return
109-
if load_from == 'hf-hub': # FIXME
117+
if load_from == 'hf-hub':
110118
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
111119
elif load_from == 'url':
112120
pretrained_loc = download_cached_file(
@@ -148,7 +156,10 @@ def load_pretrained(
148156
return
149157

150158
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
151-
if load_from == 'file':
159+
if load_from == 'state_dict':
160+
_logger.info(f'Loading pretrained weights from state dict')
161+
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
162+
elif load_from == 'file':
152163
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
153164
state_dict = load_state_dict(pretrained_loc)
154165
elif load_from == 'url':

timm/models/_pretrained.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
class PretrainedCfg:
1212
"""
1313
"""
14-
# weight locations
15-
url: Optional[Union[str, Tuple[str, str]]] = None
16-
file: Optional[str] = None
17-
hf_hub_id: Optional[str] = None
18-
hf_hub_filename: Optional[str] = None
14+
# weight source locations
15+
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
16+
file: Optional[str] = None # local / shared filesystem path
17+
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
18+
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
19+
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
1920

2021
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
2122
architecture: Optional[str] = None # architecture variant can be set when not implicit

timm/models/beit.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,11 @@ def _cfg(url='', **kwargs):
477477
hf_hub_id='timm/',
478478
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
479479
),
480+
'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
481+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
482+
hf_hub_id='timm/',
483+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
484+
),
480485
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
481486
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
482487
hf_hub_id='timm/',
@@ -487,6 +492,11 @@ def _cfg(url='', **kwargs):
487492
hf_hub_id='timm/',
488493
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
489494
),
495+
'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
496+
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
497+
hf_hub_id='timm/',
498+
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
499+
),
490500
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
491501
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
492502
hf_hub_id='timm/',
@@ -515,7 +525,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
515525

516526

517527
@register_model
518-
def beit_base_patch16_224(pretrained=False, **kwargs):
528+
def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit:
519529
model_args = dict(
520530
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
521531
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
@@ -524,7 +534,7 @@ def beit_base_patch16_224(pretrained=False, **kwargs):
524534

525535

526536
@register_model
527-
def beit_base_patch16_384(pretrained=False, **kwargs):
537+
def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit:
528538
model_args = dict(
529539
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
530540
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
@@ -533,7 +543,7 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
533543

534544

535545
@register_model
536-
def beit_large_patch16_224(pretrained=False, **kwargs):
546+
def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit:
537547
model_args = dict(
538548
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
539549
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
@@ -542,7 +552,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs):
542552

543553

544554
@register_model
545-
def beit_large_patch16_384(pretrained=False, **kwargs):
555+
def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit:
546556
model_args = dict(
547557
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
548558
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
@@ -551,7 +561,7 @@ def beit_large_patch16_384(pretrained=False, **kwargs):
551561

552562

553563
@register_model
554-
def beit_large_patch16_512(pretrained=False, **kwargs):
564+
def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit:
555565
model_args = dict(
556566
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
557567
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
@@ -560,7 +570,7 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
560570

561571

562572
@register_model
563-
def beitv2_base_patch16_224(pretrained=False, **kwargs):
573+
def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit:
564574
model_args = dict(
565575
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
566576
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
@@ -569,7 +579,7 @@ def beitv2_base_patch16_224(pretrained=False, **kwargs):
569579

570580

571581
@register_model
572-
def beitv2_large_patch16_224(pretrained=False, **kwargs):
582+
def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit:
573583
model_args = dict(
574584
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
575585
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)

0 commit comments

Comments
 (0)