diff --git a/cf_xarray/helpers.py b/cf_xarray/helpers.py index 5ca5b1b9..108a9c81 100644 --- a/cf_xarray/helpers.py +++ b/cf_xarray/helpers.py @@ -17,25 +17,23 @@ def _guess_bounds_1d(da, dim): """ if dim not in da.dims: (dim,) = da.cf.axes[dim] - ADDED_INDEX = False - if dim not in da.coords: - # For proper alignment in the lines below, we need an index on dim. - da = da.assign_coords({dim: da[dim]}) - ADDED_INDEX = True bound_position = 0.5 - diff = da.diff(dim).pad({dim: (1, 1)}, mode="edge").reset_index(dim) - lower = ( - da.reset_index(dim) - bound_position * diff.isel({dim: slice(0, -1)}) - ).assign_coords({dim: da[dim]}) - upper = ( - da.reset_index(dim) + bound_position * diff.isel({dim: slice(1, None)}) - ).assign_coords({dim: da[dim]}) - result = xr.concat([lower, upper], dim="bounds").transpose(..., "bounds") - - if ADDED_INDEX: - result = result.drop_vars(dim) - return result.drop_attrs(deep=False) + + diff = da.diff(dim).pad({dim: (1, 1)}, mode="edge") + lower = da.copy( + deep=False, + data=da.data - bound_position * diff.isel({dim: slice(0, -1)}).data, + ) + upper = da.copy( + deep=False, + data=da.data + bound_position * diff.isel({dim: slice(1, None)}).data, + ) + return ( + xr.concat([lower, upper], dim="bounds") + .transpose(..., "bounds") + .drop_attrs(deep=False) + ) def _guess_bounds_2d(da, dims): diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 334c8f8c..a188aa39 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -3,6 +3,7 @@ import warnings from textwrap import dedent +import dask.array import matplotlib as mpl import numpy as np import pandas as pd @@ -839,6 +840,21 @@ def test_add_bounds(dims): _check_unchanged(original, ds) +def test_add_bounds_preserves_array_type() -> None: + # Test that the array type of the bounds variable is the same as the original variable. + ds = airds + original = ds.copy(deep=True) + ds = ds.drop_indexes("lat").rename_dims(lat="x") + ds["lat"] = ds.lat.copy(data=dask.array.asarray(ds.lat.data)) + added = ds.cf.add_bounds("lat") + + assert isinstance(added.lat.data, dask.array.Array) + assert isinstance(added.lat_bounds.data, dask.array.Array) + + assert isinstance(ds.lat.data, dask.array.Array) + _check_unchanged(original, ds) + + def test_add_irregularly_spaced_bounds_do_not_overlap() -> None: # Test that added bounds with irregular spacing do not overlap. ds = airds