Skip to content

Commit fd592ec

Browse files
committed
Fix an issue with FastCollateMixup still using device
1 parent 569d114 commit fd592ec

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

timm/data/mixup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
def one_hot(x, num_classes, on_value=1., off_value=0.):
1818
x = x.long().view(-1, 1)
19-
device = x.device
20-
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
19+
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
2120

2221

2322
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
@@ -311,7 +310,7 @@ def __call__(self, batch, _=None):
311310
else:
312311
lam = self._mix_batch_collate(output, batch)
313312
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
314-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
313+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
315314
target = target[:batch_size]
316315
return output, target
317316

0 commit comments

Comments
 (0)