|
13 | 13 |
|
14 | 14 | import unittest |
15 | 15 |
|
16 | | -import numpy |
| 16 | +import numpy as np |
17 | 17 | import torch |
18 | 18 |
|
19 | 19 | from monai.metrics import FBetaScore |
20 | 20 | from tests.utils import assert_allclose |
21 | 21 |
|
| 22 | +_device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| 23 | + |
22 | 24 |
|
23 | 25 | class TestFBetaScore(unittest.TestCase): |
24 | | - def test_expecting_success(self): |
| 26 | + def test_expecting_success_and_device(self): |
25 | 27 | metric = FBetaScore() |
26 | | - metric( |
27 | | - y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) |
28 | | - ) |
29 | | - assert_allclose(metric.aggregate()[0], torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6) |
| 28 | + y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device) |
| 29 | + y = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], device=_device) |
| 30 | + metric(y_pred=y_pred, y=y) |
| 31 | + result = metric.aggregate()[0] |
| 32 | + assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6) |
| 33 | + np.testing.assert_equal(result.device, y_pred.device) |
30 | 34 |
|
31 | 35 | def test_expecting_success2(self): |
32 | 36 | metric = FBetaScore(beta=0.5) |
@@ -58,7 +62,7 @@ def test_with_nan_values(self): |
58 | 62 | metric = FBetaScore(get_not_nans=True) |
59 | 63 | metric( |
60 | 64 | y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), |
61 | | - y=torch.Tensor([[1, 0, 1], [numpy.NaN, numpy.NaN, numpy.NaN], [1, 0, 1]]), |
| 65 | + y=torch.Tensor([[1, 0, 1], [np.NaN, np.NaN, np.NaN], [1, 0, 1]]), |
62 | 66 | ) |
63 | 67 | assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6) |
64 | 68 |
|
|
0 commit comments