diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index d3bf731..774a0ce 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -202,7 +202,7 @@ validated_df: dy.DataFrame[MySchema] = MySchema.validate(df, cast=True) 1. **Python code**: Run `pixi run pre-commit run` before committing 2. **Rust code**: Run `pixi run postinstall` to rebuild, then run tests -3. **Tests**: Ensure `pixi run test` passes +3. **Tests**: Ensure `pixi run test` passes. If changes might affect storage backends, use `pixi run test -m s3`. 4. **Documentation**: Update docstrings 5. **API changes**: Ensure backward compatibility or document migration path diff --git a/dataframely/__init__.py b/dataframely/__init__.py index 616fb04..802d923 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -51,6 +51,7 @@ UInt64, ) from .config import Config +from .exc import DeserializationError from .filter_result import FailureInfo from .functional import ( concat_collection_members, @@ -106,4 +107,5 @@ "Array", "Object", "Validation", + "DeserializationError", ] diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 1576f9f..74b92e4 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -12,7 +12,7 @@ from dataclasses import asdict from json import JSONDecodeError from pathlib import Path -from typing import IO, Annotated, Any, Literal, cast +from typing import IO, Annotated, Any, Literal, cast, overload import polars as pl import polars.exceptions as plexc @@ -33,7 +33,11 @@ from dataframely._storage.delta import DeltaStorageBackend from dataframely._storage.parquet import ParquetStorageBackend from dataframely._typing import LazyFrame, Validation -from dataframely.exc import ValidationError, ValidationRequiredError +from dataframely.exc import ( + DeserializationError, + ValidationError, + ValidationRequiredError, +) from dataframely.filter_result import FailureInfo from dataframely.random import Generator from dataframely.schema import _schema_from_dict @@ -891,13 +895,13 @@ def read_parquet( - `"allow"`: The method tries to read the schema data from the parquet files. If the stored collection schema matches this collection schema, the collection is read without validation. If the stored - schema mismatches this schema no metadata can be found in + schema mismatches this schema, no valid metadata can be found in the parquets, or the files have conflicting metadata, this method automatically runs :meth:`validate` with `cast=True`. - `"warn"`: The method behaves similarly to `"allow"`. However, it prints a warning if validation is necessary. - `"forbid"`: The method never runs validation automatically and only - returns if the metadata stores a collection schema that matches + returns if the metadata stores a valid collection schema that matches this collection. - `"skip"`: The method never runs validation and simply reads the data, entrusting the user that the schema is valid. *Use this option @@ -1184,7 +1188,12 @@ def _read( members=cls.member_schemas().keys(), **kwargs ) - collection_types = _deserialize_types(serialized_collection_types) + # Use strict=False when validation is "allow", "warn" or "skip" to tolerate + # missing or broken collection metadata. + strict = validation == "forbid" + collection_types = _deserialize_types( + serialized_collection_types, strict=strict + ) collection_type = _reconcile_collection_types(collection_types) if cls._requires_validation_for_reading_parquets(collection_type, validation): @@ -1245,14 +1254,27 @@ def read_parquet_metadata_collection( """ metadata = pl.read_parquet_metadata(source) if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None: - try: - return deserialize_collection(schema_metadata) - except (JSONDecodeError, plexc.ComputeError): - return None + return deserialize_collection(schema_metadata, strict=False) return None -def deserialize_collection(data: str) -> type[Collection]: +@overload +def deserialize_collection( + data: str, strict: Literal[True] = True +) -> type[Collection]: ... + + +@overload +def deserialize_collection( + data: str, strict: Literal[False] +) -> type[Collection] | None: ... + + +@overload +def deserialize_collection(data: str, strict: bool) -> type[Collection] | None: ... + + +def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | None: """Deserialize a collection from a JSON string. This method allows to dynamically load a collection from its serialization, without @@ -1260,12 +1282,14 @@ def deserialize_collection(data: str) -> type[Collection]: Args: data: The JSON string created via :meth:`Collection.serialize`. + strict: Whether to raise an exception if the collection cannot be deserialized. Returns: The collection loaded from the JSON data. Raises: - ValueError: If the schema format version is not supported. + DeserializationError: If the collection can not be deserialized + and `strict=True`. Attention: The returned collection **cannot** be used to create instances of the @@ -1280,34 +1304,41 @@ def deserialize_collection(data: str) -> type[Collection]: See also: :meth:`Collection.serialize` for additional information on serialization. """ - decoded = json.loads(data, cls=SchemaJSONDecoder) - if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: - raise ValueError(f"Unsupported schema format version: {format}") - - annotations: dict[str, Any] = {} - for name, info in decoded["members"].items(): - lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore - if info["is_optional"]: - lf_type = lf_type | None # type: ignore - annotations[name] = Annotated[ - lf_type, - CollectionMember( - ignored_in_filters=info["ignored_in_filters"], - inline_for_sampling=info["inline_for_sampling"], - ), - ] - - return type( - f"{decoded['name']}_dynamic", - (Collection,), - { - "__annotations__": annotations, - **{ - name: Filter(logic=lambda _, logic=logic: logic) # type: ignore - for name, logic in decoded["filters"].items() + try: + decoded = json.loads(data, cls=SchemaJSONDecoder) + if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: + raise ValueError(f"Unsupported schema format version: {format}") + + annotations: dict[str, Any] = {} + for name, info in decoded["members"].items(): + lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore + if info["is_optional"]: + lf_type = lf_type | None # type: ignore + annotations[name] = Annotated[ + lf_type, + CollectionMember( + ignored_in_filters=info["ignored_in_filters"], + inline_for_sampling=info["inline_for_sampling"], + ), + ] + + return type( + f"{decoded['name']}_dynamic", + (Collection,), + { + "__annotations__": annotations, + **{ + name: Filter(logic=lambda _, logic=logic: logic) # type: ignore + for name, logic in decoded["filters"].items() + }, }, - }, - ) + ) + except (ValueError, TypeError, JSONDecodeError, plexc.ComputeError) as e: + if strict: + raise DeserializationError( + "The collection could not be deserialized" + ) from e + return None # --------------------------------------- UTILS -------------------------------------- # @@ -1333,14 +1364,15 @@ def _extract_keys_if_exist( def _deserialize_types( serialized_collection_types: Iterable[str | None], + strict: bool = True, ) -> list[type[Collection]]: collection_types = [] - collection_type: type[Collection] | None = None for t in serialized_collection_types: if t is None: continue - collection_type = deserialize_collection(t) - collection_types.append(collection_type) + collection_type = deserialize_collection(t, strict=strict) + if collection_type is not None: + collection_types.append(collection_type) return collection_types diff --git a/dataframely/exc.py b/dataframely/exc.py index a87497a..bc32420 100644 --- a/dataframely/exc.py +++ b/dataframely/exc.py @@ -41,3 +41,10 @@ def __init__(self, attr: str, kls: type) -> None: class ValidationRequiredError(Exception): """Error raised when validation is required when reading a parquet file.""" + + +# ---------------------------------- DESERIALIZATION --------------------------------- # + + +class DeserializationError(Exception): + """Error raised when deserialization of a schema or collection fails.""" diff --git a/dataframely/schema.py b/dataframely/schema.py index c457ec7..64096bf 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -40,7 +40,12 @@ from ._typing import DataFrame, LazyFrame, Validation from .columns import Column, column_from_dict from .config import Config -from .exc import SchemaError, ValidationError, ValidationRequiredError +from .exc import ( + DeserializationError, + SchemaError, + ValidationError, + ValidationRequiredError, +) from .filter_result import FailureInfo, FilterResult, LazyFilterResult from .random import Generator @@ -1238,8 +1243,13 @@ def _validate_if_needed( validation: Validation, source: str, ) -> DataFrame[Self] | LazyFrame[Self]: + # Use strict=False when validation is "allow", "warn" or "skip" to tolerate + # deserialization failures from old serialized formats. + strict = validation == "forbid" deserialized_schema = ( - deserialize_schema(serialized_schema) if serialized_schema else None + deserialize_schema(serialized_schema, strict=strict) + if serialized_schema + else None ) # Smart validation @@ -1347,6 +1357,10 @@ def deserialize_schema(data: str, strict: Literal[True] = True) -> type[Schema]: def deserialize_schema(data: str, strict: Literal[False]) -> type[Schema] | None: ... +@overload +def deserialize_schema(data: str, strict: bool) -> type[Schema] | None: ... + + def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None: """Deserialize a schema from a JSON string. @@ -1375,9 +1389,9 @@ def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None: if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: raise ValueError(f"Unsupported schema format version: {format}") return _schema_from_dict(decoded) - except (ValueError, JSONDecodeError, plexc.ComputeError) as e: + except (ValueError, JSONDecodeError, plexc.ComputeError, TypeError) as e: if strict: - raise e from e + raise DeserializationError("The schema could not be deserialized") from e return None diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 9807e6c..b4cf14b 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -29,11 +29,11 @@ class SchemaStorageTester(ABC): def write_typed( self, schema: type[S], df: dy.DataFrame[S], path: str, lazy: bool ) -> None: - """Write a schema to the backend without recording schema information.""" + """Write a schema to the backend and record schema information.""" @abstractmethod def write_untyped(self, df: pl.DataFrame, path: str, lazy: bool) -> None: - """Write a schema to the backend and record schema information.""" + """Write a schema to the backend without recording schema information.""" @overload def read( @@ -45,12 +45,22 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + @abstractmethod def read( self, schema: type[S], path: str, lazy: bool, validation: Validation ) -> dy.LazyFrame[S] | dy.DataFrame[S]: """Read from the backend, using schema information if available.""" + @abstractmethod + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + """Overwrite the metadata stored at the given path with the provided + metadata.""" + class ParquetSchemaStorageTester(SchemaStorageTester): """Testing interface for the parquet storage functionality of Schema.""" @@ -83,6 +93,11 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + def read( self, schema: type[S], path: str, lazy: bool, validation: Validation ) -> dy.LazyFrame[S] | dy.DataFrame[S]: @@ -93,6 +108,11 @@ def read( else: return schema.read_parquet(self._wrap_path(path), validation=validation) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + target = self._wrap_path(path) + data = pl.read_parquet(target) + data.write_parquet(target, metadata=metadata) + class DeltaSchemaStorageTester(SchemaStorageTester): """Testing interface for the deltalake storage functionality of Schema.""" @@ -115,6 +135,11 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + def read( self, schema: type[S], path: str, lazy: bool, validation: Validation ) -> dy.DataFrame[S] | dy.LazyFrame[S]: @@ -122,6 +147,18 @@ def read( return schema.scan_delta(path, validation=validation) return schema.read_delta(path, validation=validation) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + df = pl.read_delta(path) + df.head(0).write_delta( + path, + delta_write_options={ + "commit_properties": deltalake.CommitProperties( + custom_metadata=metadata + ), + }, + mode="overwrite", + ) + # ------------------------------- Collection ------------------------------------------- @@ -147,11 +184,36 @@ def write_untyped( def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: """Read from the backend, using collection information if available.""" + @abstractmethod + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + """Overwrite the metadata stored at the given path with the provided + metadata.""" + + def _prefix_path(self, path: str, fs: AbstractFileSystem) -> str: + return f"{self._get_prefix(fs)}{path}" + + @staticmethod + def _get_prefix(fs: AbstractFileSystem) -> str: + return ( + "" + if fs.protocol == "file" + else ( + f"{fs.protocol}://" + if isinstance(fs.protocol, str) + else f"{fs.protocol[0]}://" + ) + ) + class ParquetCollectionStorageTester(CollectionStorageTester): def write_typed( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "`metadata` kwarg will be ignored in `write_typed`. Use `set_metadata`." + ) + # Polars does not support partitioning via kwarg on sink_parquet if lazy: kwargs.pop("partition_by", None) @@ -164,6 +226,11 @@ def write_typed( def write_untyped( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "Cannot set metadata through `write_untyped`. Use `set_metadata`." + ) + if lazy: collection.sink_parquet(path, **kwargs) else: @@ -175,17 +242,8 @@ def _delete_meta(file: str) -> None: df.write_parquet(file) fs: AbstractFileSystem = url_to_fs(path)[0] - prefix = ( - "" - if fs.protocol == "file" - else ( - f"{fs.protocol}://" - if isinstance(fs.protocol, str) - else f"{fs.protocol[0]}://" - ) - ) for file in fs.glob(fs.sep.join([path, "**", "*.parquet"])): - _delete_meta(f"{prefix}{file}") + _delete_meta(self._prefix_path(file, fs)) def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: if lazy: @@ -193,11 +251,23 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: else: return collection.read_parquet(path, **kwargs) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + fs: AbstractFileSystem = url_to_fs(path)[0] + for file in fs.glob(fs.sep.join([path, "*.parquet"])): + file_path = self._prefix_path(file, fs) + df = pl.read_parquet(file_path) + df.write_parquet(file_path, metadata=metadata) + class DeltaCollectionStorageTester(CollectionStorageTester): def write_typed( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "`metadata` kwarg will be ignored in `write_typed`. Use `set_metadata`." + ) + extra_kwargs = {} if partition_by := kwargs.pop("partition_by", None): extra_kwargs["delta_write_options"] = {"partition_by": partition_by} @@ -207,6 +277,10 @@ def write_typed( def write_untyped( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "Cannot set metadata through `write_untyped`. Use `set_metadata`." + ) collection.write_delta(path, **kwargs) # For each member table, write an empty commit @@ -222,6 +296,23 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: return collection.scan_delta(source=path, **kwargs) return collection.read_delta(source=path, **kwargs) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + fs: AbstractFileSystem = url_to_fs(path)[0] + # For delta, we need to update metadata on each member table + for entry in fs.ls(path): + member_path = self._prefix_path(entry, fs) + if fs.isdir(member_path): + df = pl.read_delta(member_path) + df.head(0).write_delta( + member_path, + delta_write_options={ + "commit_properties": deltalake.CommitProperties( + custom_metadata=metadata + ), + }, + mode="overwrite", + ) + # ------------------------------------ Failure info ------------------------------------ class FailureInfoStorageTester(ABC): diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 7455913..d7c6bc5 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -87,7 +87,39 @@ def test_roundtrip_matches(collection: type[dy.Collection]) -> None: # ----------------------------- DESERIALIZATION FAILURES ----------------------------- # -def test_deserialize_unknown_format_version() -> None: +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_unknown_format_version(strict: bool) -> None: serialized = '{"versions": {"format": "invalid"}}' - with pytest.raises(ValueError, match=r"Unsupported schema format version"): - dy.deserialize_collection(serialized) + if strict: + with pytest.raises(dy.DeserializationError): + dy.deserialize_collection(serialized) + else: + assert dy.deserialize_collection(serialized, strict=False) is None + + +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_invalid_json_strict_false(strict: bool) -> None: + serialized = '{"invalid json' + if strict: + with pytest.raises(dy.DeserializationError): + dy.deserialize_collection(serialized, strict=True) + else: + assert dy.deserialize_collection(serialized, strict=False) is None + + +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_invalid_member_schema(strict: bool) -> None: + collection = create_collection( + "test", + { + "s1": create_schema("schema1", {"a": dy.Int64()}), + }, + ) + serialized = collection.serialize() + broken = serialized.replace("primary_key", "primary_keys") + + if strict: + with pytest.raises(dy.DeserializationError): + dy.deserialize_collection(broken, strict=strict) + else: + assert dy.deserialize_collection(broken, strict=False) is None diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 6dee93f..9ff416d 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -1,7 +1,6 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause -import warnings from typing import Any import polars as pl @@ -14,7 +13,7 @@ from dataframely._storage.constants import COLLECTION_METADATA_KEY from dataframely._storage.delta import DeltaStorageBackend from dataframely.collection.collection import _reconcile_collection_types -from dataframely.exc import ValidationRequiredError +from dataframely.exc import DeserializationError, ValidationRequiredError from dataframely.testing.storage import ( CollectionStorageTester, DeltaCollectionStorageTester, @@ -405,23 +404,74 @@ def test_read_write_parquet_schema_json_fallback_corrupt( # Arrange collection = MyCollection.create_empty() tester = ParquetCollectionStorageTester() - tester.write_untyped( - collection, + tester.write_untyped(collection, any_tmp_path, lazy) + tester.set_metadata( any_tmp_path, - lazy, metadata={COLLECTION_METADATA_KEY: "} this is not a valid JSON {"}, ) # Act spy = mocker.spy(MyCollection, "validate") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) + if validation == "warn": + with pytest.warns(UserWarning): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + else: tester.read(MyCollection, any_tmp_path, lazy, validation=validation) # Assert spy.assert_called_once() +@pytest.mark.parametrize("tester", TESTERS) +@pytest.mark.parametrize("validation", ["forbid", "allow", "skip", "warn"]) +@pytest.mark.parametrize("lazy", [True, False]) +@pytest.mark.parametrize( + "any_tmp_path", + ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], + indirect=True, +) +def test_read_write_old_metadata_contents( + tester: CollectionStorageTester, + any_tmp_path: str, + mocker: pytest_mock.MockerFixture, + validation: Any, + lazy: bool, +) -> None: + """If collection has an old/incompatible schema content, we should fall back to + validating when validation is 'allow' or 'warn', and raise otherwise.""" + # Arrange + collection = MyCollection.create_empty() + tester.write_typed(collection, any_tmp_path, lazy) + tester.set_metadata( + any_tmp_path, + metadata={ + COLLECTION_METADATA_KEY: collection.serialize().replace( + "primary_key", "primary_keys" + ) + }, + ) + + # Act & Assert + match validation: + case "forbid": + with pytest.raises(DeserializationError): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + case "allow": + spy = mocker.spy(MyCollection, "validate") + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + spy.assert_called_once() + case "warn": + spy = mocker.spy(MyCollection, "validate") + with pytest.warns(UserWarning): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + spy.assert_called_once() + case "skip": + spy = mocker.spy(MyCollection, "validate") + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + # Validation should NOT be called because we are skipping it + spy.assert_not_called() + + @pytest.mark.parametrize("metadata", [None, {COLLECTION_METADATA_KEY: "invalid"}]) @pytest.mark.parametrize( "any_tmp_path", diff --git a/tests/schema/test_serialization.py b/tests/schema/test_serialization.py index 78e3bd6..bad2b79 100644 --- a/tests/schema/test_serialization.py +++ b/tests/schema/test_serialization.py @@ -99,7 +99,7 @@ def test_deserialize_unknown_column_type() -> None: "rules": {} } """ - with pytest.raises(ValueError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) @@ -112,13 +112,13 @@ def test_deserialize_unknown_rule_type() -> None: "rules": {"a": {"rule_type": "unknown"}} } """ - with pytest.raises(ValueError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) def test_deserialize_invalid_type() -> None: serialized = '{"__type__": "unknown", "value": "foo"}' - with pytest.raises(TypeError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) @@ -127,5 +127,5 @@ def test_deserialize_invalid_type() -> None: def test_deserialize_unknown_format_version() -> None: serialized = '{"versions": {"format": "invalid"}}' - with pytest.raises(ValueError, match=r"Unsupported schema format version"): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index c102f43..0c66fcc 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -11,7 +11,7 @@ import dataframely as dy from dataframely import Validation from dataframely._storage.delta import DeltaStorageBackend -from dataframely.exc import ValidationRequiredError +from dataframely.exc import DeserializationError, ValidationRequiredError from dataframely.testing import create_schema from dataframely.testing.storage import ( DeltaSchemaStorageTester, @@ -289,6 +289,67 @@ def test_read_write_parquet_validation_skip_invalid_schema( spy.assert_not_called() +# ---------------------------- PARQUET SPECIFICS ---------------------------------- # + + +@pytest.mark.parametrize("tester", TESTERS) +@pytest.mark.parametrize("validation", ["allow", "warn", "skip", "forbid"]) +@pytest.mark.parametrize("lazy", [True, False]) +@pytest.mark.parametrize( + "any_tmp_path", + ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], + indirect=True, +) +def test_read_write_parquet_old_metadata_contents( + tester: SchemaStorageTester, + any_tmp_path: str, + mocker: pytest_mock.MockerFixture, + validation: Validation, + lazy: bool, +) -> None: + """If schema has an old/incompatible content, we should fall back to validating when + validation is 'allow', 'warn' or 'skip' or raise otherwise.""" + # Arrange + from dataframely._storage.constants import SCHEMA_METADATA_KEY + + schema = create_schema("test", {"a": dy.Int64()}) + df = schema.create_empty() + + tester.write_typed(schema, df, any_tmp_path, lazy=lazy) + tester.set_metadata( + any_tmp_path, + metadata={ + SCHEMA_METADATA_KEY: schema.serialize().replace( + "primary_key", "primary_keys" + ) + }, + ) + + # Act and assert + match validation: + case "forbid": + with pytest.raises(DeserializationError): + tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + case "allow": + spy = mocker.spy(schema, "validate") + out = tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + case "warn": + spy = mocker.spy(schema, "validate") + with pytest.warns(UserWarning, match=r"requires validation"): + out = tester.read( + schema, any_tmp_path, lazy=lazy, validation=validation + ) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + case "skip": + spy = mocker.spy(schema, "validate") + out = tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_not_called() + + # ---------------------------- DELTA LAKE SPECIFICS ---------------------------------- #