Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,15 @@ async def set_if_not_exists(self, default: Buffer) -> None:


class _ShardIndex(NamedTuple):
# the chunk grid shape of a single shard
chunks_per_shard: tuple[int, ...]
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
offsets_and_lengths: npt.NDArray[np.uint64]

@property
def chunks_per_shard(self) -> tuple[int, ...]:
result = tuple(self.offsets_and_lengths.shape[0:-1])
# The cast is required until https://github.com/numpy/numpy/pull/27211 is merged
return cast("tuple[int, ...]", result)

def _localize_chunk(self, chunk_coords: tuple[int, ...]) -> tuple[int, ...]:
return tuple(
chunk_i % shard_i
for chunk_i, shard_i in zip(chunk_coords, self.offsets_and_lengths.shape, strict=False)
for chunk_i, shard_i in zip(chunk_coords, self.chunks_per_shard, strict=False)
)

def is_all_empty(self) -> bool:
Expand Down Expand Up @@ -171,25 +167,24 @@ def get_chunk_slices_vectorized(
valid : ndarray of shape (n_chunks,)
Boolean mask indicating which chunks are non-empty.
"""
# Handle 0-dimensional arrays (n_dims == 0)
# Handle 0-dimensional arrays (n_dims == 0): the shard holds a single
# chunk, so every coordinate maps to the same flat entry.
if chunk_coords_array.shape[1] == 0:
# offsets_and_lengths has shape (2,) for 0D, reshape to (1, 2)
offsets_and_lengths = self.offsets_and_lengths.reshape(1, 2)
starts = offsets_and_lengths[:, 0]
lengths = offsets_and_lengths[:, 1]
valid = starts != MAX_UINT_64
ends = starts + lengths
return starts, ends, valid

# Localize coordinates via modulo (vectorized)
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
localized = chunk_coords_array.astype(np.uint64) % shard_shape

# Build index tuple for advanced indexing
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))

# Fetch all offsets and lengths at once
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
offsets_and_lengths = self.offsets_and_lengths.reshape(-1, 2)
offsets_and_lengths = np.broadcast_to(
offsets_and_lengths, (chunk_coords_array.shape[0], 2)
)
else:
# Localize coordinates via modulo (vectorized)
shard_shape = np.array(self.chunks_per_shard, dtype=np.uint64)
localized = chunk_coords_array.astype(np.uint64) % shard_shape

# Build index tuple for advanced indexing
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))

# Fetch all offsets and lengths at once
offsets_and_lengths = self.offsets_and_lengths[index_tuple]

starts = offsets_and_lengths[:, 0]
lengths = offsets_and_lengths[:, 1]

Expand All @@ -215,7 +210,7 @@ def is_dense(self, chunk_byte_length: int) -> bool:
sorted_offsets_and_lengths = sorted(
[
(offset, length)
for offset, length in self.offsets_and_lengths
for offset, length in self.offsets_and_lengths.reshape(-1, 2)
if offset != MAX_UINT_64
],
key=itemgetter(0),
Expand All @@ -236,7 +231,7 @@ def is_dense(self, chunk_byte_length: int) -> bool:
def create_empty(cls, chunks_per_shard: tuple[int, ...]) -> _ShardIndex:
offsets_and_lengths = np.zeros(chunks_per_shard + (2,), dtype="<u8", order="C")
offsets_and_lengths.fill(MAX_UINT_64)
return cls(offsets_and_lengths)
return cls(chunks_per_shard, offsets_and_lengths)


class _ShardReader(ShardMapping):
Expand Down Expand Up @@ -280,7 +275,7 @@ def __len__(self) -> int:
return int(self.index.offsets_and_lengths.size / 2)

def __iter__(self) -> Iterator[tuple[int, ...]]:
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
return c_order_iter(self.index.chunks_per_shard)

def to_dict_vectorized(
self,
Expand All @@ -298,8 +293,7 @@ def to_dict_vectorized(
dict mapping chunk coordinate tuples to Buffer or None
"""
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
chunk_coords_keys = _morton_order_keys(self.index.chunks_per_shard)

result: dict[tuple[int, ...], Buffer | None] = {}
for i, coords in enumerate(chunk_coords_keys):
Expand Down Expand Up @@ -712,7 +706,7 @@ async def _decode_shard_index(
)
# This cannot be None because we have the bytes already
index_array = cast(NDBuffer, index_array)
return _ShardIndex(index_array.as_numpy_array())
return _ShardIndex(chunks_per_shard, index_array.as_numpy_array())

async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
index_bytes = next(
Expand Down
45 changes: 30 additions & 15 deletions tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from zarr.codecs.sharding import MAX_UINT_64, _ShardIndex
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
from zarr.core.indexing import c_order_iter
from zarr.storage import StorePath, ZipStore

from ..conftest import ArrayRequest
Expand Down Expand Up @@ -567,18 +568,32 @@ def test_sharding_zero_dimensional() -> None:
assert arr[()] == pytest.approx(43.0)


def test_shard_index_get_chunk_slices_vectorized_zero_dimensional() -> None:
"""Directly cover the 0-D path in _ShardIndex.get_chunk_slices_vectorized."""
# For a 0D array offsets_and_lengths has shape (2,) — reshape to (1, 2) inside.
index = _ShardIndex(np.array([10, 4], dtype=np.uint64))
chunk_coords = np.empty((1, 0), dtype=np.uint64)
starts, ends, valid = index.get_chunk_slices_vectorized(chunk_coords)
np.testing.assert_array_equal(starts, np.array([10], dtype=np.uint64))
np.testing.assert_array_equal(ends, np.array([14], dtype=np.uint64))
np.testing.assert_array_equal(valid, np.array([True]))

# Empty/unwritten chunk case
index_empty = _ShardIndex(np.array([MAX_UINT_64, MAX_UINT_64], dtype=np.uint64))
starts_e, _ends_e, valid_e = index_empty.get_chunk_slices_vectorized(chunk_coords)
np.testing.assert_array_equal(starts_e, np.array([MAX_UINT_64], dtype=np.uint64))
np.testing.assert_array_equal(valid_e, np.array([False]))
def test_shard_index_stores_chunks_per_shard_explicitly() -> None:
"""_ShardIndex stores the chunk grid shape as an explicit field."""
index = _ShardIndex.create_empty((2, 3))
assert index.chunks_per_shard == (2, 3)

# 0-D: chunks_per_shard is the empty tuple, distinct from the array's rank
index_0d = _ShardIndex.create_empty(())
assert index_0d.chunks_per_shard == ()


@pytest.mark.parametrize("chunks_per_shard", [(), (3,), (2, 3)])
def test_shard_index_get_chunk_slices_vectorized(chunks_per_shard: tuple[int, ...]) -> None:
"""get_chunk_slices_vectorized works uniformly across chunk grid ranks, including 0-D."""
index = _ShardIndex.create_empty(chunks_per_shard)
# Write the first chunk; leave the rest (if any) empty.
all_coords = list(c_order_iter(chunks_per_shard))
index.set_chunk_slice(all_coords[0], slice(10, 14))

coords_array = np.array(all_coords, dtype=np.uint64).reshape(
len(all_coords), len(chunks_per_shard)
)
starts, ends, valid = index.get_chunk_slices_vectorized(coords_array)

expected_valid = np.zeros(len(all_coords), dtype=bool)
expected_valid[0] = True
np.testing.assert_array_equal(valid, expected_valid)
assert starts[0] == 10
assert ends[0] == 14
np.testing.assert_array_equal(starts[~expected_valid], MAX_UINT_64)
Loading