Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 65 additions & 43 deletions xrspatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
import xarray as xr
from scipy.ndimage import map_coordinates as _scipy_map_coords
from scipy.ndimage import zoom as _scipy_zoom

try:
import dask.array as da
Expand All @@ -19,10 +18,8 @@

try:
import cupy
import cupyx.scipy.ndimage as _cupy_ndimage
except ImportError:
cupy = None
_cupy_ndimage = None

from xrspatial.utils import (
ArrayTypeFunctionMapping,
Expand Down Expand Up @@ -64,52 +61,88 @@ def _output_chunks(in_chunks, scale):
for i in range(len(in_chunks)))


# -- NaN-aware zoom (NumPy) --------------------------------------------------
# -- Block-centered coordinate mapping ---------------------------------------

def _nan_aware_zoom_np(data, zoom_yx, order):
"""``scipy.ndimage.zoom`` with NaN-aware weighting.
def _block_centered_coords(n_in, n_out):
"""Return input coordinates for each output pixel using block-centered mapping.

Maps output pixel ``o`` to input pixel ``(o + 0.5) * (n_in / n_out) - 0.5``.
This places each output pixel at the center of its spatial footprint,
matching the convention used by ``_new_coords`` for output coordinate
metadata.
"""
o = np.arange(n_out, dtype=np.float64)
return (o + 0.5) * (n_in / n_out) - 0.5


# -- NaN-aware interpolation (NumPy) ----------------------------------------

def _nan_aware_interp_np(data, out_h, out_w, order):
"""Interpolate *data* to *(out_h, out_w)* with NaN-aware weighting.

Uses ``scipy.ndimage.map_coordinates`` with block-centered coordinate
mapping so that sample positions match the output coordinate metadata.

