|
14 | 14 | import torch |
15 | 15 |
|
16 | 16 |
|
17 | | -def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): |
| 17 | +def one_hot(x, num_classes, on_value=1., off_value=0.): |
18 | 18 | x = x.long().view(-1, 1) |
19 | | - return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) |
| 19 | + return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value) |
20 | 20 |
|
21 | 21 |
|
22 | | -def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): |
| 22 | +def mixup_target(target, num_classes, lam=1., smoothing=0.0): |
23 | 23 | off_value = smoothing / num_classes |
24 | 24 | on_value = 1. - smoothing + off_value |
25 | | - y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) |
26 | | - y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) |
| 25 | + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) |
| 26 | + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) |
27 | 27 | return y1 * lam + y2 * (1. - lam) |
28 | 28 |
|
29 | 29 |
|
@@ -214,7 +214,7 @@ def __call__(self, x, target): |
214 | 214 | lam = self._mix_pair(x) |
215 | 215 | else: |
216 | 216 | lam = self._mix_batch(x) |
217 | | - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) |
| 217 | + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) |
218 | 218 | return x, target |
219 | 219 |
|
220 | 220 |
|
@@ -310,7 +310,7 @@ def __call__(self, batch, _=None): |
310 | 310 | else: |
311 | 311 | lam = self._mix_batch_collate(output, batch) |
312 | 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) |
313 | | - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') |
| 313 | + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) |
314 | 314 | target = target[:batch_size] |
315 | 315 | return output, target |
316 | 316 |
|
0 commit comments