Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 83 additions & 43 deletions xrspatial/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,19 +1079,31 @@ def _run_head_tail_breaks(agg, module):

def _run_dask_head_tail_breaks(agg):
data = agg.data
data_clean = da.where(da.isinf(data), np.nan, data)
# Persist once so the iterative loop does not re-read from storage on
# every scan. Fuse the three reductions per iteration into a single
# dask.compute() to cut graph traversals from 3N+1 to N+1.
data_clean = da.where(da.isinf(data), np.nan, data).persist()
bins = []
mask = da.isfinite(data_clean)
total_count_lazy = mask.sum()
total_count = int(total_count_lazy.compute())
if total_count == 0:
max_v = float(da.nanmax(data_clean).compute())
return _bin(agg, np.array([max_v]), np.arange(1))

current_total = total_count
while True:
current = da.where(mask, data_clean, np.nan)
mean_v = float(da.nanmean(current).compute())
new_mask = mask & (data_clean > da.nanmean(current))
# Fuse mean and head-count into one graph evaluation.
mean_v, head_count = dask.compute(da.nanmean(current), new_mask.sum())
mean_v = float(mean_v)
head_count = int(head_count)
bins.append(mean_v)
new_mask = mask & (data_clean > mean_v)
head_count = int(new_mask.sum().compute())
total_count = int(mask.sum().compute())
if head_count == 0 or head_count / total_count > 0.40:
if head_count == 0 or head_count / current_total > 0.40:
break
mask = new_mask
current_total = head_count
max_v = float(da.nanmax(data_clean).compute())
bins.append(max_v)
bins = np.array(bins)
Expand Down Expand Up @@ -1366,6 +1378,22 @@ def maximum_breaks(agg: xr.DataArray,
attrs=agg.attrs)


_BOX_PLOT_DEFAULT_SAMPLE = 200_000


def _box_plot_bins_from_sample(finite_np, hinge, max_v):
q1 = float(np.percentile(finite_np, 25))
q2 = float(np.percentile(finite_np, 50))
q3 = float(np.percentile(finite_np, 75))
iqr = q3 - q1
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
bins = np.sort(np.unique(raw_bins))
bins = bins[bins <= max_v]
if bins[-1] < max_v:
bins = np.append(bins, max_v)
return bins


def _run_box_plot(agg, hinge, module):
data = agg.data
data_clean = module.where(module.isinf(data), np.nan, data)
Expand All @@ -1376,51 +1404,63 @@ def _run_box_plot(agg, hinge, module):
q2 = float(cupy.percentile(finite_data, 50).get())
q3 = float(cupy.percentile(finite_data, 75).get())
max_v = float(cupy.nanmax(finite_data).get())
elif module == da:
q1_l = da.percentile(finite_data, 25)
q2_l = da.percentile(finite_data, 50)
q3_l = da.percentile(finite_data, 75)
max_l = da.nanmax(data_clean)
q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l)
q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item()
else:
q1 = float(np.percentile(finite_data, 25))
q2 = float(np.percentile(finite_data, 50))
q3 = float(np.percentile(finite_data, 75))
max_v = float(np.nanmax(finite_data))
iqr = q3 - q1
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
bins = np.sort(np.unique(raw_bins))
bins = bins[bins <= max_v]
if bins[-1] < max_v:
bins = np.append(bins, max_v)
else: # numpy
finite_np = np.asarray(finite_data)
max_v = float(np.nanmax(finite_np))
bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)

iqr = q3 - q1
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
bins = np.sort(np.unique(raw_bins))
# Remove bins above max (they'd create empty classes)
bins = bins[bins <= max_v]
if bins[-1] < max_v:
bins = np.append(bins, max_v)
out = _bin(agg, bins, np.arange(len(bins)))
return out


def _run_dask_box_plot(agg, hinge):
"""Dask+numpy box_plot.

Avoids boolean fancy indexing on a dask array (which produces unknown
chunk sizes and forces a chunk-size compute pass). Samples the data
via the same seeded index sampler used by natural_breaks, then
computes percentiles on the finite portion of the sample in numpy.
"""
data = agg.data
data_clean = da.where(da.isinf(data), np.nan, data)
num_data = data_clean.size
num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data)
sample_idx = _generate_sample_indices(num_data, num_sample)

sample_lazy = data_clean.ravel()[sample_idx]
max_lazy = da.nanmax(data_clean)
sample_np, max_v = dask.compute(sample_lazy, max_lazy)
sample_np = np.asarray(sample_np)
finite_np = sample_np[np.isfinite(sample_np)]
max_v = float(max_v)

bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)
return _bin(agg, bins, np.arange(len(bins)))


def _run_dask_cupy_box_plot(agg, hinge):
"""Dask+cupy box_plot: sample on-device, transfer only the sample."""
data = agg.data
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
data_clean = da.where(da.isinf(data_cpu), np.nan, data_cpu)
finite_data = data_clean[da.isfinite(data_clean)]
data_clean = da.where(da.isinf(data), np.nan, data)
num_data = data_clean.size
num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data)
sample_idx = _generate_sample_indices(num_data, num_sample)

q1_l = da.percentile(finite_data, 25)
q2_l = da.percentile(finite_data, 50)
q3_l = da.percentile(finite_data, 75)
max_l = da.nanmax(data_clean)
q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l)
q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item()
sample_lazy = data_clean.ravel()[sample_idx]
max_lazy = da.nanmax(data_clean)
sample_cp, max_v = dask.compute(sample_lazy, max_lazy)
sample_np = cupy.asnumpy(sample_cp)
finite_np = sample_np[np.isfinite(sample_np)]
max_v = float(cupy.asnumpy(max_v).item()) if hasattr(max_v, 'get') else float(max_v)

iqr = q3 - q1
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
bins = np.sort(np.unique(raw_bins))
bins = bins[bins <= max_v]
if bins[-1] < max_v:
bins = np.append(bins, max_v)
out = _bin(agg, bins, np.arange(len(bins)))
return out
bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)
return _bin(agg, bins, np.arange(len(bins)))


@supports_dataset
Expand Down Expand Up @@ -1459,7 +1499,7 @@ def box_plot(agg: xr.DataArray,

mapper = ArrayTypeFunctionMapping(
numpy_func=lambda *args: _run_box_plot(*args, module=np),
dask_func=lambda *args: _run_box_plot(*args, module=da),
dask_func=_run_dask_box_plot,
cupy_func=lambda *args: _run_box_plot(*args, module=cupy),
dask_cupy_func=_run_dask_cupy_box_plot,
)
Expand Down
Loading