Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.config import DtypeLike
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
Expand Down Expand Up @@ -483,7 +484,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
ret: NdarrayOrTensor
if self.minv is not None or self.maxv is not None:
if self.channel_wise:
Expand Down Expand Up @@ -542,7 +543,7 @@ def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor:
factor = factor if factor is not None else self.factor

img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
ret: NdarrayOrTensor
if self.channel_wise:
out = []
Expand Down Expand Up @@ -1168,7 +1169,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
if self.channel_wise:
img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore
else:
Expand Down Expand Up @@ -1433,7 +1434,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
if self.channel_wise:
img_t = torch.stack([self._normalize(img=d) for d in img_t]) # type: ignore
else:
Expand Down Expand Up @@ -1530,7 +1531,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

"""
img = convert_to_tensor(img, track_meta=get_track_meta())
self.img_t = convert_to_tensor(img, track_meta=False)
self.img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img)

# add one to transform axis because a batch axis will be added at dimension 0
savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode)
Expand Down Expand Up @@ -1907,7 +1908,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen

if self.reference_control_points is None or self.floating_control_points is None:
raise RuntimeError("please call the `randomize()` function first.")
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
img_min, img_max = img_t.min(), img_t.max()
if img_min == img_max:
warn(
Expand Down Expand Up @@ -1952,7 +1953,7 @@ def __init__(self, alpha: float = 0.1) -> None:

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img
n_dims = len(img_t.shape[1:])

# FT
Expand Down Expand Up @@ -2604,7 +2605,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
img: image to remap.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_ = convert_to_tensor(img, track_meta=False)
img_ = img.as_tensor() if isinstance(img, MetaTensor) else img
# sample noise
vals_to_sample = torch.unique(img_).tolist()
noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size))
Expand Down
7 changes: 4 additions & 3 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import warnings
from collections.abc import Callable, Iterable, Sequence
from typing import cast

import numpy as np
import torch
Expand Down Expand Up @@ -338,7 +339,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
else:
applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0))
img = convert_to_tensor(img, track_meta=get_track_meta())
img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img)
if self.independent:
for i in applied_labels:
foreground = img_[i] > 0 if is_onehot else img_[0] == i
Expand Down Expand Up @@ -497,7 +498,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

if isinstance(img, torch.Tensor):
img = convert_to_tensor(img, track_meta=get_track_meta())
img_ = convert_to_tensor(img, track_meta=False)
img_ = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img)
if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0
appl_lbls = torch.as_tensor(self.applied_labels, device=img_.device)
out = torch.where(torch.isin(img_, appl_lbls), img_, torch.tensor(0.0).to(img_))
Expand Down Expand Up @@ -623,7 +624,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img)
spatial_dims = len(img_.shape) - 1
img_ = img_.unsqueeze(0) # adds a batch dim
if spatial_dims == 2:
Expand Down
Loading