Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions xrspatial/geotiff/_crs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,62 @@ def _looks_like_wkt(s: str) -> bool:
return s.lstrip().upper().startswith(_WKT_ROOT_KEYWORDS)


def _validate_crs_arg(crs) -> None:
"""Reject malformed ``crs=`` arguments before they reach the writer.

Closes two gaps in the writer entry points (issue #1971):

* ``bool`` is an ``int`` subclass, so ``crs=True`` and ``crs=False``
would otherwise slip through ``isinstance(crs, int)`` and write
``EPSG=1`` / ``EPSG=0`` to the file. No CRS database resolves
those, so the result is silent metadata corruption.
* An ``int`` EPSG code that pyproj cannot resolve gets written
verbatim into ``ProjectedCSType`` / ``GeographicType``. The
file then round-trips with ``attrs['crs']`` set to the bad
value and only a ``GeoTIFFFallbackWarning`` to tell the caller
something is wrong.

Validates ``crs`` is one of ``None`` (no-op), ``int`` (a valid
EPSG code), or ``str`` (WKT/PROJ -- left for ``_wkt_to_epsg``
downstream). Pyproj is optional; the EPSG-resolves check is
skipped when pyproj is not installed, matching the rest of the
module's pyproj-optional posture. Under
``XRSPATIAL_GEOTIFF_STRICT=1`` the pyproj error is re-raised
instead of being wrapped.
"""
if crs is None:
return
if isinstance(crs, bool):
raise ValueError(
f"crs must be an int (EPSG code), str (WKT/PROJ), or None; "
f"got bool ({crs!r}). bool is an int subclass in Python, so "
f"passing True/False would otherwise be written as EPSG=1 / "
f"EPSG=0 -- neither resolves with any CRS database."
)
if isinstance(crs, int):
try:
from pyproj import CRS
except ImportError:
return
try:
CRS.from_epsg(crs)
except Exception as e:
if _geotiff_strict_mode():
raise
raise ValueError(
f"crs={crs!r} is not a valid EPSG code "
f"(pyproj: {type(e).__name__}: {e}). Pass a valid "
f"EPSG integer, a WKT string, or None."
) from e
return
if isinstance(crs, str):
return
raise TypeError(
f"crs must be int (EPSG code), str (WKT/PROJ), or None; "
f"got {type(crs).__name__}."
)


