From cb14aab92fa5bdf3b60f8bd6cdf176822e96145d Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Thu, 16 Apr 2026 09:04:47 -0700 Subject: [PATCH] Cap dask graph size in read_geotiff_dask and batch adler32 transfers read_geotiff_dask built one delayed task per chunk with no upper bound. For very large files at small chunk sizes the Python graph itself OOMs the driver before any pixel read runs (30TB at chunks=256 would produce ~125M chunks, ~500M tasks, ~500GB graph on the host). Cap total chunks at 1,000,000 and auto-scale the requested chunks size upward, emitting a UserWarning so callers know their request was adjusted. _nvcomp_batch_compress on the deflate path copied every uncompressed tile GPU->CPU one at a time with .get().tobytes() purely to compute the zlib adler32 trailer. Each per-tile .get() is a sync point on the default stream. Batch all tiles into a single contiguous device buffer, transfer once, then compute adler32 from a host memoryview slice per tile. --- xrspatial/geotiff/__init__.py | 21 +++++++++++++++++++++ xrspatial/geotiff/_gpu_decode.py | 19 ++++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 8f4b6f64..4003fb85 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -937,6 +937,27 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, else: ch_h, ch_w = chunks + # Graph-size guard. Each chunk becomes a delayed task whose Python graph + # entry retains ~1KB. At very large chunk counts the graph itself OOMs + # the driver before any read executes (30TB at chunks=256 => ~500M tasks + # => ~500GB graph on host). Auto-scale chunks up to cap total task count. + _MAX_DASK_CHUNKS = 1_000_000 + n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w) + if n_chunks > _MAX_DASK_CHUNKS: + import math + scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS) + new_ch_h = int(math.ceil(ch_h * scale)) + new_ch_w = int(math.ceil(ch_w * scale)) + import warnings + warnings.warn( + f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a " + f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, " + f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to " + f"chunks=({new_ch_h}, {new_ch_w}).", + stacklevel=2, + ) + ch_h, ch_w = new_ch_h, new_ch_w + # Build dask array from delayed windowed reads rows = list(range(0, full_h, ch_h)) cols = list(range(0, full_w, ch_w)) diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py index 480d0e48..420908d0 100644 --- a/xrspatial/geotiff/_gpu_decode.py +++ b/xrspatial/geotiff/_gpu_decode.py @@ -1813,15 +1813,24 @@ class _DeflateCompOpts(ctypes.Structure): return None # For deflate, compute adler32 checksums from uncompressed tiles - # before reading compressed data (need the originals) + # before reading compressed data (need the originals). + # Batch the GPU->CPU transfer so all tiles move in a single DMA + # instead of one .get() per tile (which serializes on the default + # stream and is the dominant cost on the deflate path). adler_checksums = None if compression in (8, 32946): import zlib import struct - adler_checksums = [] - for i in range(n_tiles): - uncomp = d_tile_bufs[i].get().tobytes() - adler_checksums.append(zlib.adler32(uncomp)) + adler_checksums = [None] * n_tiles + if n_tiles > 0: + d_contig = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8) + for i in range(n_tiles): + d_contig[i * tile_bytes:(i + 1) * tile_bytes] = \ + d_tile_bufs[i][:tile_bytes] + host_view = memoryview(d_contig.get()) + for i in range(n_tiles): + adler_checksums[i] = zlib.adler32( + host_view[i * tile_bytes:(i + 1) * tile_bytes]) # Read compressed sizes and data back to CPU comp_sizes = d_comp_sizes.get().astype(int)