Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,27 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
else:
ch_h, ch_w = chunks

# Graph-size guard. Each chunk becomes a delayed task whose Python graph
# entry retains ~1KB. At very large chunk counts the graph itself OOMs
# the driver before any read executes (30TB at chunks=256 => ~500M tasks
# => ~500GB graph on host). Auto-scale chunks up to cap total task count.
_MAX_DASK_CHUNKS = 1_000_000
n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w)
if n_chunks > _MAX_DASK_CHUNKS:
import math
scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS)
new_ch_h = int(math.ceil(ch_h * scale))
new_ch_w = int(math.ceil(ch_w * scale))
import warnings
warnings.warn(
f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a "
f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, "
f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to "
f"chunks=({new_ch_h}, {new_ch_w}).",
stacklevel=2,
)
ch_h, ch_w = new_ch_h, new_ch_w

# Build dask array from delayed windowed reads
rows = list(range(0, full_h, ch_h))
cols = list(range(0, full_w, ch_w))
Expand Down
19 changes: 14 additions & 5 deletions xrspatial/geotiff/_gpu_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,15 +1813,24 @@ class _DeflateCompOpts(ctypes.Structure):
return None

# For deflate, compute adler32 checksums from uncompressed tiles
# before reading compressed data (need the originals)
# before reading compressed data (need the originals).
# Batch the GPU->CPU transfer so all tiles move in a single DMA
# instead of one .get() per tile (which serializes on the default
# stream and is the dominant cost on the deflate path).
adler_checksums = None
if compression in (8, 32946):
import zlib
import struct
adler_checksums = []
for i in range(n_tiles):
uncomp = d_tile_bufs[i].get().tobytes()
adler_checksums.append(zlib.adler32(uncomp))
adler_checksums = [None] * n_tiles
if n_tiles > 0:
d_contig = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8)
for i in range(n_tiles):
d_contig[i * tile_bytes:(i + 1) * tile_bytes] = \
d_tile_bufs[i][:tile_bytes]
host_view = memoryview(d_contig.get())
for i in range(n_tiles):
adler_checksums[i] = zlib.adler32(
host_view[i * tile_bytes:(i + 1) * tile_bytes])

# Read compressed sizes and data back to CPU
comp_sizes = d_comp_sizes.get().astype(int)
Expand Down
Loading