Skip to content

Commit 663cfd0

Browse files
authored
5163 metatensor support for OneOf (#5217)
Signed-off-by: Wenqi Li <wenqil@nvidia.com> Fixes #5163 ### Description enhance metatensor compatibility for `OneOf` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent ce6daf7 commit 663cfd0

File tree

2 files changed

+54
-23
lines changed

2 files changed

+54
-23
lines changed

monai/transforms/compose.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919

20+
import monai
2021
from monai.transforms.inverse import InvertibleTransform
2122

2223
# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
@@ -254,26 +255,27 @@ def __call__(self, data):
254255
_transform = self.transforms[index]
255256
data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats)
256257
# if the data is a mapping (dictionary), append the OneOf transform to the end
257-
if isinstance(data, Mapping):
258-
for key in data.keys():
259-
if self.trace_key(key) in data:
258+
if isinstance(data, monai.data.MetaTensor):
259+
self.push_transform(data, extra_info={"index": index})
260+
elif isinstance(data, Mapping):
261+
for key in data: # dictionary not change size during iteration
262+
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
260263
self.push_transform(data, key, extra_info={"index": index})
261264
return data
262265

263266
def inverse(self, data):
264267
if len(self.transforms) == 0:
265268
return data
266-
if not isinstance(data, Mapping):
267-
raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")
268269

269-
# loop until we get an index and then break (since they'll all be the same)
270270
index = None
271-
for key in data.keys():
272-
if self.trace_key(key) in data:
273-
# get the index of the applied OneOf transform
274-
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
275-
# and then remove the OneOf transform
276-
self.pop_transform(data, key)
271+
if isinstance(data, monai.data.MetaTensor):
272+
index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"]
273+
elif isinstance(data, Mapping):
274+
for key in data:
275+
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
276+
index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
277+
else:
278+
raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.")
277279
if index is None:
278280
# no invertible transforms have been applied
279281
return data

tests/test_one_of.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
import numpy as np
1616
from parameterized import parameterized
1717

18+
from monai.data import MetaTensor
1819
from monai.transforms import (
1920
InvertibleTransform,
2021
OneOf,
22+
RandScaleIntensity,
2123
RandScaleIntensityd,
24+
RandShiftIntensity,
2225
RandShiftIntensityd,
26+
Resize,
2327
Resized,
2428
TraceableTransform,
2529
Transform,
@@ -106,10 +110,10 @@ def __init__(self, keys):
106110

107111
KEYS = ["x", "y"]
108112
TEST_INVERSES = [
109-
(OneOf((InvA(KEYS), InvB(KEYS))), True),
110-
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True),
111-
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True),
112-
(OneOf((NonInv(KEYS), NonInv(KEYS))), False),
113+
(OneOf((InvA(KEYS), InvB(KEYS))), True, True),
114+
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False),
115+
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False),
116+
(OneOf((NonInv(KEYS), NonInv(KEYS))), False, False),
113117
]
114118

115119

@@ -148,13 +152,17 @@ def _match(a, b):
148152
_match(p, f)
149153

150154
@parameterized.expand(TEST_INVERSES)
151-
def test_inverse(self, transform, invertible):
152-
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
155+
def test_inverse(self, transform, invertible, use_metatensor):
156+
data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}
153157
fwd_data = transform(data)
154158

155159
if invertible:
156160
for k in KEYS:
157-
t = fwd_data[TraceableTransform.trace_key(k)][-1]
161+
t = (
162+
fwd_data[TraceableTransform.trace_key(k)][-1]
163+
if not use_metatensor
164+
else fwd_data[k].applied_operations[-1]
165+
)
158166
# make sure the OneOf index was stored
159167
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
160168
# make sure index exists and is in bounds
@@ -166,9 +174,11 @@ def test_inverse(self, transform, invertible):
166174
if invertible:
167175
for k in KEYS:
168176
# check transform was removed
169-
self.assertTrue(
170-
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
171-
)
177+
if not use_metatensor:
178+
self.assertTrue(
179+
len(fwd_inv_data[TraceableTransform.trace_key(k)])
180+
< len(fwd_data[TraceableTransform.trace_key(k)])
181+
)
172182
# check data is same as original (and different from forward)
173183
self.assertEqual(fwd_inv_data[k], data[k])
174184
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
@@ -186,15 +196,34 @@ def test_inverse_compose(self):
186196
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
187197
]
188198
),
199+
OneOf(
200+
[
201+
RandScaleIntensityd(keys="img", factors=0.5, prob=1.0),
202+
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
203+
]
204+
),
189205
]
190206
)
191207
transform.set_random_state(seed=0)
192208
result = transform({"img": np.ones((1, 101, 102, 103))})
193-
194209
result = transform.inverse(result)
195210
# invert to the original spatial shape
196211
self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103))
197212

213+
def test_inverse_metatensor(self):
214+
transform = Compose(
215+
[
216+
Resize(spatial_size=[100, 100, 100]),
217+
OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
218+
OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
219+
]
220+
)
221+
transform.set_random_state(seed=0)
222+
result = transform(np.ones((1, 101, 102, 103)))
223+
self.assertTupleEqual(result.shape, (1, 100, 100, 100))
224+
result = transform.inverse(result)
225+
self.assertTupleEqual(result.shape, (1, 101, 102, 103))
226+
198227
def test_one_of(self):
199228
p = OneOf((A(), B(), C()), (1, 2, 1))
200229
counts = [0] * 3

0 commit comments

Comments
 (0)