Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def open_geotiff(source: str, *, dtype=None, window=None,
band: int | None = None,
name: str | None = None,
chunks: int | tuple | None = None,
gpu: bool = False) -> xr.DataArray:
gpu: bool = False,
max_pixels: int | None = None) -> xr.DataArray:
"""Read a GeoTIFF, COG, or VRT file into an xarray.DataArray.

Automatically dispatches to the best backend:
Expand Down Expand Up @@ -216,6 +217,10 @@ def open_geotiff(source: str, *, dtype=None, window=None,
Chunk size for Dask lazy reading.
gpu : bool
Use GPU-accelerated decompression (requires cupy + nvCOMP).
max_pixels : int or None
Maximum allowed pixel count (width * height * samples). None
uses the default (~1 billion). Raise to read legitimately
large files.

Returns
-------
Expand All @@ -225,22 +230,28 @@ def open_geotiff(source: str, *, dtype=None, window=None,
# VRT files
if source.lower().endswith('.vrt'):
return read_vrt(source, dtype=dtype, window=window, band=band,
name=name, chunks=chunks, gpu=gpu)
name=name, chunks=chunks, gpu=gpu,
max_pixels=max_pixels)

# GPU path
if gpu:
return read_geotiff_gpu(source, dtype=dtype,
overview_level=overview_level,
name=name, chunks=chunks)
name=name, chunks=chunks,
max_pixels=max_pixels)

# Dask path (CPU)
if chunks is not None:
return read_geotiff_dask(source, dtype=dtype, chunks=chunks,
overview_level=overview_level, name=name)

kwargs = {}
if max_pixels is not None:
kwargs['max_pixels'] = max_pixels
arr, geo_info = read_to_array(
source, window=window,
overview_level=overview_level, band=band,
**kwargs,
)

height, width = arr.shape[:2]
Expand Down Expand Up @@ -995,7 +1006,8 @@ def read_geotiff_gpu(source: str, *,
dtype=None,
overview_level: int | None = None,
name: str | None = None,
chunks: int | tuple | None = None) -> xr.DataArray:
chunks: int | tuple | None = None,
max_pixels: int | None = None) -> xr.DataArray:
"""Read a GeoTIFF with GPU-accelerated decompression via Numba CUDA.

Decompresses all tiles in parallel on the GPU and returns a
Expand All @@ -1018,6 +1030,9 @@ def read_geotiff_gpu(source: str, *,
chunks, (row, col) tuple for rectangular.
name : str or None
Name for the DataArray.
max_pixels : int or None
Maximum allowed pixel count (width * height * samples). None
uses the default (~1 billion).

Returns
-------
Expand All @@ -1031,12 +1046,15 @@ def read_geotiff_gpu(source: str, *,
"cupy is required for GPU reads. "
"Install it with: pip install cupy-cuda12x")

from ._reader import _FileSource
from ._reader import _FileSource, _check_dimensions, MAX_PIXELS_DEFAULT
from ._header import parse_header, parse_all_ifds
from ._dtypes import tiff_dtype_to_numpy
from ._geotags import extract_geo_info
from ._gpu_decode import gpu_decode_tiles

if max_pixels is None:
max_pixels = MAX_PIXELS_DEFAULT

# Parse metadata on CPU (fast, <1ms)
src = _FileSource(source)
data = src.read_all()
Expand Down Expand Up @@ -1088,6 +1106,8 @@ def read_geotiff_gpu(source: str, *,
width = ifd.width
height = ifd.height

_check_dimensions(width, height, samples, max_pixels)

finally:
src.close()

Expand Down Expand Up @@ -1326,7 +1346,8 @@ def read_vrt(source: str, *, dtype=None, window=None,
band: int | None = None,
name: str | None = None,
chunks: int | tuple | None = None,
gpu: bool = False) -> xr.DataArray:
gpu: bool = False,
max_pixels: int | None = None) -> xr.DataArray:
"""Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray.

The VRT's source GeoTIFFs are read via windowed reads and assembled
Expand Down Expand Up @@ -1358,7 +1379,8 @@ def read_vrt(source: str, *, dtype=None, window=None,
"""
from ._vrt import read_vrt as _read_vrt_internal

arr, vrt = _read_vrt_internal(source, window=window, band=band)
arr, vrt = _read_vrt_internal(source, window=window, band=band,
max_pixels=max_pixels)

if name is None:
import os
Expand Down
9 changes: 8 additions & 1 deletion xrspatial/geotiff/_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def parse_vrt(xml_str: str, vrt_dir: str = '.') -> VRTDataset:


def read_vrt(vrt_path: str, *, window=None,
band: int | None = None) -> tuple[np.ndarray, VRTDataset]:
band: int | None = None,
max_pixels: int | None = None) -> tuple[np.ndarray, VRTDataset]:
"""Read a VRT file by assembling pixel data from its source files.

Parameters
Expand Down Expand Up @@ -228,6 +229,12 @@ def read_vrt(vrt_path: str, *, window=None,
out_h = r1 - r0
out_w = c1 - c0

from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT
if max_pixels is None:
max_pixels = MAX_PIXELS_DEFAULT
n_bands = len([vrt.bands[band]] if band is not None else vrt.bands)
_check_dimensions(out_w, out_h, n_bands, max_pixels)

# Select bands
if band is not None:
selected_bands = [vrt.bands[band]]
Expand Down
75 changes: 75 additions & 0 deletions xrspatial/geotiff/tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Tests for:
- Unbounded allocation guard (issue #1184)
- VRT path traversal prevention (issue #1185)
- GPU read and VRT read allocation guards (issue #1195)
"""
from __future__ import annotations

