Skip to content

Commit b3eb73a

Browse files
authored
TTA progress bar, use torch for mode, add label meta_data (#1992)
TTA progress bar, use torch for mode, add label meta_data
1 parent d1f4e6f commit b3eb73a

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

monai/data/test_time_augmentation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
1313

1414
import numpy as np
1515
import torch
@@ -23,6 +23,14 @@
2323
from monai.transforms.transform import Randomizable
2424
from monai.transforms.utils import allow_missing_keys_mode
2525
from monai.utils.enums import CommonKeys, InverseKeys
26+
from monai.utils.module import optional_import
27+
28+
if TYPE_CHECKING:
29+
from tqdm import tqdm
30+
31+
has_tqdm = True
32+
else:
33+
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
2634

2735
__all__ = ["TestTimeAugmentation"]
2836

@@ -57,6 +65,7 @@ class TestTimeAugmentation:
5765
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
5866
full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
5967
equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
68+
progress: whether to display a progress bar.
6069
6170
Example:
6271
.. code-block:: python
@@ -80,6 +89,7 @@ def __init__(
8089
image_key=CommonKeys.IMAGE,
8190
label_key=CommonKeys.LABEL,
8291
return_full_data: bool = False,
92+
progress: bool = True,
8393
) -> None:
8494
self.transform = transform
8595
self.batch_size = batch_size
@@ -89,6 +99,7 @@ def __init__(
8999
self.image_key = image_key
90100
self.label_key = label_key
91101
self.return_full_data = return_full_data
102+
self.progress = progress
92103

93104
# check that the transform has at least one random component, and that all random transforms are invertible
94105
self._check_transforms()
@@ -143,7 +154,7 @@ def __call__(
143154

144155
outputs: List[np.ndarray] = []
145156

146-
for batch_data in dl:
157+
for batch_data in tqdm(dl) if has_tqdm and self.progress else dl:
147158

148159
batch_images = batch_data[self.image_key].to(self.device)
149160

@@ -156,6 +167,10 @@ def __call__(
156167

157168
# create a dictionary containing the inferred batch and their transforms
158169
inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]}
170+
# if meta dict is present, add that too (required for some inverse transforms)
171+
label_meta_dict_key = self.label_key + "_meta_dict"
172+
if label_meta_dict_key in batch_data:
173+
inferred_dict[label_meta_dict_key] = batch_data[label_meta_dict_key]
159174

160175
# do inverse transformation (allow missing keys as only inverting label)
161176
with allow_missing_keys_mode(self.transform): # type: ignore
@@ -171,7 +186,7 @@ def __call__(
171186
return output
172187

173188
# calculate metrics
174-
mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64))
189+
mode = np.array(torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values)
175190
mean: np.ndarray = np.mean(output, axis=0) # type: ignore
176191
std: np.ndarray = np.std(output, axis=0) # type: ignore
177192
vvc: float = (np.std(output) / np.mean(output)).item()

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def run_testsuit():
118118
"test_ensure_channel_firstd",
119119
"test_handler_early_stop",
120120
"test_handler_transform_inverter",
121+
"test_testtimeaugmentation",
121122
]
122123
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
123124

tests/test_testtimeaugmentation.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,37 @@
2323
from monai.networks.nets import UNet
2424
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined
2525
from monai.transforms.croppad.dictionary import SpatialPadd
26-
from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd
26+
from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd, Spacingd
2727
from monai.utils import optional_import, set_determinism
2828

2929
if TYPE_CHECKING:
3030
import tqdm
3131

3232
has_tqdm = True
33+
has_nib = True
3334
else:
3435
tqdm, has_tqdm = optional_import("tqdm")
36+
_, has_nib = optional_import("nibabel")
3537

3638
trange = partial(tqdm.trange, desc="training") if has_tqdm else range
3739

3840

3941
class TestTestTimeAugmentation(unittest.TestCase):
4042
@staticmethod
41-
def get_data(num_examples, input_size):
43+
def get_data(num_examples, input_size, include_label=True):
4244
custom_create_test_image_2d = partial(
4345
create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1
4446
)
4547
data = []
4648
for _ in range(num_examples):
4749
im, label = custom_create_test_image_2d()
48-
data.append({"image": im, "label": label})
50+
d = {}
51+
d["image"] = im
52+
d["image_meta_dict"] = {"affine": np.eye(4)}
53+
if include_label:
54+
d["label"] = label
55+
d["label_meta_dict"] = {"affine": np.eye(4)}
56+
data.append(d)
4957
return data[0] if num_examples == 1 else data
5058

5159
def setUp(self) -> None:
@@ -138,6 +146,17 @@ def test_single_transform(self):
138146
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x)
139147
tta(self.get_data(1, (20, 20)))
140148

149+
def test_image_no_label(self):
150+
transforms = RandFlipd(["image"])
151+
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, label_key="image")
152+
tta(self.get_data(1, (20, 20), include_label=False))
153+
154+
@unittest.skipUnless(has_nib, "Requires nibabel")
155+
def test_requires_meta_dict(self):
156+
transforms = Compose([RandFlipd("image"), Spacingd("image", (1, 1))])
157+
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, label_key="image")
158+
tta(self.get_data(1, (20, 20), include_label=False))
159+
141160

142161
if __name__ == "__main__":
143162
unittest.main()

0 commit comments

Comments
 (0)