From aea89c5b24ec6028ce0d00117c3123a8f080b5fc Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 7 May 2026 10:06:13 +0200 Subject: [PATCH 1/3] Fix write validation not catching empty (rather than None) transformations After remove_transformation(element, remove_all=True) the transformations dict is set to {} rather than None, bypassing the is-None guard in all three IO writers. Changed the check to `not transformations` so both None and empty dicts are caught, and added a parametrized regression test covering images, multiscale images, labels, multiscale labels, shapes, and points. Co-Authored-By: Claude Sonnet 4.6 --- src/spatialdata/_io/io_points.py | 4 ++-- src/spatialdata/_io/io_raster.py | 4 ++-- src/spatialdata/_io/io_shapes.py | 2 +- tests/core/operations/test_transform.py | 27 +++++++++++++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 7b00b6ac..4b98fc59 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -95,6 +95,6 @@ def write_points( axes=list(axes), attrs=attrs, ) - if transformations is None: - raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.") + if not transformations: + raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 027ac8ce..ca4c9804 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -369,7 +369,7 @@ def _write_raster_dataarray( data = raster_data.data transformations = _get_transformations(raster_data) - if transformations is None: + if not transformations: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) @@ -439,7 +439,7 @@ def _write_raster_datatree( assert len(d) == 1 xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) - if transformations is None: + if not transformations: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 0907a56c..4743781c 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -100,7 +100,7 @@ def write_shapes( axes = get_axes_names(shapes) transformations = _get_transformations(shapes) - if transformations is None: + if not transformations: raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") if isinstance(element_format, ShapesFormatV01): attrs = _write_shapes_v01(shapes, group, element_format) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index a4e7077f..9e17a067 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -796,3 +796,30 @@ def test_transform_until_0_0_15(points): transform(points, transformation=t0, maintain_positioning=True) transform(points, to_coordinate_system="global", maintain_positioning=True) + + +@pytest.mark.parametrize( + "element_fixture,kwargs", + [ + ("image2d", {"images": {}}), + ("image2d_multiscale", {"images": {}}), + ("labels2d", {"labels": {}}), + ("labels2d_multiscale", {"labels": {}}), + ("circles", {"shapes": {}}), + ("points_0", {"points": {}}), + ], +) +def test_write_fails_after_removing_all_transformations( + full_sdata: SpatialData, tmp_path: Path, element_fixture: str, kwargs: dict +) -> None: + """Writing an element whose transformations have all been removed should raise a ValueError.""" + element = full_sdata[element_fixture] + remove_transformation(element, remove_all=True) + + # Build a minimal SpatialData with only this element and write it to a fresh location + container_key = next(iter(kwargs)) + sdata = SpatialData(**{container_key: {element_fixture: element}}) + tmpdir = tmp_path / "sdata.zarr" + + with pytest.raises(ValueError, match="does not have any transformations"): + sdata.write(tmpdir) From 2f183047b2aeba15000b3611ff38c611e093e734 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 7 May 2026 10:22:58 +0200 Subject: [PATCH 2/3] Move empty-transformation validation into model validate() methods Instead of guarding in each IO writer, call get_model(element) (which already dispatches to the right schema and runs validate()) at the start of _write_element() for all non-table spatial elements. Also fix the is-None guards in all three validate() methods to use `not transformations` / `not data.attrs.get(key)` so that an empty dict {} is caught in addition to None. The IO-level guards added in the previous commit are removed since they are now superseded by the model-level check; assert statements are kept to narrow the type for mypy. The regression test is updated to reflect the correct production scenario: element is already inside a SpatialData object when its transformations are removed in-place, so the error fires during write() not at construction. Co-Authored-By: Claude Sonnet 4.6 --- src/spatialdata/_core/spatialdata.py | 5 +++++ src/spatialdata/_io/io_points.py | 3 +-- src/spatialdata/_io/io_raster.py | 6 ++---- src/spatialdata/_io/io_shapes.py | 3 +-- src/spatialdata/models/models.py | 6 +++--- tests/core/operations/test_transform.py | 17 ++++++++--------- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225f..c41b9010 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1225,6 +1225,11 @@ def _write_element( if parsed_formats is None: parsed_formats = _parse_formats(formats=parsed_formats) + if element_type != "tables": + from spatialdata.models.models import get_model + + get_model(element) + if element_type == "images": write_image( image=element, diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 4b98fc59..800a78d3 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -67,6 +67,7 @@ def write_points( """ axes = get_axes_names(points) transformations = _get_transformations(points) + assert transformations is not None store_root = group.store_path.store.root path = store_root / group.path / "points.parquet" @@ -95,6 +96,4 @@ def write_points( axes=list(axes), attrs=attrs, ) - if not transformations: - raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index ca4c9804..c0416875 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -369,8 +369,7 @@ def _write_raster_dataarray( data = raster_data.data transformations = _get_transformations(raster_data) - if not transformations: - raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + assert transformations is not None input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) @@ -439,8 +438,7 @@ def _write_raster_datatree( assert len(d) == 1 xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) - if not transformations: - raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + assert transformations is not None parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 4743781c..71908e55 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -100,8 +100,7 @@ def write_shapes( axes = get_axes_names(shapes) transformations = _get_transformations(shapes) - if not transformations: - raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") + assert transformations is not None if isinstance(element_format, ShapesFormatV01): attrs = _write_shapes_v01(shapes, group, element_format) elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 8ca2a199..d8414a5d 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -342,7 +342,7 @@ def _validate_attrs(cls, data: DataArray) -> None: @classmethod def _check_transforms_present(cls, data: DataArray | DataTree) -> None: parsed_transform = _get_transformations(data) - if parsed_transform is None: + if not parsed_transform: raise ValueError( f"No transformation found for `{data}`. At least one transformation is required for " f"raster elements, e.g. images, labels." @@ -477,7 +477,7 @@ def validate(cls, data: GeoDataFrame) -> None: "please see https://github.com/scverse/spatialdata/discussions/657 for a solution. Otherwise, " "please correct the radii of the circles before calling the parser function.", ) - if cls.TRANSFORM_KEY not in data.attrs: + if not data.attrs.get(cls.TRANSFORM_KEY): raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) if len(data) > 0: n = data.geometry.iloc[0]._ndim @@ -668,7 +668,7 @@ def validate(cls, data: DaskDataFrame) -> None: np.int64, ]: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") - if cls.TRANSFORM_KEY not in data.attrs: + if not data.attrs.get(cls.TRANSFORM_KEY): raise ValueError( f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION ) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 9e17a067..0a251d5d 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -812,14 +812,13 @@ def test_transform_until_0_0_15(points): def test_write_fails_after_removing_all_transformations( full_sdata: SpatialData, tmp_path: Path, element_fixture: str, kwargs: dict ) -> None: - """Writing an element whose transformations have all been removed should raise a ValueError.""" - element = full_sdata[element_fixture] - remove_transformation(element, remove_all=True) - - # Build a minimal SpatialData with only this element and write it to a fresh location + """Writing should fail when all transformations are removed from an element already in a SpatialData.""" + # Build a valid SpatialData first (passes __setitem__ validation) container_key = next(iter(kwargs)) - sdata = SpatialData(**{container_key: {element_fixture: element}}) - tmpdir = tmp_path / "sdata.zarr" + sdata = SpatialData(**{container_key: {element_fixture: full_sdata[element_fixture]}}) + + # Mutate in-place after construction, bypassing __setitem__ validation + remove_transformation(sdata[element_fixture], remove_all=True) - with pytest.raises(ValueError, match="does not have any transformations"): - sdata.write(tmpdir) + with pytest.raises(ValueError, match="transform"): + sdata.write(tmp_path / "sdata.zarr") From 8c42cbee31ee955ed1948f54ae060aad21aa83d1 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 7 May 2026 12:28:47 +0200 Subject: [PATCH 3/3] Move empty-transformation validation into model validate() methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of guarding in each IO writer, call validate_element(element) (a new public helper in spatialdata.models that delegates to get_model) at the start of _write_element() for all non-table spatial elements. Validation changes in models.py: - RasterSchema._check_transforms_present: two explicit checks — one for None (key absent) and one for empty dict, with separate messages - ShapesModel.validate / PointsModel.validate: same split into two checks - asserts in IO files are kept solely for mypy type-narrowing, each annotated with a comment explaining that validate_element() guarantees the invariant at runtime New public API: - spatialdata.models.validate_element(e) raises ValueError if the element fails schema validation; documented in docs/api/models_utils.md Co-Authored-By: Claude Sonnet 4.6 --- docs/api/models_utils.md | 1 + src/spatialdata/_core/spatialdata.py | 4 +-- src/spatialdata/_io/io_points.py | 2 +- src/spatialdata/_io/io_raster.py | 4 +-- src/spatialdata/_io/io_shapes.py | 2 +- src/spatialdata/models/__init__.py | 2 ++ src/spatialdata/models/models.py | 38 +++++++++++++++++++++++++--- 7 files changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/api/models_utils.md b/docs/api/models_utils.md index a13751df..c2a74803 100644 --- a/docs/api/models_utils.md +++ b/docs/api/models_utils.md @@ -4,6 +4,7 @@ .. currentmodule:: spatialdata.models .. autofunction:: get_model +.. autofunction:: validate_element .. autodata:: SpatialElement .. autofunction:: get_axes_names .. autofunction:: get_spatial_axes diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c41b9010..760736c6 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1226,9 +1226,9 @@ def _write_element( parsed_formats = _parse_formats(formats=parsed_formats) if element_type != "tables": - from spatialdata.models.models import get_model + from spatialdata.models import validate_element - get_model(element) + validate_element(element) if element_type == "images": write_image( diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 800a78d3..03ef3338 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -67,7 +67,7 @@ def write_points( """ axes = get_axes_names(points) transformations = _get_transformations(points) - assert transformations is not None + assert transformations is not None # mypy: validate_element() in _write_element guarantees this store_root = group.store_path.store.root path = store_root / group.path / "points.parquet" diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index c0416875..4bba7887 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -369,7 +369,7 @@ def _write_raster_dataarray( data = raster_data.data transformations = _get_transformations(raster_data) - assert transformations is not None + assert transformations is not None # mypy: validate_element() in _write_element guarantees this input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) @@ -438,7 +438,7 @@ def _write_raster_datatree( assert len(d) == 1 xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) - assert transformations is not None + assert transformations is not None # mypy: validate_element() in _write_element guarantees this parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 71908e55..3b6e18e3 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -100,7 +100,7 @@ def write_shapes( axes = get_axes_names(shapes) transformations = _get_transformations(shapes) - assert transformations is not None + assert transformations is not None # mypy: validate_element() in _write_element guarantees this if isinstance(element_format, ShapesFormatV01): attrs = _write_shapes_v01(shapes, group, element_format) elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 2f25b4f3..01db9dc2 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -27,6 +27,7 @@ TableModel, get_model, get_table_keys, + validate_element, ) __all__ = [ @@ -51,6 +52,7 @@ "points_dask_dataframe_to_geopandas", "check_target_region_column_symmetry", "get_table_keys", + "validate_element", "get_channel_names", "set_channel_names", "force_2d", diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index d8414a5d..ef903c3d 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -342,11 +342,16 @@ def _validate_attrs(cls, data: DataArray) -> None: @classmethod def _check_transforms_present(cls, data: DataArray | DataTree) -> None: parsed_transform = _get_transformations(data) - if not parsed_transform: + if parsed_transform is None: raise ValueError( f"No transformation found for `{data}`. At least one transformation is required for " f"raster elements, e.g. images, labels." ) + if len(parsed_transform) == 0: + raise ValueError( + f"The transformations dict for `{data}` is empty. At least one transformation is required for " + f"raster elements, e.g. images, labels." + ) @classmethod def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None: @@ -477,8 +482,13 @@ def validate(cls, data: GeoDataFrame) -> None: "please see https://github.com/scverse/spatialdata/discussions/657 for a solution. Otherwise, " "please correct the radii of the circles before calling the parser function.", ) - if not data.attrs.get(cls.TRANSFORM_KEY): + if cls.TRANSFORM_KEY not in data.attrs: raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) + if not data.attrs[cls.TRANSFORM_KEY]: + raise ValueError( + f":class:`geopandas.GeoDataFrame` has an empty `{TRANSFORM_KEY}` dict. " + f"At least one transformation is required." + SUGGESTION + ) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -668,10 +678,15 @@ def validate(cls, data: DaskDataFrame) -> None: np.int64, ]: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") - if not data.attrs.get(cls.TRANSFORM_KEY): + if cls.TRANSFORM_KEY not in data.attrs: raise ValueError( f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION ) + if not data.attrs[cls.TRANSFORM_KEY]: + raise ValueError( + f":attr:`dask.dataframe.core.DataFrame.attrs` has an empty `{cls.TRANSFORM_KEY}` dict. " + f"At least one transformation is required." + SUGGESTION + ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] if feature_key not in data.columns: @@ -1293,6 +1308,23 @@ def _validate_and_return( raise TypeError(f"Unsupported type {type(e)}") +def validate_element(e: SpatialElement) -> None: + """ + Validate a spatial element against its model schema. + + Parameters + ---------- + e + The spatial element to validate. + + Raises + ------ + ValueError + If the element is invalid (e.g. missing or empty transformations, wrong dtypes). + """ + get_model(e, validate=True) + + def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: """ Get the table keys giving information about what spatial element is annotated.