Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.. currentmodule:: spatialdata.models

.. autofunction:: get_model
.. autofunction:: validate_element
.. autodata:: SpatialElement
.. autofunction:: get_axes_names
.. autofunction:: get_spatial_axes
Expand Down
5 changes: 5 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,11 @@ def _write_element(
if parsed_formats is None:
parsed_formats = _parse_formats(formats=parsed_formats)

if element_type != "tables":
from spatialdata.models import validate_element

validate_element(element)

if element_type == "images":
write_image(
image=element,
Expand Down
3 changes: 1 addition & 2 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def write_points(
"""
axes = get_axes_names(points)
transformations = _get_transformations(points)
assert transformations is not None # mypy: validate_element() in _write_element guarantees this

store_root = group.store_path.store.root
path = store_root / group.path / "points.parquet"
Expand Down Expand Up @@ -95,6 +96,4 @@ def write_points(
axes=list(axes),
attrs=attrs,
)
if transformations is None:
raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.")
overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations)
6 changes: 2 additions & 4 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ def _write_raster_dataarray(

data = raster_data.data
transformations = _get_transformations(raster_data)
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
input_axes: tuple[str, ...] = tuple(raster_data.dims)
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
storage_options = _prepare_storage_options(storage_options)
Expand Down Expand Up @@ -439,8 +438,7 @@ def _write_raster_datatree(
assert len(d) == 1
xdata = d.values().__iter__().__next__()
transformations = _get_transformations_xarray(xdata)
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
assert transformations is not None # mypy: validate_element() in _write_element guarantees this

parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
storage_options = _prepare_storage_options(storage_options)
Expand Down
3 changes: 1 addition & 2 deletions src/spatialdata/_io/io_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def write_shapes(

axes = get_axes_names(shapes)
transformations = _get_transformations(shapes)
if transformations is None:
raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.")
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
if isinstance(element_format, ShapesFormatV01):
attrs = _write_shapes_v01(shapes, group, element_format)
elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03):
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TableModel,
get_model,
get_table_keys,
validate_element,
)

__all__ = [
Expand All @@ -51,6 +52,7 @@
"points_dask_dataframe_to_geopandas",
"check_target_region_column_symmetry",
"get_table_keys",
"validate_element",
"get_channel_names",
"set_channel_names",
"force_2d",
Expand Down
32 changes: 32 additions & 0 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,11 @@ def _check_transforms_present(cls, data: DataArray | DataTree) -> None:
f"No transformation found for `{data}`. At least one transformation is required for "
f"raster elements, e.g. images, labels."
)
if len(parsed_transform) == 0:
raise ValueError(
f"The transformations dict for `{data}` is empty. At least one transformation is required for "
f"raster elements, e.g. images, labels."
)

@classmethod
def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None:
Expand Down Expand Up @@ -479,6 +484,11 @@ def validate(cls, data: GeoDataFrame) -> None:
)
if cls.TRANSFORM_KEY not in data.attrs:
raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION)
if not data.attrs[cls.TRANSFORM_KEY]:
raise ValueError(
f":class:`geopandas.GeoDataFrame` has an empty `{TRANSFORM_KEY}` dict. "
f"At least one transformation is required." + SUGGESTION
)
if len(data) > 0:
n = data.geometry.iloc[0]._ndim
if n != 2:
Expand Down Expand Up @@ -672,6 +682,11 @@ def validate(cls, data: DaskDataFrame) -> None:
raise ValueError(
f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION
)
if not data.attrs[cls.TRANSFORM_KEY]:
raise ValueError(
f":attr:`dask.dataframe.core.DataFrame.attrs` has an empty `{cls.TRANSFORM_KEY}` dict. "
f"At least one transformation is required." + SUGGESTION
)
if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]:
feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY]
if feature_key not in data.columns:
Expand Down Expand Up @@ -1293,6 +1308,23 @@ def _validate_and_return(
raise TypeError(f"Unsupported type {type(e)}")


def validate_element(e: SpatialElement) -> None:
"""
Validate a spatial element against its model schema.

Parameters
----------
e
The spatial element to validate.

Raises
------
ValueError
If the element is invalid (e.g. missing or empty transformations, wrong dtypes).
"""
get_model(e, validate=True)


def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]:
"""
Get the table keys giving information about what spatial element is annotated.
Expand Down
26 changes: 26 additions & 0 deletions tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,29 @@ def test_transform_until_0_0_15(points):

transform(points, transformation=t0, maintain_positioning=True)
transform(points, to_coordinate_system="global", maintain_positioning=True)


@pytest.mark.parametrize(
"element_fixture,kwargs",
[
("image2d", {"images": {}}),
("image2d_multiscale", {"images": {}}),
("labels2d", {"labels": {}}),
("labels2d_multiscale", {"labels": {}}),
("circles", {"shapes": {}}),
("points_0", {"points": {}}),
],
)
def test_write_fails_after_removing_all_transformations(
full_sdata: SpatialData, tmp_path: Path, element_fixture: str, kwargs: dict
) -> None:
"""Writing should fail when all transformations are removed from an element already in a SpatialData."""
# Build a valid SpatialData first (passes __setitem__ validation)
container_key = next(iter(kwargs))
sdata = SpatialData(**{container_key: {element_fixture: full_sdata[element_fixture]}})

# Mutate in-place after construction, bypassing __setitem__ validation
remove_transformation(sdata[element_fixture], remove_all=True)

with pytest.raises(ValueError, match="transform"):
sdata.write(tmp_path / "sdata.zarr")
Loading