diff --git a/docs/api/models_utils.md b/docs/api/models_utils.md index a13751dfb..c2a74803b 100644 --- a/docs/api/models_utils.md +++ b/docs/api/models_utils.md @@ -4,6 +4,7 @@ .. currentmodule:: spatialdata.models .. autofunction:: get_model +.. autofunction:: validate_element .. autodata:: SpatialElement .. autofunction:: get_axes_names .. autofunction:: get_spatial_axes diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225fe..760736c6a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1225,6 +1225,11 @@ def _write_element( if parsed_formats is None: parsed_formats = _parse_formats(formats=parsed_formats) + if element_type != "tables": + from spatialdata.models import validate_element + + validate_element(element) + if element_type == "images": write_image( image=element, diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 7b00b6acc..03ef33389 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -67,6 +67,7 @@ def write_points( """ axes = get_axes_names(points) transformations = _get_transformations(points) + assert transformations is not None # mypy: validate_element() in _write_element guarantees this store_root = group.store_path.store.root path = store_root / group.path / "points.parquet" @@ -95,6 +96,4 @@ def write_points( axes=list(axes), attrs=attrs, ) - if transformations is None: - raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.") overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 027ac8cec..4bba78874 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -369,8 +369,7 @@ def _write_raster_dataarray( data = raster_data.data transformations = _get_transformations(raster_data) - if transformations is None: - raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + assert transformations is not None # mypy: validate_element() in _write_element guarantees this input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) @@ -439,8 +438,7 @@ def _write_raster_datatree( assert len(d) == 1 xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) - if transformations is None: - raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + assert transformations is not None # mypy: validate_element() in _write_element guarantees this parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_storage_options(storage_options) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 0907a56c2..3b6e18e39 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -100,8 +100,7 @@ def write_shapes( axes = get_axes_names(shapes) transformations = _get_transformations(shapes) - if transformations is None: - raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") + assert transformations is not None # mypy: validate_element() in _write_element guarantees this if isinstance(element_format, ShapesFormatV01): attrs = _write_shapes_v01(shapes, group, element_format) elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 2f25b4f3c..01db9dc2d 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -27,6 +27,7 @@ TableModel, get_model, get_table_keys, + validate_element, ) __all__ = [ @@ -51,6 +52,7 @@ "points_dask_dataframe_to_geopandas", "check_target_region_column_symmetry", "get_table_keys", + "validate_element", "get_channel_names", "set_channel_names", "force_2d", diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 8ca2a1999..ef903c3df 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -347,6 +347,11 @@ def _check_transforms_present(cls, data: DataArray | DataTree) -> None: f"No transformation found for `{data}`. At least one transformation is required for " f"raster elements, e.g. images, labels." ) + if len(parsed_transform) == 0: + raise ValueError( + f"The transformations dict for `{data}` is empty. At least one transformation is required for " + f"raster elements, e.g. images, labels." + ) @classmethod def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None: @@ -479,6 +484,11 @@ def validate(cls, data: GeoDataFrame) -> None: ) if cls.TRANSFORM_KEY not in data.attrs: raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) + if not data.attrs[cls.TRANSFORM_KEY]: + raise ValueError( + f":class:`geopandas.GeoDataFrame` has an empty `{TRANSFORM_KEY}` dict. " + f"At least one transformation is required." + SUGGESTION + ) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -672,6 +682,11 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError( f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION ) + if not data.attrs[cls.TRANSFORM_KEY]: + raise ValueError( + f":attr:`dask.dataframe.core.DataFrame.attrs` has an empty `{cls.TRANSFORM_KEY}` dict. " + f"At least one transformation is required." + SUGGESTION + ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] if feature_key not in data.columns: @@ -1293,6 +1308,23 @@ def _validate_and_return( raise TypeError(f"Unsupported type {type(e)}") +def validate_element(e: SpatialElement) -> None: + """ + Validate a spatial element against its model schema. + + Parameters + ---------- + e + The spatial element to validate. + + Raises + ------ + ValueError + If the element is invalid (e.g. missing or empty transformations, wrong dtypes). + """ + get_model(e, validate=True) + + def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: """ Get the table keys giving information about what spatial element is annotated. diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index a4e7077f8..0a251d5d2 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -796,3 +796,29 @@ def test_transform_until_0_0_15(points): transform(points, transformation=t0, maintain_positioning=True) transform(points, to_coordinate_system="global", maintain_positioning=True) + + +@pytest.mark.parametrize( + "element_fixture,kwargs", + [ + ("image2d", {"images": {}}), + ("image2d_multiscale", {"images": {}}), + ("labels2d", {"labels": {}}), + ("labels2d_multiscale", {"labels": {}}), + ("circles", {"shapes": {}}), + ("points_0", {"points": {}}), + ], +) +def test_write_fails_after_removing_all_transformations( + full_sdata: SpatialData, tmp_path: Path, element_fixture: str, kwargs: dict +) -> None: + """Writing should fail when all transformations are removed from an element already in a SpatialData.""" + # Build a valid SpatialData first (passes __setitem__ validation) + container_key = next(iter(kwargs)) + sdata = SpatialData(**{container_key: {element_fixture: full_sdata[element_fixture]}}) + + # Mutate in-place after construction, bypassing __setitem__ validation + remove_transformation(sdata[element_fixture], remove_all=True) + + with pytest.raises(ValueError, match="transform"): + sdata.write(tmp_path / "sdata.zarr")