diff --git a/changelog/940.feature.rst b/changelog/940.feature.rst new file mode 100644 index 000000000..e76a20117 --- /dev/null +++ b/changelog/940.feature.rst @@ -0,0 +1 @@ +Added the ``NDCube._extra_attrs_to_copy`` extension point. Subclasses can declare instance attributes that are automatically propagated to cubes derived through arithmetic operations (``_new_instance``) and through ``to_nddata`` when the target type is a subclass carrying the same attributes. diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index ee1b0909a..d77853cc9 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -378,6 +378,15 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin): _extra_coords = NDCubeLinkedDescriptor(ExtraCoords) _global_coords = NDCubeLinkedDescriptor(GlobalCoords) + # Names of additional instance attributes which subclasses want propagated + # (by reference) to instances derived through `_new_instance` (e.g. + # arithmetic operations) and `to_nddata` when the target type carries the + # same attributes. Subclasses should override this with a tuple of + # attribute names. Note these attributes are not automatically modified + # when a cube is sliced; subclasses with shape-dependent attributes must + # handle slicing themselves. + _extra_attrs_to_copy = () + def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None, unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs): @@ -968,6 +977,9 @@ def _new_instance(self, **kwargs): new_cube._extra_coords = deepcopy(self.extra_coords) if self.global_coords is not None: new_cube._global_coords = deepcopy(self.global_coords) + for attr in self._extra_attrs_to_copy: + if hasattr(self, attr): + setattr(new_cube, attr, getattr(self, attr)) return new_cube def __neg__(self): @@ -1640,6 +1652,12 @@ def to_nddata(self, array([[1., 1., 1.], [1., 1., 1.]]) """ + # If the target type carries the same subclass-specific attributes as + # this cube, copy them by default unless explicitly overridden. + if isinstance(nddata_type, type) and issubclass(nddata_type, type(self)): + for attr in self._extra_attrs_to_copy: + if hasattr(self, attr): + kwargs.setdefault(attr, "copy") # Put all NDData kwargs in a dict user_kwargs = {"data": data, "wcs": wcs, diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index 7719b842c..1592d216f 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -293,3 +293,47 @@ def __init__(self, data, *, spam=None, **kwargs): assert new_ndd.spam == "Eggs" assert new_ndd.data is ndc.data assert new_ndd.wcs is ndc.wcs + + +class SidecarCube(NDCube): + """NDCube subclass declaring extra attributes to propagate to derived cubes.""" + + _extra_attrs_to_copy = ("observer", "calibration_level") + + def __init__(self, *args, **kwargs): + self.observer = kwargs.pop("observer", None) + self.calibration_level = kwargs.pop("calibration_level", 0) + super().__init__(*args, **kwargs) + + +@pytest.fixture +def sidecar_cube(wcs_3d_lt_ln_l): + cube = SidecarCube(np.ones((2, 3, 4)), wcs=wcs_3d_lt_ln_l) + cube.observer = "earth" + cube.calibration_level = 2 + return cube + + +def test_extra_attrs_to_copy_propagate_through_arithmetic(sidecar_cube): + doubled = sidecar_cube * 2 + assert type(doubled) is SidecarCube + assert doubled.observer == "earth" + assert doubled.calibration_level == 2 + + negated = -sidecar_cube + assert negated.observer == "earth" + assert negated.calibration_level == 2 + + +def test_extra_attrs_to_copy_propagate_through_to_nddata(sidecar_cube): + copied = sidecar_cube.to_nddata(nddata_type=SidecarCube) + assert copied.observer == "earth" + assert copied.calibration_level == 2 + + # Explicit kwargs override the automatic copy. + overridden = sidecar_cube.to_nddata(nddata_type=SidecarCube, observer="sdo") + assert overridden.observer == "sdo" + + # Types which do not carry the attributes are unaffected. + plain = sidecar_cube.to_nddata(nddata_type=astropy.nddata.NDData) + assert not hasattr(plain, "observer")