From 71fcfceabaf0605dcc07bbbf13efd7ed3165b974 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Thu, 29 Jul 2021 12:56:13 -0700 Subject: [PATCH] reset_classifier: fix it actually to work, expose toggle param for forcing new head to gpu --- models/cswin.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/models/cswin.py b/models/cswin.py index 9a74de1..ae1b2e7 100644 --- a/models/cswin.py +++ b/models/cswin.py @@ -326,16 +326,23 @@ def no_weight_decay(self): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): - if self.num_classes != num_classes: - print ('reset head to', num_classes) + def reset_classifier(self, num_classes, force=False, to_gpu=False): + if self.num_classes != num_classes or force: + print("reset head to", num_classes) self.num_classes = num_classes - self.head = nn.Linear(self.out_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head = self.head.cuda() - trunc_normal_(self.head.weight, std=.02) + self.head = ( + nn.Linear(self.head.in_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + if to_gpu: + self.head = self.head.cuda() + # init new head + trunc_normal_(self.head.weight, std=0.02) if self.head.bias is not None: nn.init.constant_(self.head.bias, 0) + def forward_features(self, x): B = x.shape[0] x = self.stage1_conv_embed(x)