99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
12+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
1313
1414import numpy as np
1515import torch
2323from monai .transforms .transform import Randomizable
2424from monai .transforms .utils import allow_missing_keys_mode
2525from monai .utils .enums import CommonKeys , InverseKeys
26+ from monai .utils .module import optional_import
27+
28+ if TYPE_CHECKING :
29+ from tqdm import tqdm
30+
31+ has_tqdm = True
32+ else :
33+ tqdm , has_tqdm = optional_import ("tqdm" , name = "tqdm" )
2634
2735__all__ = ["TestTimeAugmentation" ]
2836
@@ -57,6 +65,7 @@ class TestTimeAugmentation:
5765 return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
5866 full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
5967 equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
68+ progress: whether to display a progress bar.
6069
6170 Example:
6271 .. code-block:: python
@@ -80,6 +89,7 @@ def __init__(
8089 image_key = CommonKeys .IMAGE ,
8190 label_key = CommonKeys .LABEL ,
8291 return_full_data : bool = False ,
92+ progress : bool = True ,
8393 ) -> None :
8494 self .transform = transform
8595 self .batch_size = batch_size
@@ -89,6 +99,7 @@ def __init__(
8999 self .image_key = image_key
90100 self .label_key = label_key
91101 self .return_full_data = return_full_data
102+ self .progress = progress
92103
93104 # check that the transform has at least one random component, and that all random transforms are invertible
94105 self ._check_transforms ()
@@ -143,7 +154,7 @@ def __call__(
143154
144155 outputs : List [np .ndarray ] = []
145156
146- for batch_data in dl :
157+ for batch_data in tqdm ( dl ) if has_tqdm and self . progress else dl :
147158
148159 batch_images = batch_data [self .image_key ].to (self .device )
149160
@@ -156,6 +167,10 @@ def __call__(
156167
157168 # create a dictionary containing the inferred batch and their transforms
158169 inferred_dict = {self .label_key : batch_output , label_transform_key : batch_data [label_transform_key ]}
170+ # if meta dict is present, add that too (required for some inverse transforms)
171+ label_meta_dict_key = self .label_key + "_meta_dict"
172+ if label_meta_dict_key in batch_data :
173+ inferred_dict [label_meta_dict_key ] = batch_data [label_meta_dict_key ]
159174
160175 # do inverse transformation (allow missing keys as only inverting label)
161176 with allow_missing_keys_mode (self .transform ): # type: ignore
@@ -171,7 +186,7 @@ def __call__(
171186 return output
172187
173188 # calculate metrics
174- mode : np . ndarray = np .apply_along_axis ( lambda x : np . bincount ( x ). argmax (), axis = 0 , arr = output .astype (np .int64 ))
189+ mode = np .array ( torch . mode ( torch . Tensor ( output .astype (np .int64 )), dim = 0 ). values )
175190 mean : np .ndarray = np .mean (output , axis = 0 ) # type: ignore
176191 std : np .ndarray = np .std (output , axis = 0 ) # type: ignore
177192 vvc : float = (np .std (output ) / np .mean (output )).item ()
0 commit comments