Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/940.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the ``NDCube._extra_attrs_to_copy`` extension point. Subclasses can declare instance attributes that are automatically propagated to cubes derived through arithmetic operations (``_new_instance``) and through ``to_nddata`` when the target type is a subclass carrying the same attributes.
18 changes: 18 additions & 0 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,15 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin):
_extra_coords = NDCubeLinkedDescriptor(ExtraCoords)
_global_coords = NDCubeLinkedDescriptor(GlobalCoords)

# Names of additional instance attributes which subclasses want propagated
# (by reference) to instances derived through `_new_instance` (e.g.
# arithmetic operations) and `to_nddata` when the target type carries the
# same attributes. Subclasses should override this with a tuple of
# attribute names. Note these attributes are not automatically modified
# when a cube is sliced; subclasses with shape-dependent attributes must
# handle slicing themselves.
_extra_attrs_to_copy = ()

def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None,
unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs):

Expand Down Expand Up @@ -968,6 +977,9 @@ def _new_instance(self, **kwargs):
new_cube._extra_coords = deepcopy(self.extra_coords)
if self.global_coords is not None:
new_cube._global_coords = deepcopy(self.global_coords)
for attr in self._extra_attrs_to_copy:
if hasattr(self, attr):
setattr(new_cube, attr, getattr(self, attr))
return new_cube

def __neg__(self):
Expand Down Expand Up @@ -1640,6 +1652,12 @@ def to_nddata(self,
array([[1., 1., 1.],
[1., 1., 1.]])
"""
# If the target type carries the same subclass-specific attributes as
# this cube, copy them by default unless explicitly overridden.
if isinstance(nddata_type, type) and issubclass(nddata_type, type(self)):
for attr in self._extra_attrs_to_copy:
if hasattr(self, attr):
kwargs.setdefault(attr, "copy")
# Put all NDData kwargs in a dict
user_kwargs = {"data": data,
"wcs": wcs,
Expand Down
44 changes: 44 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,47 @@ def __init__(self, data, *, spam=None, **kwargs):
assert new_ndd.spam == "Eggs"
assert new_ndd.data is ndc.data
assert new_ndd.wcs is ndc.wcs


class SidecarCube(NDCube):
"""NDCube subclass declaring extra attributes to propagate to derived cubes."""

_extra_attrs_to_copy = ("observer", "calibration_level")

def __init__(self, *args, **kwargs):
self.observer = kwargs.pop("observer", None)
self.calibration_level = kwargs.pop("calibration_level", 0)
super().__init__(*args, **kwargs)


@pytest.fixture
def sidecar_cube(wcs_3d_lt_ln_l):
cube = SidecarCube(np.ones((2, 3, 4)), wcs=wcs_3d_lt_ln_l)
cube.observer = "earth"
cube.calibration_level = 2
return cube


def test_extra_attrs_to_copy_propagate_through_arithmetic(sidecar_cube):
doubled = sidecar_cube * 2
assert type(doubled) is SidecarCube
assert doubled.observer == "earth"
assert doubled.calibration_level == 2

negated = -sidecar_cube
assert negated.observer == "earth"
assert negated.calibration_level == 2


def test_extra_attrs_to_copy_propagate_through_to_nddata(sidecar_cube):
copied = sidecar_cube.to_nddata(nddata_type=SidecarCube)
assert copied.observer == "earth"
assert copied.calibration_level == 2

# Explicit kwargs override the automatic copy.
overridden = sidecar_cube.to_nddata(nddata_type=SidecarCube, observer="sdo")
assert overridden.observer == "sdo"

# Types which do not carry the attributes are unaffected.
plain = sidecar_cube.to_nddata(nddata_type=astropy.nddata.NDData)
assert not hasattr(plain, "observer")