|
14 | 14 | import torch |
15 | 15 |
|
16 | 16 |
|
17 | | -def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): |
| 17 | +def one_hot(x, num_classes, on_value=1., off_value=0.): |
18 | 18 | x = x.long().view(-1, 1) |
| 19 | + device = x.device |
19 | 20 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) |
20 | 21 |
|
21 | 22 |
|
22 | | -def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): |
| 23 | +def mixup_target(target, num_classes, lam=1., smoothing=0.0): |
23 | 24 | off_value = smoothing / num_classes |
24 | 25 | on_value = 1. - smoothing + off_value |
25 | | - y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) |
26 | | - y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) |
| 26 | + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) |
| 27 | + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) |
27 | 28 | return y1 * lam + y2 * (1. - lam) |
28 | 29 |
|
29 | 30 |
|
@@ -214,7 +215,7 @@ def __call__(self, x, target): |
214 | 215 | lam = self._mix_pair(x) |
215 | 216 | else: |
216 | 217 | lam = self._mix_batch(x) |
217 | | - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) |
| 218 | + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) |
218 | 219 | return x, target |
219 | 220 |
|
220 | 221 |
|
|
0 commit comments