Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"networkx",
"numba>=0.55.0",
"numpy",
"ome_zarr>=0.12.2",
"ome_zarr>=0.14.0",
"pandas",
"pooch",
"pyarrow",
Expand Down
9 changes: 2 additions & 7 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from importlib.metadata import version
from typing import TYPE_CHECKING, Any

import spatialdata.models._accessor # noqa: F401

__version__ = version("spatialdata")

_submodules = {
Expand Down Expand Up @@ -129,15 +131,8 @@
"settings",
]

_accessor_loaded = False


def __getattr__(name: str) -> Any:
global _accessor_loaded
if not _accessor_loaded:
_accessor_loaded = True
import spatialdata.models._accessor # noqa: F401

if name in _submodules:
return importlib.import_module(f"spatialdata.{name}")
if name in _LAZY_IMPORTS:
Expand Down
145 changes: 134 additions & 11 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, TypeGuard

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -38,6 +39,131 @@
)


def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]:
if isinstance(value, str | bytes):
return False
if not isinstance(value, Sequence):
return False
return all(isinstance(v, int) for v in value)


def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]:
if isinstance(value, str | bytes):
return False
if not isinstance(value, Sequence):
return False
return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value)


def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool:
"""Check whether a Dask chunk grid is regular (zarr-compatible).

A grid is regular when every axis has at most one unique chunk size among all but the last
chunk, and the last chunk is not larger than the first.

Parameters
----------
chunk_grid
Per-axis tuple of chunk sizes, for instance as returned by ``dask_array.chunks``.

Examples
--------
Triggers ``continue`` on the first ``if`` (single or empty axis):

>>> _is_regular_dask_chunk_grid([(4,)]) # single chunk → True
True
>>> _is_regular_dask_chunk_grid([()]) # empty axis → True
True

Triggers the first ``return False`` (non-uniform interior chunks):

>>> _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) # interior sizes differ → False
False

Triggers the second ``return False`` (last chunk larger than the first):

>>> _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) # last > first → False
False

Exits with ``return True``:

>>> _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) # all equal → True
True
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) # last < first → True
True

Empty grid (loop never executes) → True:

>>> _is_regular_dask_chunk_grid([])
True

Multi-axis: all axes regular → True; one axis irregular → False:

>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)])
True
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)])
False
"""
# Match Dask's private _check_regular_chunks() logic without depending on its internal API.
for axis_chunks in chunk_grid:
if len(axis_chunks) <= 1:
continue
if len(set(axis_chunks[:-1])) > 1:
return False
if axis_chunks[-1] > axis_chunks[0]:
return False
return True
Comment on lines +58 to +115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a docstring with examples (or examples in-line with the code) to show what fails and what not.

I would add the following:
triggers the continue in the first if:

  • [(4,)]
  • [()]

triggers the first return False

  • [(4, 4, 3, 4)]

triggers the second return False

  • [(4, 4, 4, 5)]

exits with the last return True

  • [(4, 4, 4, 4)], succeeds, all chunks equal
  • [(4, 4, 4, 1)], succeeds, final chunk is < of the initial one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: I would add all the examples above in a test, for the function _is_regular_dask_chunk_grid().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring and tests here: 2450bd4



def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None:
if isinstance(chunks, int):
return chunks
if _is_flat_int_sequence(chunks):
return tuple(chunks)
if _is_dask_chunk_grid(chunks):
chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks)
if _is_regular_dask_chunk_grid(chunk_grid):
return tuple(axis_chunks[0] for axis_chunks in chunk_grid)
return None
return None


def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int:
normalized = _chunks_to_zarr_chunks(chunks)
if normalized is None:
raise ValueError(
'storage_options["chunks"] must resolve to a Zarr chunk shape or a regular Dask chunk grid. '
"The current raster has irregular Dask chunks, which cannot be written to Zarr. "
"To fix this, rechunk before writing, for example by passing regular chunks=... "
"to Image2DModel.parse(...) / Labels2DModel.parse(...)."
)
return normalized


def _prepare_storage_options(
storage_options: JSONDict | list[JSONDict] | None,
data: list[da.Array],
) -> list[JSONDict]:
if storage_options is None:
return [{"chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data]
if isinstance(storage_options, dict):
if "chunks" not in storage_options:
return [{**storage_options, "chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data]
prepared = dict(storage_options)
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
return prepared # type: ignore[return-value]

prepared_options = []
for i, options in enumerate(storage_options):
opts = dict(options)
if "chunks" not in opts:
opts["chunks"] = _normalize_explicit_chunks(data[i].chunks)
else:
opts["chunks"] = _normalize_explicit_chunks(opts["chunks"])
prepared_options.append(opts)
return prepared_options


def _read_multiscale(
store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
) -> DataArray | DataTree:
Expand Down Expand Up @@ -251,20 +377,18 @@ def _write_raster_dataarray(
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
input_axes: tuple[str, ...] = tuple(raster_data.dims)
chunks = raster_data.chunks
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
if storage_options is not None:
if "chunks" not in storage_options and isinstance(storage_options, dict):
storage_options["chunks"] = chunks
else:
storage_options = {"chunks": chunks}
# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
# We need this because the argument of write_image_ngff is called image while the argument of
storage_options = _prepare_storage_options(storage_options, [data])
# Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default
# write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ...
# even when the input is a plain DataArray.
# We need this because the argument of write_image_ngff is called image while the argument of
# write_labels_ngff is called label.
metadata[raster_type] = data
ome_zarr_format = get_ome_zarr_format(raster_format)
write_single_scale_ngff(
group=group,
scale_factors=[],
scaler=None,
fmt=ome_zarr_format,
axes=parsed_axes,
Expand Down Expand Up @@ -322,10 +446,9 @@ def _write_raster_datatree(
transformations = _get_transformations_xarray(xdata)
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
chunks = get_pyramid_levels(raster_data, "chunks")

parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
storage_options = [{"chunks": chunk} for chunk in chunks]
storage_options = _prepare_storage_options(storage_options, data)
ome_zarr_format = get_ome_zarr_format(raster_format)
dask_delayed = write_multi_scale_ngff(
pyramid=data,
Expand Down
20 changes: 10 additions & 10 deletions tests/io/test_partial_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR
sdata.write(sdata_path)

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
(sdata_path / "images" / corrupted / "0").touch()
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
(sdata_path / "images" / corrupted / "s0").touch()

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand All @@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
(sdata_path / "images" / corrupted / "0").touch()
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
(sdata_path / "images" / corrupted / "s0").touch()
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

return PartialReadTestCase(
Expand Down Expand Up @@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3(
sdata.write(sdata_path)

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json")
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json")
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand All @@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2(
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray")
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray")
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand Down
Loading