diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 8124ea44ea..3b31ac8483 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,5 +1,7 @@ from __future__ import annotations +import threading +import warnings from collections.abc import Iterable, Mapping, MutableMapping from dataclasses import dataclass, replace from enum import Enum @@ -51,6 +53,7 @@ morton_order_iter, ) from zarr.core.metadata.v3 import parse_codecs +from zarr.errors import ZarrUserWarning from zarr.registry import get_ndbuffer_class, get_pipeline_class if TYPE_CHECKING: @@ -64,6 +67,64 @@ ShardMapping = Mapping[tuple[int, ...], Buffer | None] ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None] +# Track in-progress partial shard writes to detect concurrent access +# Key: shard identifier (store_id, path), Value: count of active writes +_active_partial_writes: dict[tuple[int, str], int] = {} +_active_partial_writes_lock = threading.Lock() +_warned_shards: set[tuple[int, str]] = set() # Track shards we've already warned about + + +def _get_shard_key(byte_setter: ByteSetter) -> tuple[int, str]: + """Get a unique key for a shard from its byte_setter. + + Uses the store's id and path to create a unique identifier. + """ + path = getattr(byte_setter, "path", "") + store = getattr(byte_setter, "store", byte_setter) + return (id(store), path) + + +def _is_full_shard_selection(selection: SelectorTuple, shard_shape: tuple[int, ...]) -> bool: + """Check if a selection covers the entire shard. + + Parameters + ---------- + selection : SelectorTuple + The selection (tuple of slices) being written to + shard_shape : tuple[int, ...] + The shape of the shard + + Returns + ------- + bool + True if selection covers the entire shard, False otherwise + """ + # SelectorTuple can be a tuple, ndarray, or slice + # Only tuple selections can potentially cover a full shard + if not isinstance(selection, tuple): + return False + + if len(selection) != len(shard_shape): + return False + + for sel, dim_size in zip(selection, shard_shape, strict=False): + if isinstance(sel, slice): + start = sel.start or 0 + stop = sel.stop if sel.stop is not None else dim_size + step = sel.step or 1 + # Check if this slice covers the full dimension + if start != 0 or stop != dim_size or step != 1: + return False + elif isinstance(sel, int): + # Single index selection never covers full dimension (unless dim_size == 1) + if dim_size != 1: + return False + else: + # Unknown selection type, assume partial + return False + + return True + class ShardingCodecIndexLocation(Enum): """ @@ -499,43 +560,75 @@ async def _encode_partial_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - shard_reader = await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, - ) - shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} + # Check if we need to track this write for concurrent access detection + # Only track partial writes (full shard writes don't have race conditions) + shard_key: tuple[int, str] | None = None + if not _is_full_shard_selection(selection, shard_shape): + shard_key = _get_shard_key(byte_setter) + with _active_partial_writes_lock: + current_count = _active_partial_writes.get(shard_key, 0) + # Warn only on concurrent access (another write already in progress) + if current_count > 0 and shard_key not in _warned_shards: + _warned_shards.add(shard_key) + warnings.warn( + "Concurrent partial shard writes detected. " + "Writing to different regions of the same shard concurrently " + "may result in data corruption due to read-modify-write race conditions. " + "Consider aligning your write chunks with shard boundaries, " + "or use a lock to coordinate writes.", + ZarrUserWarning, + stacklevel=6, + ) + _active_partial_writes[shard_key] = current_count + 1 - indexer = list( - get_indexer( - selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) + try: + shard_reader = await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, ) - ) - - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) + shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} + + indexer = list( + get_indexer( + selection, + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) - buf = await self._encode_shard_dict( - shard_dict, - chunks_per_shard=chunks_per_shard, - buffer_prototype=default_buffer_prototype(), - ) + ) - if buf is None: - await byte_setter.delete() - else: - await byte_setter.set(buf) + await self.codec_pipeline.write( + [ + ( + _ShardingByteSetter(shard_dict, chunk_coords), + chunk_spec, + chunk_selection, + out_selection, + is_complete_shard, + ) + for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer + ], + shard_array, + ) + buf = await self._encode_shard_dict( + shard_dict, + chunks_per_shard=chunks_per_shard, + buffer_prototype=default_buffer_prototype(), + ) + + if buf is None: + await byte_setter.delete() + else: + await byte_setter.set(buf) + finally: + # Clean up tracking if we were tracking this write + if shard_key is not None: + with _active_partial_writes_lock: + _active_partial_writes[shard_key] -= 1 + if _active_partial_writes[shard_key] == 0: + del _active_partial_writes[shard_key] + _warned_shards.discard(shard_key) async def _encode_shard_dict( self,