def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
"""Try to extract an EPSG code from a WKT or PROJ string.

Expand Down Expand Up @@ -149,6 +205,7 @@ def _resolve_crs_to_wkt(crs) -> str | None:
other than a string. (A string is passed through verbatim so the
WKT-only path keeps working without pyproj.)
"""
_validate_crs_arg(crs)
if crs is None:
return None
if not isinstance(crs, (int, str)):
Expand Down
4 changes: 3 additions & 1 deletion xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg
from .._geotags import GeoTransform, RASTER_PIXEL_IS_AREA
from .._runtime import (
GeoTIFFFallbackWarning,
Expand Down Expand Up @@ -499,6 +499,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
extra_tags_list = None

# Resolve crs argument: can be int (EPSG) or str (WKT/PROJ)
_validate_crs_arg(crs)
if isinstance(crs, int):
epsg = crs
elif isinstance(crs, str):
Expand Down Expand Up @@ -798,6 +799,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
os.makedirs(tiles_dir, exist_ok=True)

# Resolve CRS
_validate_crs_arg(crs)
epsg = None
wkt_fallback = None
if isinstance(crs, int):
Expand Down
3 changes: 2 additions & 1 deletion xrspatial/geotiff/_writers/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
from .._crs import _validate_crs_arg, _validate_crs_fallback, _wkt_to_epsg
from .._runtime import GeoTIFFFallbackWarning
from .._validation import (
_validate_3d_writer_dims,
Expand Down Expand Up @@ -310,6 +310,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
y_res = None
res_unit = None

_validate_crs_arg(crs)
if isinstance(crs, int):
epsg = crs
elif isinstance(crs, str):
Expand Down
99 changes: 99 additions & 0 deletions xrspatial/geotiff/tests/test_crs_arg_validation_1971.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Validate the writer entry points reject bool / unresolvable EPSG (#1971).

``bool`` is an int subclass, so ``crs=True`` used to slip through
``isinstance(crs, int)`` and write EPSG=1 to the file (with EPSG=0 for
``crs=False``). Integer EPSG codes were also written without a pyproj
round-trip, so any int that does not resolve as a CRS produced a file
with garbage in ``ProjectedCSType`` / ``GeographicType`` and only a
``GeoTIFFFallbackWarning`` to flag it.

Locks down the rejection at all three writer entry points: ``to_geotiff``
(eager), ``write_geotiff_gpu`` (GPU), and ``to_geotiff`` with
``vrt_tiled=True`` (the deprecated VRT-tiled path).
"""
from __future__ import annotations

import io

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import to_geotiff
from xrspatial.geotiff._crs import _validate_crs_arg

pyproj = pytest.importorskip("pyproj")


def _square(dtype=np.float32):
return xr.DataArray(
np.zeros((4, 4), dtype=dtype),
coords={'y': np.arange(4.0), 'x': np.arange(4.0)},
dims=('y', 'x'),
)


@pytest.mark.parametrize("bad_crs", [True, False])
def test_validate_crs_arg_rejects_bool(bad_crs):
with pytest.raises(ValueError, match="bool"):
_validate_crs_arg(bad_crs)


def test_validate_crs_arg_rejects_unresolvable_epsg():
# EPSG:1 does not exist in any CRS database.
with pytest.raises(ValueError, match="EPSG"):
_validate_crs_arg(1)


def test_validate_crs_arg_accepts_valid_epsg():
_validate_crs_arg(4326) # WGS84


def test_validate_crs_arg_accepts_none():
_validate_crs_arg(None)


def test_validate_crs_arg_accepts_str():
# Strings are deferred to ``_wkt_to_epsg`` and the WKT-fallback
# path; the entry-point validator only catches bool and bogus int.
_validate_crs_arg("EPSG:4326")
_validate_crs_arg('PROJCS["foo",GEOGCS["bar"]]')


def test_validate_crs_arg_rejects_non_int_non_str():
with pytest.raises(TypeError, match="crs must be int"):
_validate_crs_arg(4326.0)


@pytest.mark.parametrize("bad_crs", [True, False])
def test_to_geotiff_rejects_bool_crs(bad_crs):
buf = io.BytesIO()
with pytest.raises(ValueError, match="bool"):
to_geotiff(_square(), buf, crs=bad_crs)


def test_to_geotiff_rejects_unresolvable_epsg():
buf = io.BytesIO()
with pytest.raises(ValueError, match="EPSG"):
to_geotiff(_square(), buf, crs=1)


def test_to_geotiff_accepts_valid_epsg():
buf = io.BytesIO()
to_geotiff(_square(), buf, crs=4326)
assert buf.getbuffer().nbytes > 0


def test_to_geotiff_vrt_path_rejects_bool_crs(tmp_path):
# ``to_geotiff(da, '*.vrt')`` dispatches to ``_write_vrt_tiled``,
# which has its own crs resolution block. The validator runs in
# that branch too.
vrt_path = str(tmp_path / "tmp_1971_vrt_tiled.vrt")
with pytest.raises(ValueError, match="bool"):
to_geotiff(_square(), vrt_path, crs=True)


def test_to_geotiff_vrt_path_rejects_unresolvable_epsg(tmp_path):
vrt_path = str(tmp_path / "tmp_1971_vrt_bad_epsg.vrt")
with pytest.raises(ValueError, match="EPSG"):
to_geotiff(_square(), vrt_path, crs=1)
Loading