From c575f22ee8514ffece0d6091ad1f3e0a5429cf63 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Thu, 7 May 2026 13:19:00 +0100 Subject: [PATCH 1/2] perf: skip redundant convert_to_tensor for MetaTensor strip Several transforms invoke ``convert_to_tensor`` twice on the same input: once with ``track_meta=get_track_meta()`` to keep the MetaTensor wrapper, then again with ``track_meta=False`` to obtain a plain-tensor view for the actual computation. The second call routes through ``_convert_tensor(data).to(dtype=..., device=..., memory_format=contiguous_format)`` (``monai/utils/type_conversion.py``) and is functionally equivalent to calling ``MetaTensor.as_tensor()`` for the purpose intended here, since the first call already enforced ``contiguous_format``. Replace the second call with an explicit ``as_tensor()`` strip: img_t = img.as_tensor() if isinstance(img, MetaTensor) else img This avoids a redundant ``convert_to_tensor`` dispatch and a no-op ``Tensor.to(memory_format=contiguous_format)`` per ``__call__``. Sites touched (all in transform ``__call__``, so per-sample per-epoch): * ``KeepLargestConnectedComponent`` (post/array.py) * ``RemoveSmallObjects`` (post/array.py) * ``LabelToContour`` (post/array.py) * ``ScaleIntensity`` (intensity/array.py) * ``ScaleIntensityFixedMean`` (intensity/array.py) * ``ClipIntensityPercentiles`` (intensity/array.py) * ``NormalizeIntensity`` (intensity/array.py) * ``SavitzkyGolaySmooth`` (intensity/array.py) * ``RandHistogramShift`` (intensity/array.py) * ``KSpaceSpikeNoise`` (intensity/array.py) * ``IntensityRemap`` (intensity/array.py) No behavioral change: the resulting tensor is the same underlying data (``as_tensor()`` returns the wrapped ``torch.Tensor`` without copying). Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/intensity/array.py | 17 +++++++++-------- monai/transforms/post/array.py | 7 ++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 23a57ae9fb..cb1ca9e6e8 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter @@ -483,7 +484,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) ret: NdarrayOrTensor if self.minv is not None or self.maxv is not None: if self.channel_wise: @@ -542,7 +543,7 @@ def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor: factor = factor if factor is not None else self.factor img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) ret: NdarrayOrTensor if self.channel_wise: out = [] @@ -1168,7 +1169,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if self.channel_wise: img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore else: @@ -1433,7 +1434,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if self.channel_wise: img_t = torch.stack([self._normalize(img=d) for d in img_t]) # type: ignore else: @@ -1530,7 +1531,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - self.img_t = convert_to_tensor(img, track_meta=False) + self.img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) @@ -1907,7 +1908,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) img_min, img_max = img_t.min(), img_t.max() if img_min == img_max: warn( @@ -1952,7 +1953,7 @@ def __init__(self, alpha: float = 0.1) -> None: def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = convert_to_tensor(img, track_meta=False) + img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) n_dims = len(img_t.shape[1:]) # FT @@ -2604,7 +2605,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: img: image to remap. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_ = convert_to_tensor(img, track_meta=False) + img_ = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) # sample noise vals_to_sample = torch.unique(img_).tolist() noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size)) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..0a43a2d820 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence +from typing import cast import numpy as np import torch @@ -338,7 +339,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0)) img = convert_to_tensor(img, track_meta=get_track_meta()) - img_: torch.Tensor = convert_to_tensor(img, track_meta=False) + img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if self.independent: for i in applied_labels: foreground = img_[i] > 0 if is_onehot else img_[0] == i @@ -497,7 +498,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if isinstance(img, torch.Tensor): img = convert_to_tensor(img, track_meta=get_track_meta()) - img_ = convert_to_tensor(img, track_meta=False) + img_ = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0 appl_lbls = torch.as_tensor(self.applied_labels, device=img_.device) out = torch.where(torch.isin(img_, appl_lbls), img_, torch.tensor(0.0).to(img_)) @@ -623,7 +624,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_: torch.Tensor = convert_to_tensor(img, track_meta=False) + img_: torch.Tensor = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) spatial_dims = len(img_.shape) - 1 img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: From 9478c5162393d0b6caff05451e8a884a8ebda949 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Fri, 8 May 2026 18:05:17 +0100 Subject: [PATCH 2/2] =?UTF-8?q?Fix=20mypy=20errors=20from=20convert=5Fto?= =?UTF-8?q?=5Ftensor=20=E2=86=92=20as=5Ftensor=20swap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The new `img.as_tensor() if isinstance(img, MetaTensor) else cast(...)` expression narrowed mypy's inference of `img_t` to `Tensor`, so downstream reassignments from `_clip` / `_normalize` / `interp` (which return `NdarrayOrTensor`) failed type-checking. Annotate `img_t` explicitly as `NdarrayOrTensor` and drop the redundant `cast(torch.Tensor, img)` so the variable matches the type used by the rest of each transform's body. Keep the `cast` only on `SavitzkyGolaySmooth.__call__` where `self.img_t` is declared as `torch.Tensor` in `__init__` and used unconditionally as a Tensor afterwards. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/intensity/array.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index cb1ca9e6e8..77ec206869 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -484,7 +484,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img ret: NdarrayOrTensor if self.minv is not None or self.maxv is not None: if self.channel_wise: @@ -543,7 +543,7 @@ def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor: factor = factor if factor is not None else self.factor img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img ret: NdarrayOrTensor if self.channel_wise: out = [] @@ -1169,7 +1169,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img if self.channel_wise: img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore else: @@ -1434,7 +1434,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Apply the transform to `img`. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img if self.channel_wise: img_t = torch.stack([self._normalize(img=d) for d in img_t]) # type: ignore else: @@ -1908,7 +1908,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img img_min, img_max = img_t.min(), img_t.max() if img_min == img_max: warn( @@ -1953,7 +1953,7 @@ def __init__(self, alpha: float = 0.1) -> None: def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) - img_t = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_t: NdarrayOrTensor = img.as_tensor() if isinstance(img, MetaTensor) else img n_dims = len(img_t.shape[1:]) # FT @@ -2605,7 +2605,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: img: image to remap. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - img_ = img.as_tensor() if isinstance(img, MetaTensor) else cast(torch.Tensor, img) + img_ = img.as_tensor() if isinstance(img, MetaTensor) else img # sample noise vals_to_sample = torch.unique(img_).tolist() noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size))