Skip to content

Commit 569d114

Browse files
authored
Fix device problem
Before, the one_hot could only run in device='cuda'. Now it will run on input device automatically.
1 parent 01a0e25 commit 569d114

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

timm/data/mixup.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
import torch
1515

1616

17-
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
17+
def one_hot(x, num_classes, on_value=1., off_value=0.):
1818
x = x.long().view(-1, 1)
19+
device = x.device
1920
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
2021

2122

22-
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23+
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
2324
off_value = smoothing / num_classes
2425
on_value = 1. - smoothing + off_value
25-
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26-
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
26+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
27+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
2728
return y1 * lam + y2 * (1. - lam)
2829

2930

@@ -214,7 +215,7 @@ def __call__(self, x, target):
214215
lam = self._mix_pair(x)
215216
else:
216217
lam = self._mix_batch(x)
217-
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
218219
return x, target
219220

220221

0 commit comments

Comments
 (0)