Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/941.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`~ndcube.extra_coords.TimeTableCoordinate` and `~ndcube.extra_coords.QuantityTableCoordinate` now accept a single N-D table, representing one coordinate that varies over several pixel dimensions. This allows, for example, a 2-D ``Time`` table indexed by (raster scan, raster step) to be attached to an `~ndcube.NDCube` via ``extra_coords.add`` and sliced along either axis.
20 changes: 17 additions & 3 deletions ndcube/extra_coords/extra_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,17 @@ def mapping(self):

# The mapping is from the array index (position in the list) to the
# pixel dimensions (numbers in the list)
lts = [list([lt[0]] if isinstance(lt[0], Integral) else lt[0]) for lt in self._lookup_tables]
converter = partial(convert_between_array_and_pixel_axes, naxes=len(self._ndcube.shape))
pixel_indicies = [list(converter(np.array(ids))) for ids in lts]
pixel_indicies = []
for lut_axis, lut in self._lookup_tables:
ids = [lut_axis] if isinstance(lut_axis, Integral) else list(lut_axis)
pixel_ids = list(converter(np.array(ids)))
if lut._model_inputs_are_pixel_ordered:
# Single N-D tables expose their model inputs in pixel order,
# i.e. reversed with respect to the array-ordered axes given
# to `add`.
pixel_ids = pixel_ids[::-1]
pixel_indicies.append(pixel_ids)
return tuple(reduce(list.__add__, pixel_indicies))

@mapping.setter
Expand Down Expand Up @@ -360,7 +368,6 @@ def _getitem_lookup_tables(self, item):
n_dropped_dims = np.cumsum([isinstance(i, Integral) for i in item])
for lut_axis, lut in self._lookup_tables:
lut_axes = (lut_axis,) if not isinstance(lut_axis, tuple) else lut_axis
new_lut_axes = tuple(ax - n_dropped_dims[ax] for ax in lut_axes)
lut_slice = tuple(item[i] for i in lut_axes)
if isinstance(lut_slice, tuple) and len(lut_slice) == 1:
lut_slice = lut_slice[0]
Expand All @@ -370,6 +377,13 @@ def _getitem_lookup_tables(self, item):
if sliced_lut.is_scalar():
dropped_tables.add(sliced_lut)
else:
kept_axes = lut_axes
if sliced_lut.n_inputs < len(lut_axes):
# The sliced table lost pixel dimensions (e.g. an N-D
# table sliced with an integer), so drop the
# integer-sliced axes from the table's axes.
kept_axes = tuple(ax for ax in lut_axes if not isinstance(item[ax], Integral))
new_lut_axes = tuple(ax - n_dropped_dims[ax] for ax in kept_axes)
new_lookup_tables.add((new_lut_axes, sliced_lut))
new_extra_coords = type(self)()
new_extra_coords._lookup_tables = list(new_lookup_tables)
Expand Down
129 changes: 104 additions & 25 deletions ndcube/extra_coords/table_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
'method': interpolation,
**kwargs}

if len(lookup_table) == 1:
if lookup_table.shape == (1,):
t = Length1Tabular(points, lookup_table, **kwargs)
else:
t = TabularND(points, lookup_table, **kwargs)
Expand Down Expand Up @@ -282,6 +282,25 @@ def model(self):
Generate the Astropy Model for this LookupTable.
"""

@property
def _model_inputs_are_pixel_ordered(self):
"""
True when this coordinate's model inputs are in pixel order.

