Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 126 additions & 33 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import threading
import warnings
from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass, replace
from enum import Enum
Expand Down Expand Up @@ -51,6 +53,7 @@
morton_order_iter,
)
from zarr.core.metadata.v3 import parse_codecs
from zarr.errors import ZarrUserWarning
from zarr.registry import get_ndbuffer_class, get_pipeline_class

if TYPE_CHECKING:
Expand All @@ -64,6 +67,64 @@
ShardMapping = Mapping[tuple[int, ...], Buffer | None]
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None]

# Track in-progress partial shard writes to detect concurrent access
# Key: shard identifier (store_id, path), Value: count of active writes
_active_partial_writes: dict[tuple[int, str], int] = {}
_active_partial_writes_lock = threading.Lock()
_warned_shards: set[tuple[int, str]] = set() # Track shards we've already warned about


def _get_shard_key(byte_setter: ByteSetter) -> tuple[int, str]:
"""Get a unique key for a shard from its byte_setter.

Uses the store's id and path to create a unique identifier.
"""
path = getattr(byte_setter, "path", "")
store = getattr(byte_setter, "store", byte_setter)
return (id(store), path)


def _is_full_shard_selection(selection: SelectorTuple, shard_shape: tuple[int, ...]) -> bool:
"""Check if a selection covers the entire shard.

Parameters
----------
selection : SelectorTuple
The selection (tuple of slices) being written to
shard_shape : tuple[int, ...]
The shape of the shard

Returns
-------
bool
True if selection covers the entire shard, False otherwise
"""
# SelectorTuple can be a tuple, ndarray, or slice
# Only tuple selections can potentially cover a full shard
if not isinstance(selection, tuple):
return False

if len(selection) != len(shard_shape):
return False

for sel, dim_size in zip(selection, shard_shape, strict=False):
if isinstance(sel, slice):
start = sel.start or 0
stop = sel.stop if sel.stop is not None else dim_size
step = sel.step or 1
# Check if this slice covers the full dimension
if start != 0 or stop != dim_size or step != 1:
return False
elif isinstance(sel, int):
# Single index selection never covers full dimension (unless dim_size == 1)
if dim_size != 1:
return False
else:
# Unknown selection type, assume partial
return False

return True


class ShardingCodecIndexLocation(Enum):
"""
Expand Down Expand Up @@ -499,43 +560,75 @@ async def _encode_partial_single(
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
chunk_spec = self._get_chunk_spec(shard_spec)

shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
# Check if we need to track this write for concurrent access detection
# Only track partial writes (full shard writes don't have race conditions)
shard_key: tuple[int, str] | None = None
if not _is_full_shard_selection(selection, shard_shape):
shard_key = _get_shard_key(byte_setter)
with _active_partial_writes_lock:
current_count = _active_partial_writes.get(shard_key, 0)
# Warn only on concurrent access (another write already in progress)
if current_count > 0 and shard_key not in _warned_shards:
_warned_shards.add(shard_key)
warnings.warn(
"Concurrent partial shard writes detected. "
"Writing to different regions of the same shard concurrently "
"may result in data corruption due to read-modify-write race conditions. "
"Consider aligning your write chunks with shard boundaries, "
"or use a lock to coordinate writes.",
ZarrUserWarning,
stacklevel=6,
)
_active_partial_writes[shard_key] = current_count + 1

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
try:
shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
)

await self.codec_pipeline.write(
[
(
_ShardingByteSetter(shard_dict, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}

indexer = list(
get_indexer(
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
buf = await self._encode_shard_dict(
shard_dict,
chunks_per_shard=chunks_per_shard,
buffer_prototype=default_buffer_prototype(),
)
)

if buf is None:
await byte_setter.delete()
else:
await byte_setter.set(buf)
await self.codec_pipeline.write(
[
(
_ShardingByteSetter(shard_dict, chunk_coords),
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
buf = await self._encode_shard_dict(
shard_dict,
chunks_per_shard=chunks_per_shard,
buffer_prototype=default_buffer_prototype(),
)

if buf is None:
await byte_setter.delete()
else:
await byte_setter.set(buf)
finally:
# Clean up tracking if we were tracking this write
if shard_key is not None:
with _active_partial_writes_lock:
_active_partial_writes[shard_key] -= 1
if _active_partial_writes[shard_key] == 0:
del _active_partial_writes[shard_key]
_warned_shards.discard(shard_key)

async def _encode_shard_dict(
self,
Expand Down