Skip to content

Commit f4902b2

Browse files
authored
transforms to have multi-sample trait (#6003)
Fixes #6002 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 54f4cfc commit f4902b2

File tree

10 files changed

+112
-100
lines changed

10 files changed

+112
-100
lines changed

monai/transforms/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -453,18 +453,8 @@
453453
ZoomD,
454454
ZoomDict,
455455
)
456-
from .transform import (
457-
LazyTrait,
458-
LazyTransform,
459-
MapTransform,
460-
MultiSampleTrait,
461-
Randomizable,
462-
RandomizableTrait,
463-
RandomizableTransform,
464-
ThreadUnsafe,
465-
Transform,
466-
apply_transform,
467-
)
456+
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe
457+
from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
468458
from .utility.array import (
469459
AddChannel,
470460
AddCoordinateChannels,

monai/transforms/croppad/array.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from monai.data.meta_tensor import MetaTensor
3131
from monai.data.utils import get_random_patch, get_valid_patch_size
3232
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
33+
from monai.transforms.traits import MultiSampleTrait
3334
from monai.transforms.transform import Randomizable, Transform
3435
from monai.transforms.utils import (
3536
compute_divisible_spatial_size,
@@ -683,7 +684,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
683684
return super().__call__(img=img, randomize=randomize)
684685

685686

686-
class RandSpatialCropSamples(Randomizable, TraceableTransform):
687+
class RandSpatialCropSamples(Randomizable, TraceableTransform, MultiSampleTrait):
687688
"""
688689
Crop image with random size or specific size ROI to generate a list of N samples.
689690
It can crop at a random position as center or at the image center. And allows to set
@@ -893,7 +894,7 @@ def inverse(self, img: MetaTensor) -> MetaTensor:
893894
return super().inverse(inv)
894895

895896

896-
class RandWeightedCrop(Randomizable, TraceableTransform):
897+
class RandWeightedCrop(Randomizable, TraceableTransform, MultiSampleTrait):
897898
"""
898899
Samples a list of `num_samples` image patches according to the provided `weight_map`.
899900
@@ -958,7 +959,7 @@ def __call__(
958959
return results
959960

960961

961-
class RandCropByPosNegLabel(Randomizable, TraceableTransform):
962+
class RandCropByPosNegLabel(Randomizable, TraceableTransform, MultiSampleTrait):
962963
"""
963964
Crop random fixed sized regions with the center being a foreground or background voxel
964965
based on the Pos Neg Ratio.
@@ -1118,7 +1119,7 @@ def __call__(
11181119
return results
11191120

11201121

1121-
class RandCropByLabelClasses(Randomizable, TraceableTransform):
1122+
class RandCropByLabelClasses(Randomizable, TraceableTransform, MultiSampleTrait):
11221123
"""
11231124
Crop random fixed sized regions with the center being a class based on the specified ratios of every class.
11241125
The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the

monai/transforms/croppad/dictionary.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
SpatialPad,
4848
)
4949
from monai.transforms.inverse import InvertibleTransform
50+
from monai.transforms.traits import MultiSampleTrait
5051
from monai.transforms.transform import MapTransform, Randomizable
5152
from monai.transforms.utils import is_positive
5253
from monai.utils import MAX_SEED, Method, PytorchPadMode, deprecated_arg_default, ensure_tuple_rep
@@ -528,7 +529,7 @@ def __init__(
528529
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys)
529530

530531

531-
class RandSpatialCropSamplesd(Randomizable, MapTransform):
532+
class RandSpatialCropSamplesd(Randomizable, MapTransform, MultiSampleTrait):
532533
"""
533534
Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`.
534535
Crop image with random size or specific size ROI to generate a list of N samples.
@@ -682,7 +683,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
682683
return d
683684

684685

685-
class RandWeightedCropd(Randomizable, MapTransform):
686+
class RandWeightedCropd(Randomizable, MapTransform, MultiSampleTrait):
686687
"""
687688
Samples a list of `num_samples` image patches according to the provided `weight_map`.
688689
@@ -739,7 +740,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable,
739740
return ret
740741

741742

742-
class RandCropByPosNegLabeld(Randomizable, MapTransform):
743+
class RandCropByPosNegLabeld(Randomizable, MapTransform, MultiSampleTrait):
743744
"""
744745
Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.
745746
Crop random fixed sized regions with the center being a foreground or background voxel
@@ -860,7 +861,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable,
860861
return ret
861862

862863

863-
class RandCropByLabelClassesd(Randomizable, MapTransform):
864+
class RandCropByLabelClassesd(Randomizable, MapTransform, MultiSampleTrait):
864865
"""
865866
Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`.
866867
Crop random fixed sized regions with the center being a class based on the specified ratios of every class.

monai/transforms/nvtx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from monai.transforms.transform import RandomizableTrait, Transform
17+
from monai.transforms.traits import RandomizableTrait
18+
from monai.transforms.transform import Transform
1819
from monai.utils import optional_import
1920

2021
_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")

monai/transforms/spatial/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
3636
from monai.transforms.intensity.array import GaussianSmooth
3737
from monai.transforms.inverse import InvertibleTransform
38+
from monai.transforms.traits import MultiSampleTrait
3839
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
3940
from monai.transforms.utils import (
4041
convert_pad_mode,
@@ -3045,7 +3046,7 @@ def __call__(
30453046
return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)
30463047

30473048

3048-
class GridSplit(Transform):
3049+
class GridSplit(Transform, MultiSampleTrait):
30493050
"""
30503051
Split the image into patches based on the provided grid in 2D.
30513052
@@ -3130,7 +3131,7 @@ def _get_params(self, image_size: Sequence[int] | np.ndarray, size: Sequence[int
31303131
return size, steps
31313132

31323133

3133-
class GridPatch(Transform):
3134+
class GridPatch(Transform, MultiSampleTrait):
31343135
"""
31353136
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.
31363137
It can sort the patches and return all or a subset of them.
@@ -3257,7 +3258,7 @@ def __call__(self, array: NdarrayOrTensor):
32573258
return output
32583259

32593260

3260-
class RandGridPatch(GridPatch, RandomizableTransform):
3261+
class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
32613262
"""
32623263
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,
32633264
and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.

monai/transforms/spatial/dictionary.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
SpatialResample,
5454
Zoom,
5555
)
56+
from monai.transforms.traits import MultiSampleTrait
5657
from monai.transforms.transform import MapTransform, RandomizableTransform
5758
from monai.transforms.utils import create_grid
5859
from monai.utils import (
@@ -1779,7 +1780,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
17791780
return d
17801781

17811782

1782-
class GridSplitd(MapTransform):
1783+
class GridSplitd(MapTransform, MultiSampleTrait):
17831784
"""
17841785
Split the image into patches based on the provided grid in 2D.
17851786
@@ -1820,7 +1821,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> list[dict[Hashab
18201821
return output
18211822

18221823

1823-
class GridPatchd(MapTransform):
1824+
class GridPatchd(MapTransform, MultiSampleTrait):
18241825
"""
18251826
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.
18261827
It can sort the patches and return all or a subset of them.
@@ -1884,7 +1885,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
18841885
return d
18851886

18861887

1887-
class RandGridPatchd(RandomizableTransform, MapTransform):
1888+
class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait):
18881889
"""
18891890
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,
18901891
and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.

monai/transforms/traits.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
A collection of generic traits for MONAI transforms.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
__all__ = ["LazyTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"]
18+
19+
20+
class LazyTrait:
21+
"""
22+
An interface to indicate that the transform has the capability to execute using
23+
MONAI's lazy resampling feature. In order to do this, the implementing class needs
24+
to be able to describe its operation as an affine matrix or grid with accompanying metadata.
25+
This interface can be extended from by people adapting transforms to the MONAI framework as
26+
well as by implementors of MONAI transforms.
27+
"""
28+
29+
@property
30+
def lazy_evaluation(self):
31+
"""
32+
Get whether lazy_evaluation is enabled for this transform instance.
33+
Returns:
34+
True if the transform is operating in a lazy fashion, False if not.
35+
"""
36+
raise NotImplementedError()
37+
38+
@lazy_evaluation.setter
39+
def lazy_evaluation(self, enabled: bool):
40+
"""
41+
Set whether lazy_evaluation is enabled for this transform instance.
42+
Args:
43+
enabled: True if the transform should operate in a lazy fashion, False if not.
44+
"""
45+
raise NotImplementedError()
46+
47+
48+
class RandomizableTrait:
49+
"""
50+
An interface to indicate that the transform has the capability to perform
51+
randomized transforms to the data that it is called upon. This interface
52+
can be extended from by people adapting transforms to the MONAI framework as well as by
53+
implementors of MONAI transforms.
54+
"""
55+
56+
pass
57+
58+
59+
class MultiSampleTrait:
60+
"""
61+
An interface to indicate that the transform has the capability to return multiple samples
62+
given an input, such as when performing random crops of a sample. This interface can be
63+
extended from by people adapting transforms to the MONAI framework as well as by implementors
64+
of MONAI transforms.
65+
"""
66+
67+
pass
68+
69+
70+
class ThreadUnsafe:
71+
"""
72+
A class to denote that the transform will mutate its member variables,
73+
when being applied. Transforms inheriting this class should be used
74+
cautiously in a multi-thread context.
75+
76+
This type is typically used by :py:class:`monai.data.CacheDataset` and
77+
its extensions, where the transform cache is built with multiple threads.
78+
"""
79+
80+
pass

monai/transforms/transform.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@
2525
from monai import config, transforms
2626
from monai.config import KeysCollection
2727
from monai.data.meta_tensor import MetaTensor
28+
from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe
2829
from monai.utils import MAX_SEED, ensure_tuple, first
2930
from monai.utils.enums import TransformBackends
3031
from monai.utils.misc import MONAIEnvVars
3132

3233
__all__ = [
3334
"ThreadUnsafe",
3435
"apply_transform",
35-
"LazyTrait",
36-
"RandomizableTrait",
37-
"MultiSampleTrait",
3836
"Randomizable",
3937
"LazyTransform",
4038
"RandomizableTransform",
@@ -132,69 +130,6 @@ def _log_stats(data, prefix: str | None = "Data"):
132130
raise RuntimeError(f"applying transform {transform}") from e
133131

134132

135-
class LazyTrait:
136-
"""
137-
An interface to indicate that the transform has the capability to execute using
138-
MONAI's lazy resampling feature. In order to do this, the implementing class needs
139-
to be able to describe its operation as an affine matrix or grid with accompanying metadata.
140-
This interface can be extended from by people adapting transforms to the MONAI framework as
141-
well as by implementors of MONAI transforms.
142-
"""
143-
144-
@property
145-
def lazy_evaluation(self):
146-
"""
147-
Get whether lazy_evaluation is enabled for this transform instance.
148-
Returns:
149-
True if the transform is operating in a lazy fashion, False if not.
150-
"""
151-
raise NotImplementedError()
152-
153-
@lazy_evaluation.setter
154-
def lazy_evaluation(self, enabled: bool):
155-
"""
156-
Set whether lazy_evaluation is enabled for this transform instance.
157-
Args:
158-
enabled: True if the transform should operate in a lazy fashion, False if not.
159-
"""
160-
raise NotImplementedError()
161-
162-
163-
class RandomizableTrait:
164-
"""
165-
An interface to indicate that the transform has the capability to perform
166-
randomized transforms to the data that it is called upon. This interface
167-
can be extended from by people adapting transforms to the MONAI framework as well as by
168-
implementors of MONAI transforms.
169-
"""
170-
171-
pass
172-
173-
174-
class MultiSampleTrait:
175-
"""
176-
An interface to indicate that the transform has the capability to return multiple samples
177-
given an input, such as when performing random crops of a sample. This interface can be
178-
extended from by people adapting transforms to the MONAI framework as well as by implementors
179-
of MONAI transforms.
180-
"""
181-
182-
pass
183-
184-
185-
class ThreadUnsafe:
186-
"""
187-
A class to denote that the transform will mutate its member variables,
188-
when being applied. Transforms inheriting this class should be used
189-
cautiously in a multi-thread context.
190-
191-
This type is typically used by :py:class:`monai.data.CacheDataset` and
192-
its extensions, where the transform cache is built with multiple threads.
193-
"""
194-
195-
pass
196-
197-
198133
class Randomizable(ThreadUnsafe, RandomizableTrait):
199134
"""
200135
An interface for handling random state locally, currently based on a class

monai/transforms/utility/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
median_filter,
4545
)
4646
from monai.transforms.inverse import InvertibleTransform
47+
from monai.transforms.traits import MultiSampleTrait
4748
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
4849
from monai.transforms.utils import (
4950
extreme_points_to_image,
@@ -343,7 +344,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
343344
return out
344345

345346

346-
class SplitDim(Transform):
347+
class SplitDim(Transform, MultiSampleTrait):
347348
"""
348349
Given an image of size X along a certain dimension, return a list of length X containing
349350
images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into
@@ -976,7 +977,7 @@ def __call__(
976977
return data
977978

978979

979-
class FgBgToIndices(Transform):
980+
class FgBgToIndices(Transform, MultiSampleTrait):
980981
"""
981982
Compute foreground and background of the input label data, return the indices.
982983
If no output_shape specified, output data will be 1 dim indices after flattening.
@@ -1017,7 +1018,7 @@ def __call__(
10171018
return fg_indices, bg_indices
10181019

10191020

1020-
class ClassesToIndices(Transform):
1021+
class ClassesToIndices(Transform, MultiSampleTrait):
10211022
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
10221023

10231024
def __init__(

0 commit comments

Comments
 (0)