diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 8f4b6f64..4003fb85 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -937,6 +937,27 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, else: ch_h, ch_w = chunks + # Graph-size guard. Each chunk becomes a delayed task whose Python graph + # entry retains ~1KB. At very large chunk counts the graph itself OOMs + # the driver before any read executes (30TB at chunks=256 => ~500M tasks + # => ~500GB graph on host). Auto-scale chunks up to cap total task count. + _MAX_DASK_CHUNKS = 1_000_000 + n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w) + if n_chunks > _MAX_DASK_CHUNKS: + import math + scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS) + new_ch_h = int(math.ceil(ch_h * scale)) + new_ch_w = int(math.ceil(ch_w * scale)) + import warnings + warnings.warn( + f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a " + f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, " + f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to " + f"chunks=({new_ch_h}, {new_ch_w}).", + stacklevel=2, + ) + ch_h, ch_w = new_ch_h, new_ch_w + # Build dask array from delayed windowed reads rows = list(range(0, full_h, ch_h)) cols = list(range(0, full_w, ch_w)) diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py index 480d0e48..420908d0 100644 --- a/xrspatial/geotiff/_gpu_decode.py +++ b/xrspatial/geotiff/_gpu_decode.py @@ -1813,15 +1813,24 @@ class _DeflateCompOpts(ctypes.Structure): return None # For deflate, compute adler32 checksums from uncompressed tiles - # before reading compressed data (need the originals) + # before reading compressed data (need the originals). + # Batch the GPU->CPU transfer so all tiles move in a single DMA + # instead of one .get() per tile (which serializes on the default + # stream and is the dominant cost on the deflate path). adler_checksums = None if compression in (8, 32946): import zlib import struct - adler_checksums = [] - for i in range(n_tiles): - uncomp = d_tile_bufs[i].get().tobytes() - adler_checksums.append(zlib.adler32(uncomp)) + adler_checksums = [None] * n_tiles + if n_tiles > 0: + d_contig = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8) + for i in range(n_tiles): + d_contig[i * tile_bytes:(i + 1) * tile_bytes] = \ + d_tile_bufs[i][:tile_bytes] + host_view = memoryview(d_contig.get()) + for i in range(n_tiles): + adler_checksums[i] = zlib.adler32( + host_view[i * tile_bytes:(i + 1) * tile_bytes]) # Read compressed sizes and data back to CPU comp_sizes = d_comp_sizes.get().astype(int)