Skip to content

Commit d44c7cb

Browse files
5991 add utility enhancements for lazy resampling (#6017)
Part of #5991 . ### Description This PR adds the minor utility enhancements from #5860 relate to lazy resampling, and it belongs to the first part mentioned in #5991 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 11745a6 commit d44c7cb

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

monai/data/meta_tensor.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2626
from monai.utils import look_up_option
2727
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
28-
from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor
28+
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
2929

3030
__all__ = ["MetaTensor"]
3131

@@ -461,7 +461,7 @@ def affine(self) -> torch.Tensor:
461461
@affine.setter
462462
def affine(self, d: NdarrayTensor) -> None:
463463
"""Set the affine."""
464-
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"))
464+
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.double)
465465

466466
@property
467467
def pixdim(self):
@@ -479,10 +479,18 @@ def peek_pending_shape(self):
479479
return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res
480480

481481
def peek_pending_affine(self):
482-
res = None
483-
if self.pending_operations:
484-
res = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
485-
return self.affine if res is None else res
482+
res = self.affine
483+
for p in self.pending_operations:
484+
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE))
485+
if next_matrix is None:
486+
continue
487+
res = convert_to_dst_type(res, next_matrix)[0]
488+
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
489+
return res
490+
491+
def peek_pending_rank(self):
492+
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
493+
return 1 if a is None else int(max(1, len(a) - 1))
486494

487495
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
488496
"""

monai/transforms/lazy/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def is_affine_shaped(data):
3838
return False
3939
if not hasattr(data, "shape") or len(data.shape) < 2:
4040
return False
41-
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2]
41+
return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2]
4242

4343

4444
class DisplacementField:
@@ -129,6 +129,6 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs:
129129
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
130130
}
131131
resampler = monai.transforms.SpatialResample(**init_kwargs)
132-
# resampler.lazy_evaluation = False
133-
with resampler.trace_transform(False): # don't track this transform in `data`
132+
# resampler.lazy_evaluation = False # resampler is a lazytransform
133+
with resampler.trace_transform(False): # don't track this transform in `img`
134134
return resampler(img=img, **call_kwargs)

monai/transforms/transform.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,17 @@ class LazyTransform(Transform, LazyTrait):
254254
dictionary transforms to simplify implementation of new lazy transforms.
255255
"""
256256

257-
def __init__(self, lazy_evaluation: bool | None = True):
258-
self.lazy_evaluation = lazy_evaluation
257+
_lazy_evaluation: bool = False
259258

260259
@property
261260
def lazy_evaluation(self):
262-
return self.lazy_evaluation
261+
return self._lazy_evaluation
263262

264263
@lazy_evaluation.setter
265264
def lazy_evaluation(self, lazy_evaluation: bool):
266265
if not isinstance(lazy_evaluation, bool):
267266
raise TypeError(f"lazy_evaluation must be a bool but is of type {type(lazy_evaluation)}")
268-
self.lazy_evaluation = lazy_evaluation
267+
self._lazy_evaluation = lazy_evaluation
269268

270269

271270
class RandomizableTransform(Randomizable, Transform):

monai/transforms/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,7 @@ def create_translate(
873873
backend: APIs to use, ``numpy`` or ``torch``.
874874
"""
875875
_backend = look_up_option(backend, TransformBackends)
876+
spatial_dims = int(spatial_dims)
876877
if _backend == TransformBackends.NUMPY:
877878
return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray)
878879
if _backend == TransformBackends.TORCH:

tests/test_meta_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,11 @@ def test_pending_ops(self):
515515
self.assertEqual(m.pending_operations, [])
516516
self.assertEqual(m.peek_pending_shape(), (10, 8))
517517
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
518+
self.assertTrue(m.peek_pending_rank() >= 1)
518519
m.push_pending_operation({})
519520
self.assertEqual(m.peek_pending_shape(), (10, 8))
520521
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
522+
self.assertTrue(m.peek_pending_rank() >= 1)
521523

522524
@parameterized.expand(TESTS)
523525
def test_multiprocessing(self, device=None, dtype=None):

0 commit comments

Comments
 (0)