diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index f4c5e53d..62558445 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -1440,65 +1440,88 @@ def _hi_cupy(zones_data, values_data, nodata): @delayed -def _hi_block_stats(z_block, v_block, uzones): - """Per-chunk: return (n_zones, 4) array of [min, max, sum, count].""" - result = np.full((len(uzones), 4), np.nan, dtype=np.float64) - result[:, 3] = 0 # count starts at 0 - for i, z in enumerate(uzones): - mask = (z_block == z) & np.isfinite(v_block) - if not np.any(mask): +def _hi_block_stats(z_block, v_block, nodata): + """Per-chunk: return dict mapping local zone IDs to (min, max, sum, count). + + Each block discovers its own zones, so the driver never has to compute + a global unique-zone set up front. Sparse zones (geographic) stay sparse + in the returned dict instead of being padded to a full (n_zones, 4) array. + """ + finite_v = np.isfinite(v_block) + finite_z = np.isfinite(z_block) + valid = finite_z & finite_v + if not np.any(valid): + return {} + + z_valid = z_block[valid] + v_valid = v_block[valid] + uzones = np.unique(z_valid) + + result = {} + for z in uzones: + if nodata is not None and z == nodata: continue - vals = v_block[mask] - result[i, 0] = vals.min() - result[i, 1] = vals.max() - result[i, 2] = vals.sum() - result[i, 3] = len(vals) + mask = z_valid == z + vals = v_valid[mask] + if vals.size == 0: + continue + result[z.item() if hasattr(z, 'item') else z] = ( + float(vals.min()), + float(vals.max()), + float(vals.sum()), + int(vals.size), + ) return result @delayed -def _hi_reduce(partials_list, uzones): - """Reduce per-block stats to global per-zone HI lookup dict.""" - stacked = np.stack(partials_list) # (n_blocks, n_zones, 4) - g_min = np.nanmin(stacked[:, :, 0], axis=0) - g_max = np.nanmax(stacked[:, :, 1], axis=0) - g_sum = np.nansum(stacked[:, :, 2], axis=0) - g_count = np.nansum(stacked[:, :, 3], axis=0) +def _hi_reduce(partials_list): + """Stream-merge per-block dicts into global hi_lookup. + + Scheduler peak memory is O(n_zones) for the merged dict, rather than + O(n_blocks * n_zones) from a stacked array. Per-block partials arrive + as a Python list but are iterated once and can be released. + """ + merged = {} + for partial in partials_list: + for z, (mn, mx, s, c) in partial.items(): + if z in merged: + om, oM, os_, oc = merged[z] + merged[z] = (min(om, mn), max(oM, mx), os_ + s, oc + c) + else: + merged[z] = (mn, mx, s, c) hi_lookup = {} - for i, z in enumerate(uzones): - if g_count[i] == 0 or g_max[i] == g_min[i]: + for z, (mn, mx, s, c) in merged.items(): + if c == 0 or mx == mn: hi_lookup[z] = np.nan else: - mean = g_sum[i] / g_count[i] - hi_lookup[z] = (mean - g_min[i]) / (g_max[i] - g_min[i]) + hi_lookup[z] = (s / c - mn) / (mx - mn) return hi_lookup def _hi_dask_numpy(zones_data, values_data, nodata): - """Dask+numpy backend for hypsometric integral.""" - # Step 1: find all unique zones across all chunks - unique_zones = _unique_finite_zones(zones_data) - if nodata is not None: - unique_zones = unique_zones[unique_zones != nodata] - - if len(unique_zones) == 0: - return da.full(values_data.shape, np.nan, dtype=np.float64, - chunks=values_data.chunks) + """Dask+numpy backend for hypsometric integral. - # Step 2: per-block aggregation -> global reduce + Single graph evaluation: each block computes its local (zone -> stats) + dict, then a streaming reduce merges them into a lookup table, then + map_blocks paints the result. No up-front `_unique_finite_zones` + compute and no O(n_blocks * n_zones) scheduler-side stack. + """ zones_blocks = zones_data.to_delayed().ravel() values_blocks = values_data.to_delayed().ravel() partials = [ - _hi_block_stats(zb, vb, unique_zones) + _hi_block_stats(zb, vb, nodata) for zb, vb in zip(zones_blocks, values_blocks) ] - # Compute the HI lookup eagerly so map_blocks can use it as a parameter. - hi_lookup = dask.compute(_hi_reduce(partials, unique_zones))[0] + hi_lookup = dask.compute(_hi_reduce(partials))[0] + + if not hi_lookup: + return da.full(values_data.shape, np.nan, dtype=np.float64, + chunks=values_data.chunks) - # Step 3: paint back using map_blocks (preserves chunk structure) def _paint(zones_chunk, values_chunk, hi_map): out = np.full(zones_chunk.shape, np.nan, dtype=np.float64) for z, hi_val in hi_map.items():