Skip to content

Commit 3c8f6c6

Browse files
5919 fix generalized dice issue (#5929)
Signed-off-by: Yiheng Wang <vennw@nvidia.com> Fixes #5919 . ### Description This PR is used to fix the device issue of function `compute_generalized_dice`, and cuda tensor input will not raise errors. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent 0d38e4b commit 3c8f6c6

12 files changed

+55
-25
lines changed

monai/metrics/generalized_dice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def compute_generalized_dice(
176176
y_pred_o = y_pred_o.sum(dim=-1)
177177
denom_zeros = denom == 0
178178
generalized_dice_score[denom_zeros] = torch.where(
179-
(y_pred_o == 0)[denom_zeros], torch.tensor(1.0), torch.tensor(0.0)
179+
(y_pred_o == 0)[denom_zeros],
180+
torch.tensor(1.0, device=generalized_dice_score.device),
181+
torch.tensor(0.0, device=generalized_dice_score.device),
180182
)
181183

182184
return generalized_dice_score

tests/test_compute_confusion_matrix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import unittest
1515
from typing import Any
1616

17+
import numpy as np
1718
import torch
1819
from parameterized import parameterized
1920

@@ -220,6 +221,7 @@ def test_value(self, input_data, expected_value):
220221
input_data["include_background"] = False
221222
result = get_confusion_matrix(**input_data)
222223
assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4)
224+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
223225

224226
@parameterized.expand(TEST_CASES_COMPUTE_SAMPLE)
225227
def test_compute_sample(self, input_data, expected_value):

tests/test_compute_f_beta.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313

1414
import unittest
1515

16-
import numpy
16+
import numpy as np
1717
import torch
1818

1919
from monai.metrics import FBetaScore
2020
from tests.utils import assert_allclose
2121

22+
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
23+
2224

2325
class TestFBetaScore(unittest.TestCase):
24-
def test_expecting_success(self):
26+
def test_expecting_success_and_device(self):
2527
metric = FBetaScore()
26-
metric(
27-
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
28-
)
29-
assert_allclose(metric.aggregate()[0], torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
28+
y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device)
29+
y = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], device=_device)
30+
metric(y_pred=y_pred, y=y)
31+
result = metric.aggregate()[0]
32+
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
33+
np.testing.assert_equal(result.device, y_pred.device)
3034

3135
def test_expecting_success2(self):
3236
metric = FBetaScore(beta=0.5)
@@ -58,7 +62,7 @@ def test_with_nan_values(self):
5862
metric = FBetaScore(get_not_nans=True)
5963
metric(
6064
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]),
61-
y=torch.Tensor([[1, 0, 1], [numpy.NaN, numpy.NaN, numpy.NaN], [1, 0, 1]]),
65+
y=torch.Tensor([[1, 0, 1], [np.NaN, np.NaN, np.NaN], [1, 0, 1]]),
6266
)
6367
assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6)
6468

tests/test_compute_generalized_dice.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
from monai.metrics import GeneralizedDiceScore, compute_generalized_dice
2121

22+
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
23+
2224
# keep background
2325
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1)
2426
{
25-
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
26-
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
27+
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
28+
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
2729
"include_background": True,
2830
},
2931
[0.8],
@@ -116,7 +118,12 @@
116118
TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]]
117119

118120

119-
class TestComputeMeanDice(unittest.TestCase):
121+
class TestComputeGeneralizedDiceScore(unittest.TestCase):
122+
@parameterized.expand([TEST_CASE_1])
123+
def test_device(self, input_data, _expected_value):
124+
result = compute_generalized_dice(**input_data)
125+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
126+
120127
# Functional part tests
121128
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
122129
def test_value(self, input_data, expected_value):

tests/test_compute_meandice.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class TestComputeMeanDice(unittest.TestCase):
191191
def test_value(self, input_data, expected_value):
192192
result = compute_dice(**input_data)
193193
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
194+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
194195

195196
@parameterized.expand([TEST_CASE_3])
196197
def test_nans(self, input_data, expected_value):

tests/test_compute_meaniou.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class TestComputeMeanIoU(unittest.TestCase):
191191
def test_value(self, input_data, expected_value):
192192
result = compute_meaniou(**input_data)
193193
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
194+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
194195

195196
@parameterized.expand([TEST_CASE_3])
196197
def test_nans(self, input_data, expected_value):

tests/test_compute_panoptic_quality.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class TestPanopticQualityMetric(unittest.TestCase):
9696
def test_value(self, input_params, expected_value):
9797
result = compute_panoptic_quality(**input_params)
9898
np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4)
99+
np.testing.assert_equal(result.device, input_params["pred"].device)
99100

100101
@parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5])
101102
def test_value_class(self, input_params, y_pred, y_gt, expected_value):

tests/test_compute_variance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class TestComputeVariance(unittest.TestCase):
113113
def test_value(self, input_data, expected_value):
114114
result = compute_variance(**input_data)
115115
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
116+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
116117

117118
@parameterized.expand([TEST_CASE_5, TEST_CASE_6])
118119
def test_spatial_case(self, input_data, expected_value):

tests/test_hausdorff_distance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from monai.metrics import HausdorffDistanceMetric
2121

22+
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
23+
2224

2325
def create_spherical_seg_3d(
2426
radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99)
@@ -116,8 +118,8 @@ def test_value(self, input_data, expected_value):
116118
else:
117119
[seg_1, seg_2] = input_data
118120
ct = 0
119-
seg_1 = torch.tensor(seg_1)
120-
seg_2 = torch.tensor(seg_2)
121+
seg_1 = torch.tensor(seg_1, device=_device)
122+
seg_2 = torch.tensor(seg_2, device=_device)
121123
for metric in ["euclidean", "chessboard", "taxicab"]:
122124
for directed in [True, False]:
123125
hd_metric = HausdorffDistanceMetric(
@@ -130,7 +132,8 @@ def test_value(self, input_data, expected_value):
130132
hd_metric(batch_seg_1, batch_seg_2)
131133
result = hd_metric.aggregate(reduction="mean")
132134
expected_value_curr = expected_value[ct]
133-
np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7)
135+
np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-7)
136+
np.testing.assert_equal(result.device, seg_1.device)
134137
ct += 1
135138

136139
@parameterized.expand(TEST_CASES_NANS)

tests/test_label_quality_score.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class TestLabelQualityScore(unittest.TestCase):
103103
def test_value(self, input_data, expected_value):
104104
result = label_quality_score(**input_data)
105105
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
106+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
106107

107108
@parameterized.expand([TEST_CASE_6, TEST_CASE_7])
108109
def test_spatial_case(self, input_data, expected_value):

0 commit comments

Comments
 (0)