For *order* 0 (nearest-neighbour) NaN propagates naturally.
For higher orders the zero-fill / weight-mask trick is used so that
NaN pixels do not corrupt their neighbours.
"""
iy = _block_centered_coords(data.shape[0], out_h)
ix = _block_centered_coords(data.shape[1], out_w)
yy, xx = np.meshgrid(iy, ix, indexing='ij')
coords = np.array([yy.ravel(), xx.ravel()])

if order == 0:
return _scipy_zoom(data, zoom_yx, order=0, mode='nearest')
result = _scipy_map_coords(data, coords, order=0, mode='nearest')
return result.reshape(out_h, out_w)

mask = np.isnan(data)
if not mask.any():
return _scipy_zoom(data, zoom_yx, order=order, mode='nearest')
result = _scipy_map_coords(data, coords, order=order, mode='nearest')
return result.reshape(out_h, out_w)

filled = np.where(mask, 0.0, data)
weights = (~mask).astype(data.dtype)

z_data = _scipy_zoom(filled, zoom_yx, order=order, mode='nearest')
z_wt = _scipy_zoom(weights, zoom_yx, order=order, mode='nearest')
z_data = _scipy_map_coords(filled, coords, order=order, mode='nearest')
z_wt = _scipy_map_coords(weights, coords, order=order, mode='nearest')

result = np.where(z_wt > 0.01,
z_data / np.maximum(z_wt, 1e-10),
np.nan)
return result.reshape(out_h, out_w)

return np.where(z_wt > 0.01,
z_data / np.maximum(z_wt, 1e-10),
np.nan)

# -- NaN-aware interpolation (CuPy) -----------------------------------------

# -- NaN-aware zoom (CuPy) ---------------------------------------------------
def _nan_aware_interp_cupy(data, out_h, out_w, order):
"""CuPy variant of :func:`_nan_aware_interp_np`."""
from cupyx.scipy.ndimage import map_coordinates as _cupy_map_coords

iy = cupy.asarray(_block_centered_coords(data.shape[0], out_h))
ix = cupy.asarray(_block_centered_coords(data.shape[1], out_w))
yy, xx = cupy.meshgrid(iy, ix, indexing='ij')
coords = cupy.array([yy.ravel(), xx.ravel()])

def _nan_aware_zoom_cupy(data, zoom_yx, order):
if order == 0:
return _cupy_ndimage.zoom(data, zoom_yx, order=0, mode='nearest')
result = _cupy_map_coords(data, coords, order=0, mode='nearest')
return result.reshape(out_h, out_w)

mask = cupy.isnan(data)
if not mask.any():
return _cupy_ndimage.zoom(data, zoom_yx, order=order, mode='nearest')
result = _cupy_map_coords(data, coords, order=order, mode='nearest')
return result.reshape(out_h, out_w)

filled = cupy.where(mask, 0.0, data)
weights = (~mask).astype(data.dtype)

z_data = _cupy_ndimage.zoom(filled, zoom_yx, order=order, mode='nearest')
z_wt = _cupy_ndimage.zoom(weights, zoom_yx, order=order, mode='nearest')
z_data = _cupy_map_coords(filled, coords, order=order, mode='nearest')
z_wt = _cupy_map_coords(weights, coords, order=order, mode='nearest')

return cupy.where(z_wt > 0.01,
z_data / cupy.maximum(z_wt, 1e-10),
cupy.nan)
result = cupy.where(z_wt > 0.01,
z_data / cupy.maximum(z_wt, 1e-10),
cupy.nan)
return result.reshape(out_h, out_w)


# -- Block-aggregation kernels (NumPy, numba) --------------------------------
Expand Down Expand Up @@ -283,13 +316,9 @@ def _interp_block_np(block, global_in_h, global_in_w,
oy = np.arange(cum_out_y[yi], cum_out_y[yi + 1], dtype=np.float64)
ox = np.arange(cum_out_x[xi], cum_out_x[xi + 1], dtype=np.float64)

# Map to global input coordinates (same formula scipy.ndimage.zoom uses)
iy = (oy * (global_in_h - 1) / (global_out_h - 1)
if global_out_h > 1
else np.full(len(oy), (global_in_h - 1) / 2.0))
ix = (ox * (global_in_w - 1) / (global_out_w - 1)
if global_out_w > 1
else np.full(len(ox), (global_in_w - 1) / 2.0))
# Map to global input coordinates using block-centered formula
iy = (oy + 0.5) * (global_in_h / global_out_h) - 0.5
ix = (ox + 0.5) * (global_in_w / global_out_w) - 0.5

# Convert to local block coordinates (overlap shifts the origin)
iy_local = iy - (cum_in_y[yi] - depth)
Expand Down Expand Up @@ -331,12 +360,9 @@ def _interp_block_cupy(block, global_in_h, global_in_w,
ox = cupy.arange(int(cum_out_x[xi]), int(cum_out_x[xi + 1]),
dtype=cupy.float64)

iy = (oy * (global_in_h - 1) / (global_out_h - 1)
if global_out_h > 1
else cupy.full(len(oy), (global_in_h - 1) / 2.0))
ix = (ox * (global_in_w - 1) / (global_out_w - 1)
if global_out_w > 1
else cupy.full(len(ox), (global_in_w - 1) / 2.0))
# Map to global input coordinates using block-centered formula
iy = (oy + 0.5) * (global_in_h / global_out_h) - 0.5
ix = (ox + 0.5) * (global_in_w / global_out_w) - 0.5

iy_local = iy - float(cum_in_y[yi] - depth)
ix_local = ix - float(cum_in_x[xi] - depth)
Expand Down Expand Up @@ -410,10 +436,8 @@ def _run_numpy(data, scale_y, scale_x, method):
out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

if method in INTERP_METHODS:
zy = out_h / data.shape[0]
zx = out_w / data.shape[1]
return _nan_aware_zoom_np(data, (zy, zx),
INTERP_METHODS[method]).astype(np.float32)
return _nan_aware_interp_np(data, out_h, out_w,
INTERP_METHODS[method]).astype(np.float32)

return _AGG_FUNCS[method](data, out_h, out_w).astype(np.float32)

Expand All @@ -423,10 +447,8 @@ def _run_cupy(data, scale_y, scale_x, method):
out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

if method in INTERP_METHODS:
zy = out_h / data.shape[0]
zx = out_w / data.shape[1]
return _nan_aware_zoom_cupy(data, (zy, zx),
INTERP_METHODS[method]).astype(cupy.float32)
return _nan_aware_interp_cupy(data, out_h, out_w,
INTERP_METHODS[method]).astype(cupy.float32)

# Aggregate: GPU reshape+reduce for integer factors, CPU fallback otherwise
fy, fx = data.shape[0] / out_h, data.shape[1] / out_w
Expand Down
55 changes: 55 additions & 0 deletions xrspatial/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,48 @@ def test_bilinear_upsample_smooth(self, grid_8x8):
# Verify interior is within tolerance of the linear gradient
assert np.all(np.isfinite(out.values))

@pytest.mark.parametrize('method', ['nearest', 'bilinear', 'cubic'])
def test_interp_coordinate_alignment_downsample(self, method):
"""Interpolated values should match output coordinate labels (#1202).

On a linear gradient where value == x-coordinate, a correct
block-centered resample produces output values equal to the output
coordinate labels (within floating-point tolerance).
"""
data = np.tile(np.arange(8, dtype=np.float32), (8, 1))
agg = create_test_raster(data, attrs={'res': (1.0, 1.0)},
dims=['y', 'x'])
out = resample(agg, scale_factor=0.5, method=method)

# Output x-coords are block-centered: 0.5, 2.5, 4.5, 6.5
# Values should match because the input is a linear gradient
np.testing.assert_allclose(
out.values[0], out.x.values, atol=0.6,
err_msg=f"{method}: values should be close to x-coordinates"
)
# For bilinear on a linear gradient, the match should be exact
if method == 'bilinear':
np.testing.assert_allclose(
out.values[0], out.x.values, atol=1e-5,
err_msg="bilinear on linear gradient must be exact"
)

def test_bilinear_coordinate_alignment_upsample(self):
"""Upsampled interior pixels should match coordinates on a gradient."""
data = np.tile(np.arange(8, dtype=np.float32), (8, 1))
agg = create_test_raster(data, attrs={'res': (1.0, 1.0)},
dims=['y', 'x'])
out = resample(agg, scale_factor=2.0, method='bilinear')

# Interior pixels (away from boundary clamping) should be exact
# for bilinear on a linear gradient. Skip first and last pixel
# which may be clamped by mode='nearest' boundary handling.
interior = slice(1, -1)
np.testing.assert_allclose(
out.values[0, interior], out.x.values[interior], atol=1e-4,
err_msg="bilinear: interior values should match x-coordinates"
)


# ---------------------------------------------------------------------------
# NaN handling
Expand Down Expand Up @@ -281,6 +323,19 @@ def test_interp_parity(self, numpy_and_dask_rasters, method, sf):
np.testing.assert_allclose(dk_out.values, np_out.values,
atol=1e-5, equal_nan=True)

@pytest.mark.parametrize('method', ['bilinear'])
def test_dask_coordinate_alignment(self, method):
"""Dask bilinear on a linear gradient should match coordinates (#1202)."""
data = np.tile(np.arange(20, dtype=np.float32), (20, 1))
dk_agg = create_test_raster(data, backend='dask+numpy',
attrs={'res': (1.0, 1.0)},
chunks=(8, 8))
out = resample(dk_agg, scale_factor=0.5, method=method)
np.testing.assert_allclose(
out.values[0], out.x.values, atol=1e-4,
err_msg="dask bilinear values should match x-coordinates"
)

@pytest.mark.parametrize('method', ['average', 'min', 'max', 'median', 'mode'])
def test_aggregate_parity(self, numpy_and_dask_rasters, method):
np_agg, dk_agg = numpy_and_dask_rasters
Expand Down
Loading