diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index e93f525c..8f4b6f64 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -185,7 +185,8 @@ def open_geotiff(source: str, *, dtype=None, window=None, band: int | None = None, name: str | None = None, chunks: int | tuple | None = None, - gpu: bool = False) -> xr.DataArray: + gpu: bool = False, + max_pixels: int | None = None) -> xr.DataArray: """Read a GeoTIFF, COG, or VRT file into an xarray.DataArray. Automatically dispatches to the best backend: @@ -216,6 +217,10 @@ def open_geotiff(source: str, *, dtype=None, window=None, Chunk size for Dask lazy reading. gpu : bool Use GPU-accelerated decompression (requires cupy + nvCOMP). + max_pixels : int or None + Maximum allowed pixel count (width * height * samples). None + uses the default (~1 billion). Raise to read legitimately + large files. Returns ------- @@ -225,22 +230,28 @@ def open_geotiff(source: str, *, dtype=None, window=None, # VRT files if source.lower().endswith('.vrt'): return read_vrt(source, dtype=dtype, window=window, band=band, - name=name, chunks=chunks, gpu=gpu) + name=name, chunks=chunks, gpu=gpu, + max_pixels=max_pixels) # GPU path if gpu: return read_geotiff_gpu(source, dtype=dtype, overview_level=overview_level, - name=name, chunks=chunks) + name=name, chunks=chunks, + max_pixels=max_pixels) # Dask path (CPU) if chunks is not None: return read_geotiff_dask(source, dtype=dtype, chunks=chunks, overview_level=overview_level, name=name) + kwargs = {} + if max_pixels is not None: + kwargs['max_pixels'] = max_pixels arr, geo_info = read_to_array( source, window=window, overview_level=overview_level, band=band, + **kwargs, ) height, width = arr.shape[:2] @@ -995,7 +1006,8 @@ def read_geotiff_gpu(source: str, *, dtype=None, overview_level: int | None = None, name: str | None = None, - chunks: int | tuple | None = None) -> xr.DataArray: + chunks: int | tuple | None = None, + max_pixels: int | None = None) -> xr.DataArray: """Read a GeoTIFF with GPU-accelerated decompression via Numba CUDA. Decompresses all tiles in parallel on the GPU and returns a @@ -1018,6 +1030,9 @@ def read_geotiff_gpu(source: str, *, chunks, (row, col) tuple for rectangular. name : str or None Name for the DataArray. + max_pixels : int or None + Maximum allowed pixel count (width * height * samples). None + uses the default (~1 billion). Returns ------- @@ -1031,12 +1046,15 @@ def read_geotiff_gpu(source: str, *, "cupy is required for GPU reads. " "Install it with: pip install cupy-cuda12x") - from ._reader import _FileSource + from ._reader import _FileSource, _check_dimensions, MAX_PIXELS_DEFAULT from ._header import parse_header, parse_all_ifds from ._dtypes import tiff_dtype_to_numpy from ._geotags import extract_geo_info from ._gpu_decode import gpu_decode_tiles + if max_pixels is None: + max_pixels = MAX_PIXELS_DEFAULT + # Parse metadata on CPU (fast, <1ms) src = _FileSource(source) data = src.read_all() @@ -1088,6 +1106,8 @@ def read_geotiff_gpu(source: str, *, width = ifd.width height = ifd.height + _check_dimensions(width, height, samples, max_pixels) + finally: src.close() @@ -1326,7 +1346,8 @@ def read_vrt(source: str, *, dtype=None, window=None, band: int | None = None, name: str | None = None, chunks: int | tuple | None = None, - gpu: bool = False) -> xr.DataArray: + gpu: bool = False, + max_pixels: int | None = None) -> xr.DataArray: """Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray. The VRT's source GeoTIFFs are read via windowed reads and assembled @@ -1358,7 +1379,8 @@ def read_vrt(source: str, *, dtype=None, window=None, """ from ._vrt import read_vrt as _read_vrt_internal - arr, vrt = _read_vrt_internal(source, window=window, band=band) + arr, vrt = _read_vrt_internal(source, window=window, band=band, + max_pixels=max_pixels) if name is None: import os diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 616c84b1..87c9cc03 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -192,7 +192,8 @@ def parse_vrt(xml_str: str, vrt_dir: str = '.') -> VRTDataset: def read_vrt(vrt_path: str, *, window=None, - band: int | None = None) -> tuple[np.ndarray, VRTDataset]: + band: int | None = None, + max_pixels: int | None = None) -> tuple[np.ndarray, VRTDataset]: """Read a VRT file by assembling pixel data from its source files. Parameters @@ -228,6 +229,12 @@ def read_vrt(vrt_path: str, *, window=None, out_h = r1 - r0 out_w = c1 - c0 + from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT + if max_pixels is None: + max_pixels = MAX_PIXELS_DEFAULT + n_bands = len([vrt.bands[band]] if band is not None else vrt.bands) + _check_dimensions(out_w, out_h, n_bands, max_pixels) + # Select bands if band is not None: selected_bands = [vrt.bands[band]] diff --git a/xrspatial/geotiff/tests/test_security.py b/xrspatial/geotiff/tests/test_security.py index a3dd8c88..a230b28b 100644 --- a/xrspatial/geotiff/tests/test_security.py +++ b/xrspatial/geotiff/tests/test_security.py @@ -3,6 +3,7 @@ Tests for: - Unbounded allocation guard (issue #1184) - VRT path traversal prevention (issue #1185) +- GPU read and VRT read allocation guards (issue #1195) """ from __future__ import annotations @@ -116,6 +117,80 @@ def test_normal_read_unaffected(self, tmp_path): arr, _ = read_to_array(path) np.testing.assert_array_equal(arr, expected) + def test_open_geotiff_max_pixels(self, tmp_path): + """open_geotiff passes max_pixels through to the reader.""" + from xrspatial.geotiff import open_geotiff + + expected = np.arange(16, dtype=np.float32).reshape(4, 4) + data = make_minimal_tiff(4, 4, np.dtype('float32'), pixel_data=expected) + path = str(tmp_path / "small_1195.tif") + with open(path, 'wb') as f: + f.write(data) + + # Should succeed with generous limit + da = open_geotiff(path, max_pixels=1_000_000) + np.testing.assert_array_equal(da.values, expected) + + # Should fail with tiny limit + with pytest.raises(ValueError, match="exceed the safety limit"): + open_geotiff(path, max_pixels=10) + + +# --------------------------------------------------------------------------- +# Cat 1b: VRT allocation guard (issue #1195) +# --------------------------------------------------------------------------- + +class TestVRTAllocationGuard: + def test_read_vrt_rejects_huge_dimensions(self, tmp_path): + """read_vrt refuses to allocate when VRT XML claims huge dims.""" + from xrspatial.geotiff._vrt import read_vrt as _read_vrt_internal + + # Create a VRT with oversized dimensions but no actual source data + # needed -- _check_dimensions fires before any file reads + vrt_xml = ''' + + +''' + + vrt_path = str(tmp_path / "huge_1195.vrt") + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + with pytest.raises(ValueError, match="exceed the safety limit"): + _read_vrt_internal(vrt_path, max_pixels=1_000_000) + + def test_read_vrt_normal_size_ok(self, tmp_path): + """Normal-sized VRT passes the allocation guard.""" + from xrspatial.geotiff._vrt import read_vrt as _read_vrt_internal + + vrt_xml = ''' + + +''' + + vrt_path = str(tmp_path / "small_1195.vrt") + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + # Should not raise -- 4x4x1 = 16 pixels + arr, vrt = _read_vrt_internal(vrt_path, max_pixels=1_000_000) + assert arr.shape == (4, 4) + + def test_open_geotiff_vrt_max_pixels(self, tmp_path): + """open_geotiff passes max_pixels through to VRT reader.""" + from xrspatial.geotiff import open_geotiff + + vrt_xml = ''' + + +''' + + vrt_path = str(tmp_path / "huge_vrt_1195.vrt") + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + with pytest.raises(ValueError, match="exceed the safety limit"): + open_geotiff(vrt_path, max_pixels=1_000_000) # --------------------------------------------------------------------------- # Cat 5: VRT path traversal