|
21 | 21 | import torch.utils.checkpoint as checkpoint |
22 | 22 |
|
23 | 23 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
24 | | -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead |
| 24 | +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ |
| 25 | + resample_patch_embed |
25 | 26 | from ._builder import build_model_with_cfg |
26 | 27 | from ._features_fx import register_notrace_function |
27 | 28 | from ._registry import generate_default_cfgs, register_model, register_model_deprecations |
@@ -622,6 +623,18 @@ def checkpoint_filter_fn(state_dict, model): |
622 | 623 | for k, v in state_dict.items(): |
623 | 624 | if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]): |
624 | 625 | continue # skip buffers that should not be persistent |
| 626 | + |
| 627 | + if 'patch_embed.proj.weight' in k: |
| 628 | + _, _, H, W = model.patch_embed.proj.weight.shape |
| 629 | + if v.shape[-2] != H or v.shape[-1] != W: |
| 630 | + v = resample_patch_embed( |
| 631 | + v, |
| 632 | + (H, W), |
| 633 | + interpolation='bicubic', |
| 634 | + antialias=True, |
| 635 | + verbose=True, |
| 636 | + ) |
| 637 | + |
625 | 638 | if not native_checkpoint: |
626 | 639 | # skip layer remapping for updated checkpoints |
627 | 640 | k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) |
|
0 commit comments