diff --git a/src/spatialdata/_core/deconcatenate.py b/src/spatialdata/_core/deconcatenate.py new file mode 100644 index 000000000..87c60a9e5 --- /dev/null +++ b/src/spatialdata/_core/deconcatenate.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from spatialdata._core.query.relational_query import match_sdata_to_table +from spatialdata._core.spatialdata import SpatialData + + +def deconcatenate( + full_sdata: SpatialData, + by: str | Iterable[str], + target_coordinate_system: str, + full_sdata_table_name: str = "table", + sdatas_table_names: str | Iterable[str] = "table", + region_key: str = "region", + join: str = "right", +) -> SpatialData | list[SpatialData]: + """ + From a `SpatialData` object containing multiple regions, returns `SpatialData` objects specific to each region identified in `by`. + """ + if full_sdata_table_name not in full_sdata.tables: + raise KeyError("Missing table") + + sdata_table = full_sdata[full_sdata_table_name] + + is_single_region = isinstance(by, str) + deconcat_regions = [by] if is_single_region else list(by) + sdatas_table_names = ( + [sdatas_table_names] * len(deconcat_regions) + if isinstance(sdatas_table_names, str) + else list(sdatas_table_names) + ) + + sdatas = [] + for region, table_name in zip(deconcat_regions, sdatas_table_names): + deconcat_table = sdata_table[sdata_table.obs[region_key] == region] + deconcat_sdata = match_sdata_to_table(full_sdata, table=deconcat_table, table_name=table_name, how=join) + + sdatas.append(deconcat_sdata) + + return sdatas[0] if is_single_region else sdatas diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 68b538e0a..0d9ddab04 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -10,6 +10,7 @@ from spatialdata._core.concatenate import _concatenate_tables, concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent +from spatialdata._core.deconcatenate import deconcatenate from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -693,3 +694,27 @@ def test_validate_table_in_spatialdata(full_sdata): del full_sdata.points["points_0"] with pytest.warns(UserWarning, match="in the SpatialData object"): full_sdata.validate_table_in_spatialdata(table) + + +def test_deconcatenate(full_sdata): + + regions = ["region1", "region2"] + table_names = ["table1", "table2"] + assert len(regions) == len(table_names) + + # MULTIPLE REGIONS === + sdatas = deconcatenate(full_sdata, by=regions, target_coordinate_system="global", sdatas_table_names=table_names) + + assert isinstance(sdatas, list) + assert len(sdatas) == len(regions) + + for sdata, region, table_name in zip(sdatas, regions, table_names): + assert table_name in sdata.tables + table = sdata.tables[table_name] + assert (table.obs["region"] == region).all() + + # SINGLE REGION === + single = deconcatenate(full_sdata, by=regions[0], target_coordinate_system="global") + + assert not isinstance(single, list) + assert "table" in single.tables