Expand Down Expand Up @@ -116,6 +117,80 @@ def test_normal_read_unaffected(self, tmp_path):
arr, _ = read_to_array(path)
np.testing.assert_array_equal(arr, expected)

def test_open_geotiff_max_pixels(self, tmp_path):
"""open_geotiff passes max_pixels through to the reader."""
from xrspatial.geotiff import open_geotiff

expected = np.arange(16, dtype=np.float32).reshape(4, 4)
data = make_minimal_tiff(4, 4, np.dtype('float32'), pixel_data=expected)
path = str(tmp_path / "small_1195.tif")
with open(path, 'wb') as f:
f.write(data)

# Should succeed with generous limit
da = open_geotiff(path, max_pixels=1_000_000)
np.testing.assert_array_equal(da.values, expected)

# Should fail with tiny limit
with pytest.raises(ValueError, match="exceed the safety limit"):
open_geotiff(path, max_pixels=10)


# ---------------------------------------------------------------------------
# Cat 1b: VRT allocation guard (issue #1195)
# ---------------------------------------------------------------------------

class TestVRTAllocationGuard:
def test_read_vrt_rejects_huge_dimensions(self, tmp_path):
"""read_vrt refuses to allocate when VRT XML claims huge dims."""
from xrspatial.geotiff._vrt import read_vrt as _read_vrt_internal

# Create a VRT with oversized dimensions but no actual source data
# needed -- _check_dimensions fires before any file reads
vrt_xml = '''<VRTDataset rasterXSize="100000" rasterYSize="100000">
<VRTRasterBand dataType="Float32" band="1">
</VRTRasterBand>
</VRTDataset>'''

vrt_path = str(tmp_path / "huge_1195.vrt")
with open(vrt_path, 'w') as f:
f.write(vrt_xml)

with pytest.raises(ValueError, match="exceed the safety limit"):
_read_vrt_internal(vrt_path, max_pixels=1_000_000)

def test_read_vrt_normal_size_ok(self, tmp_path):
"""Normal-sized VRT passes the allocation guard."""
from xrspatial.geotiff._vrt import read_vrt as _read_vrt_internal

vrt_xml = '''<VRTDataset rasterXSize="4" rasterYSize="4">
<VRTRasterBand dataType="Float32" band="1">
</VRTRasterBand>
</VRTDataset>'''

vrt_path = str(tmp_path / "small_1195.vrt")
with open(vrt_path, 'w') as f:
f.write(vrt_xml)

# Should not raise -- 4x4x1 = 16 pixels
arr, vrt = _read_vrt_internal(vrt_path, max_pixels=1_000_000)
assert arr.shape == (4, 4)

def test_open_geotiff_vrt_max_pixels(self, tmp_path):
"""open_geotiff passes max_pixels through to VRT reader."""
from xrspatial.geotiff import open_geotiff

vrt_xml = '''<VRTDataset rasterXSize="100000" rasterYSize="100000">
<VRTRasterBand dataType="Float32" band="1">
</VRTRasterBand>
</VRTDataset>'''

vrt_path = str(tmp_path / "huge_vrt_1195.vrt")
with open(vrt_path, 'w') as f:
f.write(vrt_xml)

with pytest.raises(ValueError, match="exceed the safety limit"):
open_geotiff(vrt_path, max_pixels=1_000_000)

# ---------------------------------------------------------------------------
# Cat 5: VRT path traversal
Expand Down
Loading