Single N-D tables span several pixel dimensions with one model. Their
model inputs are exposed in pixel order (reversed array order) so the
resulting WCS follows the APE-14 convention expected by
`~ndcube.NDCube`.
"""
return False

@staticmethod
def _reorder_inputs_to_pixel(model):
"""
Reverse a model's inputs from array order to pixel order.
"""
return models.Mapping(tuple(range(model.n_inputs))[::-1]) | model

@property
def wcs(self):
"""
Expand All @@ -301,15 +320,19 @@ class QuantityTableCoordinate(BaseTableCoordinate):
"""
A lookup table up built on `~astropy.units.Quantity`.

Quantities must be 1-D but more than one can be provided to represent
different dimensions of an N-D coordinate.
Either a single N-D Quantity, or multiple 1-D Quantities can be provided.
A single N-D Quantity represents one physical coordinate which varies over
N pixel dimensions. Multiple 1-D Quantities represent the different
dimensions of an N-D coordinate, with each table corresponding to one
pixel dimension.

Parameters
----------
tables: one or more `~astropy.units.Quantity`
The coordinates. Must be 1 dimensionsal. If coordinate system is >1D,
multiple 1-D Quantities can be provided representing the different
dimensions
The coordinates. Either a single Quantity of any dimensionality
representing one coordinate varying over that many pixel dimensions,
or multiple 1-D Quantities representing the different dimensions of
an N-D coordinate system.

names: `str` or `list` of `str`
Custom names for the components of the QuantityTableCoord. If provided,
Expand All @@ -329,10 +352,10 @@ def __init__(self, *tables, names=None, physical_types=None):
raise u.UnitsError("All tables must have equivalent units.")
ndim = len(tables)
dims = np.array([t.ndim for t in tables])
if any(dims > 1):
if len(tables) > 1 and any(dims > 1):
raise ValueError(
"Currently all tables must be 1-D. If you need >1D support, please "
"raise an issue at https://github.con/sunpy/ndcube/issues")
"Multiple tables can only be provided if they are all 1-D. "
"A single N-D table representing one coordinate is also supported.")

if isinstance(names, str):
names = [names]
Expand Down Expand Up @@ -379,12 +402,26 @@ def _slice_table(self, i, table, item, new_components, whole_slice):
if self.physical_types:
new_components["physical_types"].append(self.physical_types[i])

@property
def _single_nd_table(self):
return len(self.table) == 1 and self.table[0].ndim > 1

def __getitem__(self, item):
if isinstance(item, (slice, Integral)):
item = (item,)
if not (len(item) == len(self.table) or len(item) == self.table[0].ndim):
raise ValueError("Can not slice with incorrect length")

if self._single_nd_table:
# A single N-D table represents one world coordinate, so slicing
# reduces the table but never splits or drops individual world
# components.
ret_table = type(self)(self.table[0][item],
names=self.names,
physical_types=self.physical_types)
ret_table._dropped_world_dimensions = copy.deepcopy(self._dropped_world_dimensions)
return ret_table

new_components = defaultdict(list)
new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions)

Expand All @@ -400,7 +437,7 @@ def __getitem__(self, item):

@property
def n_inputs(self):
return len(self.table)
return self.ndim

def is_scalar(self):
return all(t.shape == () for t in self.table)
Expand All @@ -412,12 +449,20 @@ def frame(self):
"""
return _generate_generic_frame(len(self.table), self.unit, self.names, self.physical_types)

@property
def _model_inputs_are_pixel_ordered(self):
# Docstring inherited.
return self._single_nd_table

@property
def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
return _model_from_quantity(self.table, True)
model = _model_from_quantity(self.table, True)
if self._single_nd_table:
model = self._reorder_inputs_to_pixel(model)
return model

@property
def ndim(self):
Expand All @@ -427,6 +472,8 @@ def ndim(self):
Note this may be different from the number of the dimensions in the
underlying table(s) if different tables represent different dimensions.
"""
if self._single_nd_table:
return self.table[0].ndim
return len(self.table)

@property
Expand All @@ -437,6 +484,8 @@ def shape(self):
Note this may be different from the shape of the underlying table(s)
if different tables represent a different dimensions.
"""
if self._single_nd_table:
return self.table[0].shape
return tuple(len(t) for t in self.table)

def interpolate(self, *new_array_grids, **kwargs):
Expand Down Expand Up @@ -472,10 +521,16 @@ def interpolate(self, *new_array_grids, **kwargs):
raise ValueError("New array grids must all be same shape.")
# Build array grids for non-interpolated table.
old_array_grids = tuple(np.arange(d) for d in self.shape)
# Iterate through tables and interpolate each.
new_tables = [
np.interp(new_grid, old_grid, t.value, **kwargs) * t.unit
for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table)]
if self._single_nd_table:
table = self.table[0]
new_values = scipy.interpolate.interpn(
old_array_grids, table.value, np.stack(new_array_grids, axis=-1), **kwargs)
new_tables = [new_values * table.unit]
else:
# Iterate through tables and interpolate each.
new_tables = [
np.interp(new_grid, old_grid, t.value, **kwargs) * t.unit
for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table)]
# Rebuild return interpolated coord.
new_coord = type(self)(*new_tables, names=self.names, physical_types=self.physical_types)
new_coord._dropped_world_dimensions = self._dropped_world_dimensions
Expand Down Expand Up @@ -699,12 +754,16 @@ def interpolate(self, *new_array_grids, mesh_output=None, **kwargs):

class TimeTableCoordinate(BaseTableCoordinate):
"""
A lookup table based on a `~astropy.time.Time`, will always be one dimensional.
A lookup table based on a `~astropy.time.Time`.

The table represents a single time coordinate which can vary over one or
more pixel dimensions, i.e. the input `~astropy.time.Time` can be N-D.

Parameters
----------
table: `~astropy.time.Time`
Time coordinates. Only one can be provided and must be 1D.
Time coordinates. Only one can be provided. An N-D table corresponds
to N pixel dimensions.

names: `str` or `list` of `str`
Custom names for the components of the SkyCoord. If provided, a name must
Expand Down Expand Up @@ -735,11 +794,15 @@ def __init__(self, *tables, names=None, physical_types=None, reference_time=None

super().__init__(*tables, mesh=False, names=names, physical_types=physical_types)
self.table = self.table[0]
self.reference_time = reference_time or self.table[0]
self.reference_time = reference_time or self.table.ravel()[0]

def __getitem__(self, item):
if not (isinstance(item, (slice, Integral)) or len(item) == 1):
if isinstance(item, (slice, Integral)):
item = (item,)
if len(item) != max(self.table.ndim, 1):
raise ValueError("Can not slice with incorrect length")
if len(item) == 1:
item = item[0]

return type(self)(self.table[item],
names=self.names,
Expand All @@ -748,7 +811,7 @@ def __getitem__(self, item):

@property
def n_inputs(self):
return 1 # The time table has to be one dimensional
return max(self.table.ndim, 1)

def is_scalar(self):
return self.table.shape == ()
Expand All @@ -763,6 +826,11 @@ def frame(self):
axes_names=self.names,
name="TemporalFrame")

@property
def _model_inputs_are_pixel_ordered(self):
# Docstring inherited.
return self.table.ndim > 1

@property
def model(self):
"""
Expand All @@ -771,9 +839,12 @@ def model(self):
time = self.table
deltas = (time - self.reference_time).to(u.s)

