diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 23a57ae9fb..77ec206869 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter @@ -483,7 +484,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img ret: NdarrayOrTensor if self.minv is not None or self.maxv is not None: if self.channel_wise: @@ -542,7 +543,7 @@ def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor: factor = factor if factor is not None else self.factor img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img ret: NdarrayOrTensor if self.channel_wise: out = [] @@ -1168,7 +1169,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img if self.channel_wise: img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore else: @@ -1433,7 +1434,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img if self.channel_wise: img_t = torch.stack([self._normalize(img=d) for d in img_t]) # type: ignore else: @@ -1530,7 +1531,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - self.img_t = convert_to_tensor(img, track_meta=False) + self.img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) @@ -1907,7 +1908,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img img_min, img_max = img_t.min(), img_t.max() if img_min == img_max: warn( @@ -1952,7 +1953,7 @@ def __init__(self, alpha: float = 0.1) -> None: def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img n_dims = len(img_t.shape[1:]) # FT @@ -2604,7 +2605,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: img: image to remap. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_ = convert_to_tensor(img, track_meta=False) + img_ = img.as_tensor() if isinstance(img, MetaTensor) else img # sample noise vals_to_sample = torch.unique(img_).tolist() noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size)) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..0a43a2d820 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence +from typing import cast import numpy as np import torch @@ -338,7 +339,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0)) img = convert_to_tensor(img, track_meta=get_track_meta()) - img_: torch.Tensor = convert_to_tensor(img, track_meta=False) + img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if self.independent: for i in applied_labels: foreground = img_[i] > 0 if is_onehot else img_[0] == i @@ -497,7 +498,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if isinstance(img, torch.Tensor): img = convert_to_tensor(img, track_meta=get_track_meta()) - img_ = convert_to_tensor(img, track_meta=False) + img_ = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0 appl_lbls = torch.as_tensor(self.applied_labels, device=img_.device) out = torch.where(torch.isin(img_, appl_lbls), img_, torch.tensor(0.0).to(img_)) @@ -623,7 +624,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_: torch.Tensor = convert_to_tensor(img, track_meta=False) + img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) spatial_dims = len(img_.shape) - 1 img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: