Skip to content

Commit ea3519a

Browse files
committed
Fix dynamic_resize for deit models (distilled or no_embed_cls) and vit w/o class tokens
1 parent 4d8ecde commit ea3519a

File tree

2 files changed

+36
-44
lines changed

2 files changed

+36
-44
lines changed

timm/models/deit.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -73,45 +73,36 @@ def reset_classifier(self, num_classes, global_pool=None):
7373
def set_distilled_training(self, enable=True):
7474
self.distilled_training = enable
7575

76-
def _intermediate_layers(
77-
self,
78-
x: torch.Tensor,
79-
n: Union[int, Sequence] = 1,
80-
):
81-
outputs, num_blocks = [], len(self.blocks)
82-
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
83-
84-
# forward pass
85-
x = self.patch_embed(x)
86-
x = torch.cat((
87-
self.cls_token.expand(x.shape[0], -1, -1),
88-
self.dist_token.expand(x.shape[0], -1, -1),
89-
x),
90-
dim=1)
91-
x = self.pos_drop(x + self.pos_embed)
92-
x = self.patch_drop(x)
93-
x = self.norm_pre(x)
94-
for i, blk in enumerate(self.blocks):
95-
x = blk(x)
96-
if i in take_indices:
97-
outputs.append(x)
98-
99-
return outputs
100-
101-
def forward_features(self, x) -> torch.Tensor:
102-
x = self.patch_embed(x)
103-
x = torch.cat((
104-
self.cls_token.expand(x.shape[0], -1, -1),
105-
self.dist_token.expand(x.shape[0], -1, -1),
106-
x),
107-
dim=1)
108-
x = self.pos_drop(x + self.pos_embed)
109-
if self.grad_checkpointing and not torch.jit.is_scripting():
110-
x = checkpoint_seq(self.blocks, x)
76+
def _pos_embed(self, x):
77+
if self.dynamic_size:
78+
B, H, W, C = x.shape
79+
pos_embed = resample_abs_pos_embed(
80+
self.pos_embed,
81+
(H, W),
82+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
83+
)
84+
x = x.view(B, -1, C)
11185
else:
112-
x = self.blocks(x)
113-
x = self.norm(x)
114-
return x
86+
pos_embed = self.pos_embed
87+
if self.no_embed_class:
88+
# deit-3, updated JAX (big vision)
89+
# position embedding does not overlap with class token, add then concat
90+
x = x + pos_embed
91+
x = torch.cat((
92+
self.cls_token.expand(x.shape[0], -1, -1),
93+
self.dist_token.expand(x.shape[0], -1, -1),
94+
x),
95+
dim=1)
96+
else:
97+
# original timm, JAX, and deit vit impl
98+
# pos_embed has entry for class token, concat then add
99+
x = torch.cat((
100+
self.cls_token.expand(x.shape[0], -1, -1),
101+
self.dist_token.expand(x.shape[0], -1, -1),
102+
x),
103+
dim=1)
104+
x = x + pos_embed
105+
return self.pos_drop(x)
115106

116107
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
117108
x, x_dist = x[:, 0], x[:, 1]

timm/models/vision_transformer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,11 +459,8 @@ def __init__(
459459

460460
embed_args = {}
461461
if dynamic_size:
462-
embed_args.update(dict(
463-
strict_img_size=False,
464-
flatten=False, # flatten deferred until after pos embed
465-
output_fmt='NHWC',
466-
))
462+
# flatten deferred until after pos embed
463+
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
467464
self.patch_embed = embed_layer(
468465
img_size=img_size,
469466
patch_size=patch_size,
@@ -559,7 +556,11 @@ def reset_classifier(self, num_classes: int, global_pool=None):
559556
def _pos_embed(self, x):
560557
if self.dynamic_size:
561558
B, H, W, C = x.shape
562-
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
559+
pos_embed = resample_abs_pos_embed(
560+
self.pos_embed,
561+
(H, W),
562+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
563+
)
563564
x = x.view(B, -1, C)
564565
else:
565566
pos_embed = self.pos_embed

0 commit comments

Comments
 (0)