@@ -73,45 +73,36 @@ def reset_classifier(self, num_classes, global_pool=None):
7373 def set_distilled_training (self , enable = True ):
7474 self .distilled_training = enable
7575
76- def _intermediate_layers (
77- self ,
78- x : torch .Tensor ,
79- n : Union [int , Sequence ] = 1 ,
80- ):
81- outputs , num_blocks = [], len (self .blocks )
82- take_indices = set (range (num_blocks - n , num_blocks ) if isinstance (n , int ) else n )
83-
84- # forward pass
85- x = self .patch_embed (x )
86- x = torch .cat ((
87- self .cls_token .expand (x .shape [0 ], - 1 , - 1 ),
88- self .dist_token .expand (x .shape [0 ], - 1 , - 1 ),
89- x ),
90- dim = 1 )
91- x = self .pos_drop (x + self .pos_embed )
92- x = self .patch_drop (x )
93- x = self .norm_pre (x )
94- for i , blk in enumerate (self .blocks ):
95- x = blk (x )
96- if i in take_indices :
97- outputs .append (x )
98-
99- return outputs
100-
101- def forward_features (self , x ) -> torch .Tensor :
102- x = self .patch_embed (x )
103- x = torch .cat ((
104- self .cls_token .expand (x .shape [0 ], - 1 , - 1 ),
105- self .dist_token .expand (x .shape [0 ], - 1 , - 1 ),
106- x ),
107- dim = 1 )
108- x = self .pos_drop (x + self .pos_embed )
109- if self .grad_checkpointing and not torch .jit .is_scripting ():
110- x = checkpoint_seq (self .blocks , x )
76+ def _pos_embed (self , x ):
77+ if self .dynamic_size :
78+ B , H , W , C = x .shape
79+ pos_embed = resample_abs_pos_embed (
80+ self .pos_embed ,
81+ (H , W ),
82+ num_prefix_tokens = 0 if self .no_embed_class else self .num_prefix_tokens ,
83+ )
84+ x = x .view (B , - 1 , C )
11185 else :
112- x = self .blocks (x )
113- x = self .norm (x )
114- return x
86+ pos_embed = self .pos_embed
87+ if self .no_embed_class :
88+ # deit-3, updated JAX (big vision)
89+ # position embedding does not overlap with class token, add then concat
90+ x = x + pos_embed
91+ x = torch .cat ((
92+ self .cls_token .expand (x .shape [0 ], - 1 , - 1 ),
93+ self .dist_token .expand (x .shape [0 ], - 1 , - 1 ),
94+ x ),
95+ dim = 1 )
96+ else :
97+ # original timm, JAX, and deit vit impl
98+ # pos_embed has entry for class token, concat then add
99+ x = torch .cat ((
100+ self .cls_token .expand (x .shape [0 ], - 1 , - 1 ),
101+ self .dist_token .expand (x .shape [0 ], - 1 , - 1 ),
102+ x ),
103+ dim = 1 )
104+ x = x + pos_embed
105+ return self .pos_drop (x )
115106
116107 def forward_head (self , x , pre_logits : bool = False ) -> torch .Tensor :
117108 x , x_dist = x [:, 0 ], x [:, 1 ]
0 commit comments