diff --git a/pyproject.toml b/pyproject.toml index e5f3134aa..07ec8140b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "networkx", "numba>=0.55.0", "numpy", - "ome_zarr>=0.12.2", + "ome_zarr>=0.14.0", "pandas", "pooch", "pyarrow", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 19be99d56..7ba66e710 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -4,6 +4,8 @@ from importlib.metadata import version from typing import TYPE_CHECKING, Any +import spatialdata.models._accessor # noqa: F401 + __version__ = version("spatialdata") _submodules = { @@ -129,15 +131,8 @@ "settings", ] -_accessor_loaded = False - def __getattr__(name: str) -> Any: - global _accessor_loaded - if not _accessor_loaded: - _accessor_loaded = True - import spatialdata.models._accessor # noqa: F401 - if name in _submodules: return importlib.import_module(f"spatialdata.{name}") if name in _LAZY_IMPORTS: diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index df7e1cb8f..a8b2ab2ce 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, TypeGuard import dask.array as da import numpy as np @@ -38,6 +39,126 @@ ) +def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: + # e.g. "", "auto" or b"auto" + if isinstance(value, str | bytes): + return False + if not isinstance(value, Sequence): + return False + return all(isinstance(v, int) for v in value) + + +def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: + if isinstance(value, str | bytes): + return False + if not isinstance(value, Sequence): + return False + return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) + + +def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: + """Check whether a Dask chunk grid is regular (zarr-compatible). + + A grid is regular when every axis has at most one unique chunk size among all but the last + chunk, and the last chunk is not larger than the first. + + Parameters + ---------- + chunk_grid + Per-axis tuple of chunk sizes, for instance as returned by ``dask_array.chunks``. + + Examples + -------- + Triggers ``continue`` on the first ``if`` (single or empty axis): + + >>> _is_regular_dask_chunk_grid([(4,)]) # single chunk → True + True + >>> _is_regular_dask_chunk_grid([()]) # empty axis → True + True + + Triggers the first ``return False`` (non-uniform interior chunks): + + >>> _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) # interior sizes differ → False + False + + Triggers the second ``return False`` (last chunk larger than the first): + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) # last > first → False + False + + Exits with ``return True``: + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) # all equal → True + True + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) # last < first → True + True + + Empty grid (loop never executes) → True: + + >>> _is_regular_dask_chunk_grid([]) + True + + Multi-axis: all axes regular → True; one axis irregular → False: + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)]) + True + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)]) + False + """ + # Match Dask's private _check_regular_chunks() logic without depending on its internal API. + for axis_chunks in chunk_grid: + if len(axis_chunks) <= 1: + continue + if len(set(axis_chunks[:-1])) > 1: + return False + if axis_chunks[-1] > axis_chunks[0]: + return False + return True + + +def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None: + if isinstance(chunks, int): + return chunks + if _is_flat_int_sequence(chunks): + return tuple(chunks) + if _is_dask_chunk_grid(chunks): + chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks) + if _is_regular_dask_chunk_grid(chunk_grid): + return tuple(axis_chunks[0] for axis_chunks in chunk_grid) + return None + return None + + +def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: + normalized = _chunks_to_zarr_chunks(chunks) + if normalized is None: + raise ValueError( + 'storage_options["chunks"] must resolve to a Zarr chunk shape or a regular Dask chunk grid. ' + "The current raster has irregular Dask chunks, which cannot be written to Zarr. " + "To fix this, rechunk before writing, for example by passing regular chunks=... " + "to Image2DModel.parse(...) / Labels2DModel.parse(...)." + ) + return normalized + + +def _prepare_storage_options( + storage_options: JSONDict | list[JSONDict] | None, +) -> JSONDict | list[JSONDict] | None: + if storage_options is None: + return None + if isinstance(storage_options, dict): + prepared = dict(storage_options) + if "chunks" in prepared: + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) + return prepared + + prepared_options = [dict(options) for options in storage_options] + for options in prepared_options: + if "chunks" in options: + options["chunks"] = _normalize_explicit_chunks(options["chunks"]) + return prepared_options + + def _read_multiscale( store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: @@ -251,20 +372,18 @@ def _write_raster_dataarray( if transformations is None: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of + storage_options = _prepare_storage_options(storage_options) + # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default + # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... + # even when the input is a plain DataArray. + # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data ome_zarr_format = get_ome_zarr_format(raster_format) write_single_scale_ngff( group=group, + scale_factors=[], scaler=None, fmt=ome_zarr_format, axes=parsed_axes, @@ -322,10 +441,9 @@ def _write_raster_datatree( transformations = _get_transformations_xarray(xdata) if transformations is None: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") - chunks = get_pyramid_levels(raster_data, "chunks") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = [{"chunks": chunk} for chunk in chunks] + storage_options = _prepare_storage_options(storage_options) ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 7c5d47841..28460046e 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") - (sdata_path / "images" / corrupted / "0").touch() + os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") + (sdata_path / "images" / corrupted / "s0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") - (sdata_path / "images" / corrupted / "0").touch() + os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") + (sdata_path / "images" / corrupted / "s0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3( sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2( sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index be07d8be8..209a43046 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Literal +import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd @@ -18,6 +19,7 @@ from packaging.version import Version from shapely import MultiPolygon, Polygon from upath import UPath +from xarray import DataArray from zarr.errors import GroupNotFoundError import spatialdata.config @@ -30,6 +32,7 @@ SpatialDataContainerFormatType, SpatialDataContainerFormatV01, ) +from spatialdata._io.io_raster import write_image from spatialdata.datasets import blobs from spatialdata.models import Image2DModel from spatialdata.models._utils import get_channel_names @@ -623,6 +626,123 @@ def test_bug_rechunking_after_queried_raster(): queried.write(f) +def test_is_regular_dask_chunk_grid() -> None: + from spatialdata._io.io_raster import _is_regular_dask_chunk_grid + + # Single chunk per axis → continue branch, overall True + assert _is_regular_dask_chunk_grid([(4,)]) is True + # Empty axis → continue branch, overall True + assert _is_regular_dask_chunk_grid([()]) is True + # Non-uniform interior chunks → first return False + assert _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) is False + # Last chunk larger than first → second return False + assert _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) is False + # All chunks equal → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) is True + # Last chunk smaller than first → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) is True + # Empty grid (no axes) → True + assert _is_regular_dask_chunk_grid([]) is True + # Multi-axis: all axes regular → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)]) is True + # Multi-axis: one axis irregular → False + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)]) is False + + +def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image}) + + path = tmp_path / "data.zarr" + with pytest.warns(UserWarning, match="irregular chunk sizes"): + sdata.write(path) + sdata_back = read_zarr(path) + assert sdata_back["image"].chunks == ((3,), (300, 300, 200), (512, 488)) + + +def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) + + assert group["s0"].chunks == (3, 300, 512) + + +def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) + + +def test_write_image_normalizes_explicit_zarr_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + zarr_chunks = (3, 100, 512) # ome zarr rechunks when writing + write_image(image, group, "image", storage_options={"chunks": zarr_chunks}) + + assert group["s0"].chunks == (3, 100, 512) + + +def test_write_image_rejects_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": "auto"}) + + +def test_write_image_rejects_empty_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": ""}) + + +def test_write_image_rejects_byte_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": b"auto"}) + + +def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: + image = Image2DModel.parse(RNG.random((3, 64, 64)), dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image}) + path = tmp_path / "data.zarr" + + sdata.write(path) + sdata_back = read_zarr(path) + + assert isinstance(sdata_back["image"], DataArray) + image_group = zarr.open_group(path / "images" / "image", mode="r") + assert list(image_group.keys()) == ["s0"] + + @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained