diff --git a/xrspatial/balanced_allocation.py b/xrspatial/balanced_allocation.py index 101c45a1..4bf013bb 100644 --- a/xrspatial/balanced_allocation.py +++ b/xrspatial/balanced_allocation.py @@ -158,7 +158,7 @@ def _allocate_from_costs(cost_stack, source_ids, fill_value=np.nan): else: all_nan = np.all(np.isnan(np.stack(cost_stack, axis=0)), axis=0) - alloc = alloc.astype(np.float32) + alloc = alloc.astype(np.float64) alloc[all_nan] = fill_value return alloc @@ -199,7 +199,7 @@ def _allocate_biased(cost_stack, biases, source_ids, fill_value=np.nan): stacked = np.stack(layers, axis=0) best_idx = np.argmin(stacked, axis=0) - alloc = source_ids[best_idx].astype(np.float32) + alloc = source_ids[best_idx].astype(np.float64) # Mark unreachable cells if da is not None and isinstance(first, da.Array): diff --git a/xrspatial/tests/test_balanced_allocation.py b/xrspatial/tests/test_balanced_allocation.py index 91839723..214a8c6d 100644 --- a/xrspatial/tests/test_balanced_allocation.py +++ b/xrspatial/tests/test_balanced_allocation.py @@ -270,6 +270,45 @@ def test_target_values(): assert {1.0, 2.0} == unique +# ----------------------------------------------------------------------- +# Non-integer source IDs (regression test for #1203) +# ----------------------------------------------------------------------- + +def test_non_integer_source_ids(): + """Non-integer source IDs should still produce balanced territories. + + Before the fix, _allocate_biased cast the allocation array to float32. + The iteration loop then compared float32 values against float64 source + IDs. For non-integer IDs the round-trip float64->float32->float64 + changes the value, so `alloc == sid` never matched, all weights were 0, + and the loop exited immediately without balancing. + """ + h, w = 12, 12 + data = np.zeros((h, w), dtype=np.float64) + data[2, 6] = 0.1 + data[9, 2] = 0.3 + data[9, 10] = 0.7 + + raster = _make_raster(data, chunks=(12, 12)) + friction = _make_raster(np.ones((h, w)), chunks=(12, 12)) + + result = balanced_allocation(raster, friction, tolerance=0.15) + out = _compute(result) + + # All three source IDs should be present + unique = set(np.unique(out[np.isfinite(out)]).astype(np.float32)) + assert len(unique) == 3 + + # Each territory should have at least 20% of cells (target is ~33%) + total = np.sum(np.isfinite(out)) + for sid_f32 in unique: + frac = np.sum(out == sid_f32) / total + assert frac > 0.20, ( + f"Source {sid_f32} only got {frac:.1%} of cells; " + f"balancing probably did not run" + ) + + @pytest.mark.skipif(da is None, reason="dask not installed") def test_balanced_allocation_memory_guard(): """Memory guard should raise before computing N cost surfaces."""