Skip to content

Commit 6bae514

Browse files
committed
Add pretrained patch embed resizing to swin
1 parent 5c504b4 commit 6bae514

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

timm/models/swin_transformer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2626
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
27-
_assert, use_fused_attn, resize_rel_pos_bias_table
27+
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed
2828
from ._builder import build_model_with_cfg
2929
from ._features_fx import register_notrace_function
3030
from ._manipulate import checkpoint_seq, named_apply
@@ -632,6 +632,17 @@ def checkpoint_filter_fn(state_dict, model):
632632
if any([n in k for n in ('relative_position_index', 'attn_mask')]):
633633
continue # skip buffers that should not be persistent
634634

635+
if 'patch_embed.proj.weight' in k:
636+
_, _, H, W = model.patch_embed.proj.weight.shape
637+
if v.shape[-2] != H or v.shape[-1] != W:
638+
v = resample_patch_embed(
639+
v,
640+
(H, W),
641+
interpolation='bicubic',
642+
antialias=True,
643+
verbose=True,
644+
)
645+
635646
if k.endswith('relative_position_bias_table'):
636647
m = model.get_submodule(k[:-29])
637648
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:

timm/models/swin_transformer_v2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import torch.utils.checkpoint as checkpoint
2222

2323
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
24-
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
24+
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
25+
resample_patch_embed
2526
from ._builder import build_model_with_cfg
2627
from ._features_fx import register_notrace_function
2728
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@@ -622,6 +623,18 @@ def checkpoint_filter_fn(state_dict, model):
622623
for k, v in state_dict.items():
623624
if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]):
624625
continue # skip buffers that should not be persistent
626+
627+
if 'patch_embed.proj.weight' in k:
628+
_, _, H, W = model.patch_embed.proj.weight.shape
629+
if v.shape[-2] != H or v.shape[-1] != W:
630+
v = resample_patch_embed(
631+
v,
632+
(H, W),
633+
interpolation='bicubic',
634+
antialias=True,
635+
verbose=True,
636+
)
637+
625638
if not native_checkpoint:
626639
# skip layer remapping for updated checkpoints
627640
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)

0 commit comments

Comments
 (0)