diff --git a/changelog/937.feature.rst b/changelog/937.feature.rst new file mode 100644 index 000000000..c5b796200 --- /dev/null +++ b/changelog/937.feature.rst @@ -0,0 +1,5 @@ +Added ``NDCube._slice_custom_state``, an overridable no-op hook called at the +end of ``NDCube.__getitem__`` with the newly sliced cube and the sanitized +slice item. Subclasses carrying extra state that tracks the data axes (e.g. +per-frame WCS objects) can override it instead of overriding ``__getitem__`` +and re-deriving the normalized item themselves. diff --git a/ndcube/mixins/ndslicing.py b/ndcube/mixins/ndslicing.py index 9f1d93378..b2b86c9c3 100644 --- a/ndcube/mixins/ndslicing.py +++ b/ndcube/mixins/ndslicing.py @@ -53,4 +53,30 @@ def __getitem__(self, item): if meta_is_sliceable: sliced_cube.meta = meta.slice[item] + self._slice_custom_state(sliced_cube, item) + return sliced_cube + + def _slice_custom_state(self, sliced_cube, item): + """ + Update custom subclass state on a newly sliced cube. + + Called at the end of ``__getitem__``, after the data, WCS, coords and + metadata of ``sliced_cube`` have been set. + + Subclasses carrying extra state that tracks the data axes (for example + a list of per-frame WCS objects) should override this method instead of + ``__getitem__``, mutating ``sliced_cube`` in place. The default + implementation does nothing. + + Parameters + ---------- + sliced_cube : `~ndcube.NDCube` + The new cube produced by slicing this one. + + item : `tuple` + The sanitized slice item: one entry per data axis of the original + cube, containing only `int` and `slice` objects. Any ellipsis has + already been expanded and missing trailing axes filled with + ``slice(None)``. + """ diff --git a/ndcube/tests/test_ndcube_slice_and_crop.py b/ndcube/tests/test_ndcube_slice_and_crop.py index 56b7cd4a4..47e6f931c 100644 --- a/ndcube/tests/test_ndcube_slice_and_crop.py +++ b/ndcube/tests/test_ndcube_slice_and_crop.py @@ -637,3 +637,28 @@ def test_crop_by_values_quantity_table_coordinate(): wcs=cube.extra_coords) assert cropped.shape == (10, 5) np.testing.assert_array_equal(cropped.data, data[3:13, 1:6]) + + +def test_slice_custom_state_hook(ndcube_3d_ln_lt_l): + class StatefulCube(NDCube): + def _slice_custom_state(self, sliced_cube, item): + sliced_cube.seen_items = [*getattr(self, "seen_items", []), item] + + cube = StatefulCube(ndcube_3d_ln_lt_l.data, ndcube_3d_ln_lt_l.wcs) + + # The hook receives the sanitized item: one int/slice per data axis. + assert cube[0].seen_items == [(0, slice(None), slice(None))] + assert cube[..., 1:3].seen_items == [(slice(None), slice(None), slice(1, 3))] + assert cube[:, 1, 1:3].seen_items == [(slice(None), 1, slice(1, 3))] + + # Slicing a sliced cube calls the hook again on the new cube. + chained = cube[..., 1:3][1:2] + assert chained.seen_items == [ + (slice(None), slice(None), slice(1, 3)), + (slice(1, 2), slice(None), slice(None)), + ] + + +def test_slice_custom_state_default_is_noop(ndcube_3d_ln_lt_l): + sliced = ndcube_3d_ln_lt_l[0] + assert not hasattr(sliced, "seen_items")