diff --git a/xrspatial/classify.py b/xrspatial/classify.py index 1a16d0af..8b3bee2c 100644 --- a/xrspatial/classify.py +++ b/xrspatial/classify.py @@ -1079,19 +1079,31 @@ def _run_head_tail_breaks(agg, module): def _run_dask_head_tail_breaks(agg): data = agg.data - data_clean = da.where(da.isinf(data), np.nan, data) + # Persist once so the iterative loop does not re-read from storage on + # every scan. Fuse the three reductions per iteration into a single + # dask.compute() to cut graph traversals from 3N+1 to N+1. + data_clean = da.where(da.isinf(data), np.nan, data).persist() bins = [] mask = da.isfinite(data_clean) + total_count_lazy = mask.sum() + total_count = int(total_count_lazy.compute()) + if total_count == 0: + max_v = float(da.nanmax(data_clean).compute()) + return _bin(agg, np.array([max_v]), np.arange(1)) + + current_total = total_count while True: current = da.where(mask, data_clean, np.nan) - mean_v = float(da.nanmean(current).compute()) + new_mask = mask & (data_clean > da.nanmean(current)) + # Fuse mean and head-count into one graph evaluation. + mean_v, head_count = dask.compute(da.nanmean(current), new_mask.sum()) + mean_v = float(mean_v) + head_count = int(head_count) bins.append(mean_v) - new_mask = mask & (data_clean > mean_v) - head_count = int(new_mask.sum().compute()) - total_count = int(mask.sum().compute()) - if head_count == 0 or head_count / total_count > 0.40: + if head_count == 0 or head_count / current_total > 0.40: break mask = new_mask + current_total = head_count max_v = float(da.nanmax(data_clean).compute()) bins.append(max_v) bins = np.array(bins) @@ -1366,6 +1378,22 @@ def maximum_breaks(agg: xr.DataArray, attrs=agg.attrs) +_BOX_PLOT_DEFAULT_SAMPLE = 200_000 + + +def _box_plot_bins_from_sample(finite_np, hinge, max_v): + q1 = float(np.percentile(finite_np, 25)) + q2 = float(np.percentile(finite_np, 50)) + q3 = float(np.percentile(finite_np, 75)) + iqr = q3 - q1 + raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v] + bins = np.sort(np.unique(raw_bins)) + bins = bins[bins <= max_v] + if bins[-1] < max_v: + bins = np.append(bins, max_v) + return bins + + def _run_box_plot(agg, hinge, module): data = agg.data data_clean = module.where(module.isinf(data), np.nan, data) @@ -1376,51 +1404,63 @@ def _run_box_plot(agg, hinge, module): q2 = float(cupy.percentile(finite_data, 50).get()) q3 = float(cupy.percentile(finite_data, 75).get()) max_v = float(cupy.nanmax(finite_data).get()) - elif module == da: - q1_l = da.percentile(finite_data, 25) - q2_l = da.percentile(finite_data, 50) - q3_l = da.percentile(finite_data, 75) - max_l = da.nanmax(data_clean) - q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l) - q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item() - else: - q1 = float(np.percentile(finite_data, 25)) - q2 = float(np.percentile(finite_data, 50)) - q3 = float(np.percentile(finite_data, 75)) - max_v = float(np.nanmax(finite_data)) + iqr = q3 - q1 + raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v] + bins = np.sort(np.unique(raw_bins)) + bins = bins[bins <= max_v] + if bins[-1] < max_v: + bins = np.append(bins, max_v) + else: # numpy + finite_np = np.asarray(finite_data) + max_v = float(np.nanmax(finite_np)) + bins = _box_plot_bins_from_sample(finite_np, hinge, max_v) - iqr = q3 - q1 - raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v] - bins = np.sort(np.unique(raw_bins)) - # Remove bins above max (they'd create empty classes) - bins = bins[bins <= max_v] - if bins[-1] < max_v: - bins = np.append(bins, max_v) out = _bin(agg, bins, np.arange(len(bins))) return out +def _run_dask_box_plot(agg, hinge): + """Dask+numpy box_plot. + + Avoids boolean fancy indexing on a dask array (which produces unknown + chunk sizes and forces a chunk-size compute pass). Samples the data + via the same seeded index sampler used by natural_breaks, then + computes percentiles on the finite portion of the sample in numpy. + """ + data = agg.data + data_clean = da.where(da.isinf(data), np.nan, data) + num_data = data_clean.size + num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data) + sample_idx = _generate_sample_indices(num_data, num_sample) + + sample_lazy = data_clean.ravel()[sample_idx] + max_lazy = da.nanmax(data_clean) + sample_np, max_v = dask.compute(sample_lazy, max_lazy) + sample_np = np.asarray(sample_np) + finite_np = sample_np[np.isfinite(sample_np)] + max_v = float(max_v) + + bins = _box_plot_bins_from_sample(finite_np, hinge, max_v) + return _bin(agg, bins, np.arange(len(bins))) + + def _run_dask_cupy_box_plot(agg, hinge): + """Dask+cupy box_plot: sample on-device, transfer only the sample.""" data = agg.data - data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(())) - data_clean = da.where(da.isinf(data_cpu), np.nan, data_cpu) - finite_data = data_clean[da.isfinite(data_clean)] + data_clean = da.where(da.isinf(data), np.nan, data) + num_data = data_clean.size + num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data) + sample_idx = _generate_sample_indices(num_data, num_sample) - q1_l = da.percentile(finite_data, 25) - q2_l = da.percentile(finite_data, 50) - q3_l = da.percentile(finite_data, 75) - max_l = da.nanmax(data_clean) - q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l) - q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item() + sample_lazy = data_clean.ravel()[sample_idx] + max_lazy = da.nanmax(data_clean) + sample_cp, max_v = dask.compute(sample_lazy, max_lazy) + sample_np = cupy.asnumpy(sample_cp) + finite_np = sample_np[np.isfinite(sample_np)] + max_v = float(cupy.asnumpy(max_v).item()) if hasattr(max_v, 'get') else float(max_v) - iqr = q3 - q1 - raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v] - bins = np.sort(np.unique(raw_bins)) - bins = bins[bins <= max_v] - if bins[-1] < max_v: - bins = np.append(bins, max_v) - out = _bin(agg, bins, np.arange(len(bins))) - return out + bins = _box_plot_bins_from_sample(finite_np, hinge, max_v) + return _bin(agg, bins, np.arange(len(bins))) @supports_dataset @@ -1459,7 +1499,7 @@ def box_plot(agg: xr.DataArray, mapper = ArrayTypeFunctionMapping( numpy_func=lambda *args: _run_box_plot(*args, module=np), - dask_func=lambda *args: _run_box_plot(*args, module=da), + dask_func=_run_dask_box_plot, cupy_func=lambda *args: _run_box_plot(*args, module=cupy), dask_cupy_func=_run_dask_cupy_box_plot, )