Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,65 +1440,88 @@ def _hi_cupy(zones_data, values_data, nodata):


@delayed
def _hi_block_stats(z_block, v_block, uzones):
"""Per-chunk: return (n_zones, 4) array of [min, max, sum, count]."""
result = np.full((len(uzones), 4), np.nan, dtype=np.float64)
result[:, 3] = 0 # count starts at 0
for i, z in enumerate(uzones):
mask = (z_block == z) & np.isfinite(v_block)
if not np.any(mask):
def _hi_block_stats(z_block, v_block, nodata):
"""Per-chunk: return dict mapping local zone IDs to (min, max, sum, count).

Each block discovers its own zones, so the driver never has to compute
a global unique-zone set up front. Sparse zones (geographic) stay sparse
in the returned dict instead of being padded to a full (n_zones, 4) array.
"""
finite_v = np.isfinite(v_block)
finite_z = np.isfinite(z_block)
valid = finite_z & finite_v
if not np.any(valid):
return {}

z_valid = z_block[valid]
v_valid = v_block[valid]
uzones = np.unique(z_valid)

result = {}
for z in uzones:
if nodata is not None and z == nodata:
continue
vals = v_block[mask]
result[i, 0] = vals.min()
result[i, 1] = vals.max()
result[i, 2] = vals.sum()
result[i, 3] = len(vals)
mask = z_valid == z
vals = v_valid[mask]
if vals.size == 0:
continue
result[z.item() if hasattr(z, 'item') else z] = (
float(vals.min()),
float(vals.max()),
float(vals.sum()),
int(vals.size),
)
return result


@delayed
def _hi_reduce(partials_list, uzones):
"""Reduce per-block stats to global per-zone HI lookup dict."""
stacked = np.stack(partials_list) # (n_blocks, n_zones, 4)
g_min = np.nanmin(stacked[:, :, 0], axis=0)
g_max = np.nanmax(stacked[:, :, 1], axis=0)
g_sum = np.nansum(stacked[:, :, 2], axis=0)
g_count = np.nansum(stacked[:, :, 3], axis=0)
def _hi_reduce(partials_list):
"""Stream-merge per-block dicts into global hi_lookup.

Scheduler peak memory is O(n_zones) for the merged dict, rather than
O(n_blocks * n_zones) from a stacked array. Per-block partials arrive
as a Python list but are iterated once and can be released.
"""
merged = {}
for partial in partials_list:
for z, (mn, mx, s, c) in partial.items():
if z in merged:
om, oM, os_, oc = merged[z]
merged[z] = (min(om, mn), max(oM, mx), os_ + s, oc + c)
else:
merged[z] = (mn, mx, s, c)

hi_lookup = {}
for i, z in enumerate(uzones):
if g_count[i] == 0 or g_max[i] == g_min[i]:
for z, (mn, mx, s, c) in merged.items():
if c == 0 or mx == mn:
hi_lookup[z] = np.nan
else:
mean = g_sum[i] / g_count[i]
hi_lookup[z] = (mean - g_min[i]) / (g_max[i] - g_min[i])
hi_lookup[z] = (s / c - mn) / (mx - mn)
return hi_lookup


def _hi_dask_numpy(zones_data, values_data, nodata):
"""Dask+numpy backend for hypsometric integral."""
# Step 1: find all unique zones across all chunks
unique_zones = _unique_finite_zones(zones_data)
if nodata is not None:
unique_zones = unique_zones[unique_zones != nodata]

if len(unique_zones) == 0:
return da.full(values_data.shape, np.nan, dtype=np.float64,
chunks=values_data.chunks)
"""Dask+numpy backend for hypsometric integral.

# Step 2: per-block aggregation -> global reduce
Single graph evaluation: each block computes its local (zone -> stats)
dict, then a streaming reduce merges them into a lookup table, then
map_blocks paints the result. No up-front `_unique_finite_zones`
compute and no O(n_blocks * n_zones) scheduler-side stack.
"""
zones_blocks = zones_data.to_delayed().ravel()
values_blocks = values_data.to_delayed().ravel()

partials = [
_hi_block_stats(zb, vb, unique_zones)
_hi_block_stats(zb, vb, nodata)
for zb, vb in zip(zones_blocks, values_blocks)
]

# Compute the HI lookup eagerly so map_blocks can use it as a parameter.
hi_lookup = dask.compute(_hi_reduce(partials, unique_zones))[0]
hi_lookup = dask.compute(_hi_reduce(partials))[0]

if not hi_lookup:
return da.full(values_data.shape, np.nan, dtype=np.float64,
chunks=values_data.chunks)

# Step 3: paint back using map_blocks (preserves chunk structure)
def _paint(zones_chunk, values_chunk, hi_map):
out = np.full(zones_chunk.shape, np.nan, dtype=np.float64)
for z, hi_val in hi_map.items():
Expand Down
Loading