Skip to content

Commit 8ce9a2c

Browse files
authored
Merge pull request #1222 from Leoooo333/master
Fix mixup/one_hot device problem
2 parents e0ec0f7 + fd592ec commit 8ce9a2c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

timm/data/mixup.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414
import torch
1515

1616

17-
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
17+
def one_hot(x, num_classes, on_value=1., off_value=0.):
1818
x = x.long().view(-1, 1)
19-
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
19+
return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
2020

2121

22-
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
22+
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
2323
off_value = smoothing / num_classes
2424
on_value = 1. - smoothing + off_value
25-
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26-
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
25+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
26+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
2727
return y1 * lam + y2 * (1. - lam)
2828

2929

@@ -214,7 +214,7 @@ def __call__(self, x, target):
214214
lam = self._mix_pair(x)
215215
else:
216216
lam = self._mix_batch(x)
217-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
217+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
218218
return x, target
219219

220220

@@ -310,7 +310,7 @@ def __call__(self, batch, _=None):
310310
else:
311311
lam = self._mix_batch_collate(output, batch)
312312
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
313+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
314314
target = target[:batch_size]
315315
return output, target
316316

0 commit comments

Comments
 (0)