return _model_from_quantity((deltas,), mesh=False)
model = _model_from_quantity((deltas,), mesh=False)
if deltas.ndim > 1:
model = self._reorder_inputs_to_pixel(model)
return model

def interpolate(self, new_array_grids, **kwargs):
def interpolate(self, *new_array_grids, **kwargs):
"""
Interpolate TimeTableCoordinate to new array index grids.

Expand All @@ -794,10 +865,18 @@ def interpolate(self, new_array_grids, **kwargs):
"""
if self.is_scalar():
raise ValueError("Cannot interpolate a scalar TimeTableCoordinate.")
# Build pixel grids for current TimeTableCoord.
old_array_grids = np.arange(len(self.table))
if len(new_array_grids) != self.table.ndim:
raise ValueError(
f"A new array grid must be given for each array axis, i.e. {self.table.ndim}")
# Interpolate using MJD format and convert back to a Time object.
new_table = np.interp(new_array_grids, old_array_grids, self.table.mjd, **kwargs)
if self.table.ndim == 1:
# Build pixel grids for current TimeTableCoord.
old_array_grids = np.arange(len(self.table))
new_table = np.interp(new_array_grids[0], old_array_grids, self.table.mjd, **kwargs)
else:
old_array_grids = tuple(np.arange(d) for d in self.table.shape)
new_table = scipy.interpolate.interpn(
old_array_grids, self.table.mjd, np.stack(new_array_grids, axis=-1), **kwargs)
new_table = Time(new_table, scale=self.table.scale, format="mjd")
new_table.format = self.table.format
# Rebuild new TimeTableCoord and return.
Expand Down
35 changes: 33 additions & 2 deletions ndcube/extra_coords/tests/test_extra_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,10 @@ def test_extra_coords_index(skycoord_2d_lut, time_lut):
assert sub_ec.wcs.world_axis_names == ("exposure_time",)


@pytest.mark.xfail(reason=">1D Tables not supported")
def test_extra_coords_2d_quantity(quantity_2d_lut):
ec = ExtraCoords()
ec.add("velocity", (0, 1), quantity_2d_lut)
assert ec.wcs.pixel_to_world(0, 0)
assert u.allclose(ec.wcs.pixel_to_world(1, 2), quantity_2d_lut[2, 1])


# Extra Coords with NDCube
Expand Down Expand Up @@ -560,3 +559,35 @@ def test_length1_extra_coord(wave_lut):
sec = ec[item]
assert (sec.wcs.pixel_to_world(0) == wave_lut[item]).all()
assert (sec.wcs.world_to_pixel(wave_lut[item])[0] == [0]).all()


def test_2d_time_extra_coord_through_cube(wcs_3d_lt_ln_l):
cube = NDCube(np.zeros((3, 4, 5)), wcs=wcs_3d_lt_ln_l)
times = Time("2020-01-01T00:00:00") + np.arange(12).reshape(3, 4) * u.s
cube.extra_coords.add("time", (0, 1), times, physical_types="time")

(world_times,) = cube.axis_world_coords("time", wcs=cube.extra_coords)
assert world_times.shape == (3, 4)
assert (world_times == times).all()

# Slicing with ranges keeps the table 2-D.
sub = cube[1:3, 0:2]
(sub_times,) = sub.axis_world_coords("time", wcs=sub.extra_coords)
assert sub_times.shape == (2, 2)
assert (sub_times == times[1:3, 0:2]).all()

# Integer slicing drops the corresponding table dimension.
row = cube[1]
(row_times,) = row.axis_world_coords("time", wcs=row.extra_coords)
assert row_times.shape == (4,)
assert (row_times == times[1]).all()

column = cube[:, 2]
(column_times,) = column.axis_world_coords("time", wcs=column.extra_coords)
assert column_times.shape == (3,)
assert (column_times == times[:, 2]).all()

# Slicing away both table dimensions drops the coordinate.
point = cube[1, 2]
assert point.extra_coords.is_empty
assert len(point.extra_coords._dropped_tables) == 1
Loading