From 1516eb349e02c36c306fd3cb90287e9b60ee7e2a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:31:42 +0000 Subject: [PATCH 01/13] Initial plan From 047cb43a96160277ca9dc65e34d92a2c80f8ba8b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:49:07 +0000 Subject: [PATCH 02/13] Add strict parameter to collection deserialization and update scan_parquet error handling Co-authored-by: MoritzPotthoffQC <160181542+MoritzPotthoffQC@users.noreply.github.com> --- dataframely/collection/collection.py | 100 ++++++++++++++++--------- dataframely/schema.py | 11 ++- tests/collection/test_serialization.py | 18 +++++ tests/collection/test_storage.py | 34 +++++++++ tests/schema/test_storage.py | 47 ++++++++++++ 5 files changed, 172 insertions(+), 38 deletions(-) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 1576f9f0..c2c43f05 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -12,7 +12,7 @@ from dataclasses import asdict from json import JSONDecodeError from pathlib import Path -from typing import IO, Annotated, Any, Literal, cast +from typing import IO, Annotated, Any, Literal, cast, overload import polars as pl import polars.exceptions as plexc @@ -1184,7 +1184,12 @@ def _read( members=cls.member_schemas().keys(), **kwargs ) - collection_types = _deserialize_types(serialized_collection_types) + # Use strict=False when validation is "allow" or "warn" to tolerate + # deserialization failures from old serialized formats + strict = validation not in ("allow", "warn") + collection_types = _deserialize_types( + serialized_collection_types, strict=strict + ) collection_type = _reconcile_collection_types(collection_types) if cls._requires_validation_for_reading_parquets(collection_type, validation): @@ -1245,14 +1250,27 @@ def read_parquet_metadata_collection( """ metadata = pl.read_parquet_metadata(source) if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None: - try: - return deserialize_collection(schema_metadata) - except (JSONDecodeError, plexc.ComputeError): - return None + return deserialize_collection(schema_metadata, strict=False) return None -def deserialize_collection(data: str) -> type[Collection]: +@overload +def deserialize_collection( + data: str, strict: Literal[True] = True +) -> type[Collection]: ... + + +@overload +def deserialize_collection( + data: str, strict: Literal[False] +) -> type[Collection] | None: ... + + +@overload +def deserialize_collection(data: str, strict: bool) -> type[Collection] | None: ... + + +def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | None: """Deserialize a collection from a JSON string. This method allows to dynamically load a collection from its serialization, without @@ -1260,12 +1278,13 @@ def deserialize_collection(data: str) -> type[Collection]: Args: data: The JSON string created via :meth:`Collection.serialize`. + strict: Whether to raise an exception if the collection cannot be deserialized. Returns: The collection loaded from the JSON data. Raises: - ValueError: If the schema format version is not supported. + ValueError: If the schema format version is not supported and `strict=True`. Attention: The returned collection **cannot** be used to create instances of the @@ -1280,34 +1299,39 @@ def deserialize_collection(data: str) -> type[Collection]: See also: :meth:`Collection.serialize` for additional information on serialization. """ - decoded = json.loads(data, cls=SchemaJSONDecoder) - if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: - raise ValueError(f"Unsupported schema format version: {format}") - - annotations: dict[str, Any] = {} - for name, info in decoded["members"].items(): - lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore - if info["is_optional"]: - lf_type = lf_type | None # type: ignore - annotations[name] = Annotated[ - lf_type, - CollectionMember( - ignored_in_filters=info["ignored_in_filters"], - inline_for_sampling=info["inline_for_sampling"], - ), - ] - - return type( - f"{decoded['name']}_dynamic", - (Collection,), - { - "__annotations__": annotations, - **{ - name: Filter(logic=lambda _, logic=logic: logic) # type: ignore - for name, logic in decoded["filters"].items() + try: + decoded = json.loads(data, cls=SchemaJSONDecoder) + if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: + raise ValueError(f"Unsupported schema format version: {format}") + + annotations: dict[str, Any] = {} + for name, info in decoded["members"].items(): + lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore + if info["is_optional"]: + lf_type = lf_type | None # type: ignore + annotations[name] = Annotated[ + lf_type, + CollectionMember( + ignored_in_filters=info["ignored_in_filters"], + inline_for_sampling=info["inline_for_sampling"], + ), + ] + + return type( + f"{decoded['name']}_dynamic", + (Collection,), + { + "__annotations__": annotations, + **{ + name: Filter(logic=lambda _, logic=logic: logic) # type: ignore + for name, logic in decoded["filters"].items() + }, }, - }, - ) + ) + except (ValueError, JSONDecodeError, plexc.ComputeError) as e: + if strict: + raise e from e + return None # --------------------------------------- UTILS -------------------------------------- # @@ -1333,14 +1357,16 @@ def _extract_keys_if_exist( def _deserialize_types( serialized_collection_types: Iterable[str | None], + strict: bool = True, ) -> list[type[Collection]]: collection_types = [] collection_type: type[Collection] | None = None for t in serialized_collection_types: if t is None: continue - collection_type = deserialize_collection(t) - collection_types.append(collection_type) + collection_type = deserialize_collection(t, strict=strict) + if collection_type is not None: + collection_types.append(collection_type) return collection_types diff --git a/dataframely/schema.py b/dataframely/schema.py index c457ec76..3e5603b6 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1238,8 +1238,13 @@ def _validate_if_needed( validation: Validation, source: str, ) -> DataFrame[Self] | LazyFrame[Self]: + # Use strict=False when validation is "allow" or "warn" to tolerate + # deserialization failures from old serialized formats + strict = validation not in ("allow", "warn") deserialized_schema = ( - deserialize_schema(serialized_schema) if serialized_schema else None + deserialize_schema(serialized_schema, strict=strict) + if serialized_schema + else None ) # Smart validation @@ -1347,6 +1352,10 @@ def deserialize_schema(data: str, strict: Literal[True] = True) -> type[Schema]: def deserialize_schema(data: str, strict: Literal[False]) -> type[Schema] | None: ... +@overload +def deserialize_schema(data: str, strict: bool) -> type[Schema] | None: ... + + def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None: """Deserialize a schema from a JSON string. diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 7455913f..87089e69 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -91,3 +91,21 @@ def test_deserialize_unknown_format_version() -> None: serialized = '{"versions": {"format": "invalid"}}' with pytest.raises(ValueError, match=r"Unsupported schema format version"): dy.deserialize_collection(serialized) + + +def test_deserialize_unknown_format_version_strict_false() -> None: + serialized = '{"versions": {"format": "invalid"}}' + result = dy.deserialize_collection(serialized, strict=False) + assert result is None + + +def test_deserialize_invalid_json_strict_false() -> None: + serialized = '{"invalid json' + result = dy.deserialize_collection(serialized, strict=False) + assert result is None + + +def test_deserialize_invalid_json_strict_true() -> None: + serialized = '{"invalid json' + with pytest.raises(json.JSONDecodeError): + dy.deserialize_collection(serialized, strict=True) diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 6dee93f3..72f2cd1a 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -422,6 +422,40 @@ def test_read_write_parquet_schema_json_fallback_corrupt( spy.assert_called_once() +@pytest.mark.parametrize("validation", ["allow", "warn"]) +@pytest.mark.parametrize("lazy", [True, False]) +@pytest.mark.parametrize( + "any_tmp_path", + ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], + indirect=True, +) +def test_read_write_parquet_old_format_version( + any_tmp_path: str, mocker: pytest_mock.MockerFixture, validation: Any, lazy: bool +) -> None: + """If schema has an old/incompatible format version, we should fall back to + validating when validation is 'allow' or 'warn'.""" + # Arrange + collection = MyCollection.create_empty() + tester = ParquetCollectionStorageTester() + # Use a collection metadata with an invalid format version + old_format_metadata = '{"versions": {"format": "999"}}' + tester.write_untyped( + collection, + any_tmp_path, + lazy, + metadata={COLLECTION_METADATA_KEY: old_format_metadata}, + ) + + # Act + spy = mocker.spy(MyCollection, "validate") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + + # Assert - validation should be called because the old format couldn't be deserialized + spy.assert_called_once() + + @pytest.mark.parametrize("metadata", [None, {COLLECTION_METADATA_KEY: "invalid"}]) @pytest.mark.parametrize( "any_tmp_path", diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index c102f43d..031672f8 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -289,6 +289,53 @@ def test_read_write_parquet_validation_skip_invalid_schema( spy.assert_not_called() +# ---------------------------- PARQUET SPECIFICS ---------------------------------- # + + +@pytest.mark.parametrize("validation", ["allow", "warn"]) +@pytest.mark.parametrize("lazy", [True, False]) +@pytest.mark.parametrize( + "any_tmp_path", + ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], + indirect=True, +) +def test_read_write_parquet_old_format_version( + any_tmp_path: str, + mocker: pytest_mock.MockerFixture, + validation: Validation, + lazy: bool, +) -> None: + """If schema has an old/incompatible format version, we should fall back to + validating when validation is 'allow' or 'warn'.""" + # Arrange + from fsspec import AbstractFileSystem, url_to_fs + + from dataframely._storage.constants import SCHEMA_METADATA_KEY + + schema = create_schema("test", {"a": dy.Int64(), "b": dy.String()}) + df = schema.create_empty() + + # Write directly with custom metadata containing an old format version + fs: AbstractFileSystem = url_to_fs(any_tmp_path)[0] + file_path = fs.sep.join([any_tmp_path, "test.parquet"]) + old_format_metadata = '{"versions": {"format": "999"}}' + df.write_parquet(file_path, metadata={SCHEMA_METADATA_KEY: old_format_metadata}) + + # Act + spy = mocker.spy(schema, "validate") + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + if lazy: + schema.scan_parquet(file_path, validation=validation) + else: + schema.read_parquet(file_path, validation=validation) + + # Assert - validation should be called because the old format couldn't be deserialized + spy.assert_called_once() + + # ---------------------------- DELTA LAKE SPECIFICS ---------------------------------- # From ea5a3aee1c4f2e3e3116110031a727360ff9e27c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:51:31 +0000 Subject: [PATCH 03/13] Fix docstring in collection storage test Co-authored-by: MoritzPotthoffQC <160181542+MoritzPotthoffQC@users.noreply.github.com> --- tests/collection/test_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 72f2cd1a..4271a38c 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -432,7 +432,7 @@ def test_read_write_parquet_schema_json_fallback_corrupt( def test_read_write_parquet_old_format_version( any_tmp_path: str, mocker: pytest_mock.MockerFixture, validation: Any, lazy: bool ) -> None: - """If schema has an old/incompatible format version, we should fall back to + """If collection has an old/incompatible format version, we should fall back to validating when validation is 'allow' or 'warn'.""" # Arrange collection = MyCollection.create_empty() From bfde519922e805a8f4ee21c119609fe2795fd781 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 21 Nov 2025 18:20:06 +0100 Subject: [PATCH 04/13] fixes --- dataframely/collection/collection.py | 11 +++--- tests/collection/test_serialization.py | 47 +++++++++++++++++--------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index c2c43f05..b2f9e663 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -891,13 +891,13 @@ def read_parquet( - `"allow"`: The method tries to read the schema data from the parquet files. If the stored collection schema matches this collection schema, the collection is read without validation. If the stored - schema mismatches this schema no metadata can be found in + schema mismatches this schema, no valid metadata can be found in the parquets, or the files have conflicting metadata, this method automatically runs :meth:`validate` with `cast=True`. - `"warn"`: The method behaves similarly to `"allow"`. However, it prints a warning if validation is necessary. - `"forbid"`: The method never runs validation automatically and only - returns if the metadata stores a collection schema that matches + returns if the metadata stores a valid collection schema that matches this collection. - `"skip"`: The method never runs validation and simply reads the data, entrusting the user that the schema is valid. *Use this option @@ -1185,7 +1185,7 @@ def _read( ) # Use strict=False when validation is "allow" or "warn" to tolerate - # deserialization failures from old serialized formats + # missing or broken collection metadata. strict = validation not in ("allow", "warn") collection_types = _deserialize_types( serialized_collection_types, strict=strict @@ -1285,6 +1285,8 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | Raises: ValueError: If the schema format version is not supported and `strict=True`. + TypeError: If the schema content is invalid and `strict=True`. + JSONDecodeError: If the provided data is not valid JSON and `strict=True`. Attention: The returned collection **cannot** be used to create instances of the @@ -1328,7 +1330,7 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | }, }, ) - except (ValueError, JSONDecodeError, plexc.ComputeError) as e: + except (ValueError, TypeError, JSONDecodeError, plexc.ComputeError) as e: if strict: raise e from e return None @@ -1360,7 +1362,6 @@ def _deserialize_types( strict: bool = True, ) -> list[type[Collection]]: collection_types = [] - collection_type: type[Collection] | None = None for t in serialized_collection_types: if t is None: continue diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 87089e69..1feb7197 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -87,25 +87,40 @@ def test_roundtrip_matches(collection: type[dy.Collection]) -> None: # ----------------------------- DESERIALIZATION FAILURES ----------------------------- # -def test_deserialize_unknown_format_version() -> None: +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_unknown_format_version(strict: bool) -> None: serialized = '{"versions": {"format": "invalid"}}' - with pytest.raises(ValueError, match=r"Unsupported schema format version"): - dy.deserialize_collection(serialized) + if strict: + with pytest.raises(ValueError, match=r"Unsupported schema format version"): + dy.deserialize_collection(serialized) + else: + assert dy.deserialize_collection(serialized, strict=False) is None -def test_deserialize_unknown_format_version_strict_false() -> None: - serialized = '{"versions": {"format": "invalid"}}' - result = dy.deserialize_collection(serialized, strict=False) - assert result is None - - -def test_deserialize_invalid_json_strict_false() -> None: +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_invalid_json_strict_false(strict: bool) -> None: serialized = '{"invalid json' - result = dy.deserialize_collection(serialized, strict=False) - assert result is None + if strict: + with pytest.raises(json.JSONDecodeError): + dy.deserialize_collection(serialized, strict=True) + else: + assert dy.deserialize_collection(serialized, strict=False) is None -def test_deserialize_invalid_json_strict_true() -> None: - serialized = '{"invalid json' - with pytest.raises(json.JSONDecodeError): - dy.deserialize_collection(serialized, strict=True) +@pytest.mark.parametrize("strict", [True, False]) +def test_deserialize_invalid_member_schema(strict: bool) -> None: + collection = create_collection( + "test", + { + "s1": create_schema("schema1", {"a": dy.Int64()}), + "s2": create_schema("schema2", {"a": dy.Int64()}), + }, + ) + serialized = collection.serialize() + broken = serialized.replace("primary_key", "primary_keys") + + if strict: + with pytest.raises(TypeError): + dy.deserialize_collection(broken, strict=strict) + else: + assert dy.deserialize_collection(broken, strict=False) is None From 88bb72efa7ab98f970ab7e341c448e869c646d26 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 21 Nov 2025 18:36:22 +0100 Subject: [PATCH 05/13] start fixing tests --- dataframely/collection/collection.py | 4 +-- dataframely/schema.py | 6 ++--- tests/collection/test_serialization.py | 1 - tests/collection/test_storage.py | 36 +++++++++++++++----------- tests/schema/test_storage.py | 16 ++++++++---- 5 files changed, 37 insertions(+), 26 deletions(-) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index b2f9e663..85a5870a 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -1184,9 +1184,9 @@ def _read( members=cls.member_schemas().keys(), **kwargs ) - # Use strict=False when validation is "allow" or "warn" to tolerate + # Use strict=False when validation is "allow", "warn" or "skip" to tolerate # missing or broken collection metadata. - strict = validation not in ("allow", "warn") + strict = validation == "forbid" collection_types = _deserialize_types( serialized_collection_types, strict=strict ) diff --git a/dataframely/schema.py b/dataframely/schema.py index 3e5603b6..da7c2b37 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1238,9 +1238,9 @@ def _validate_if_needed( validation: Validation, source: str, ) -> DataFrame[Self] | LazyFrame[Self]: - # Use strict=False when validation is "allow" or "warn" to tolerate - # deserialization failures from old serialized formats - strict = validation not in ("allow", "warn") + # Use strict=False when validation is "allow", "warn" or "skip" to tolerate + # deserialization failures from old serialized formats. + strict = validation == "forbid" deserialized_schema = ( deserialize_schema(serialized_schema, strict=strict) if serialized_schema diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 1feb7197..77a8e4b0 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -113,7 +113,6 @@ def test_deserialize_invalid_member_schema(strict: bool) -> None: "test", { "s1": create_schema("schema1", {"a": dy.Int64()}), - "s2": create_schema("schema2", {"a": dy.Int64()}), }, ) serialized = collection.serialize() diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 4271a38c..911cc455 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -422,38 +422,44 @@ def test_read_write_parquet_schema_json_fallback_corrupt( spy.assert_called_once() -@pytest.mark.parametrize("validation", ["allow", "warn"]) +@pytest.mark.parametrize("validation", ["forbid", "allow", "skip", "warn"]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( "any_tmp_path", ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], indirect=True, ) -def test_read_write_parquet_old_format_version( +def test_read_write_parquet_old_schema_contents( any_tmp_path: str, mocker: pytest_mock.MockerFixture, validation: Any, lazy: bool ) -> None: - """If collection has an old/incompatible format version, we should fall back to - validating when validation is 'allow' or 'warn'.""" + """If collection has an old/incompatible schema content, we should fall back to + validating when validation is 'allow' or 'warn', and raise otherwise.""" # Arrange collection = MyCollection.create_empty() tester = ParquetCollectionStorageTester() - # Use a collection metadata with an invalid format version - old_format_metadata = '{"versions": {"format": "999"}}' tester.write_untyped( collection, any_tmp_path, lazy, - metadata={COLLECTION_METADATA_KEY: old_format_metadata}, + metadata={ + COLLECTION_METADATA_KEY: collection.serialize().replace( + "primary_key", "primary_keys" + ) + }, ) - # Act - spy = mocker.spy(MyCollection, "validate") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) - tester.read(MyCollection, any_tmp_path, lazy, validation=validation) - - # Assert - validation should be called because the old format couldn't be deserialized - spy.assert_called_once() + # Act & Assert + if validation == "forbid": + with pytest.raises(ValidationRequiredError): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + else: + spy = mocker.spy(MyCollection, "validate") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + + # Validation should be called because the old format couldn't be deserialized + spy.assert_called_once() @pytest.mark.parametrize("metadata", [None, {COLLECTION_METADATA_KEY: "invalid"}]) diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index 031672f8..b2ac6e25 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -292,7 +292,7 @@ def test_read_write_parquet_validation_skip_invalid_schema( # ---------------------------- PARQUET SPECIFICS ---------------------------------- # -@pytest.mark.parametrize("validation", ["allow", "warn"]) +@pytest.mark.parametrize("validation", ["allow", "warn", "skip", "forbid"]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( "any_tmp_path", @@ -305,8 +305,8 @@ def test_read_write_parquet_old_format_version( validation: Validation, lazy: bool, ) -> None: - """If schema has an old/incompatible format version, we should fall back to - validating when validation is 'allow' or 'warn'.""" + """If schema has an old/incompatible content, we should fall back to validating when + validation is 'allow', 'warn' or 'skip' or raise otherwise.""" # Arrange from fsspec import AbstractFileSystem, url_to_fs @@ -318,8 +318,14 @@ def test_read_write_parquet_old_format_version( # Write directly with custom metadata containing an old format version fs: AbstractFileSystem = url_to_fs(any_tmp_path)[0] file_path = fs.sep.join([any_tmp_path, "test.parquet"]) - old_format_metadata = '{"versions": {"format": "999"}}' - df.write_parquet(file_path, metadata={SCHEMA_METADATA_KEY: old_format_metadata}) + df.write_parquet( + file_path, + metadata={ + SCHEMA_METADATA_KEY: schema.serialize().replace( + "primary_key", "primary_keys" + ) + }, + ) # Act spy = mocker.spy(schema, "validate") From 30b5514a381ed7528e48022d82b4dfa918508919 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 24 Nov 2025 14:10:00 +0100 Subject: [PATCH 06/13] wip --- dataframely/testing/storage.py | 26 ++++++++++++++-- tests/schema/test_storage.py | 55 +++++++++++++++++++++------------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 9807e6cb..615097f4 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -29,11 +29,11 @@ class SchemaStorageTester(ABC): def write_typed( self, schema: type[S], df: dy.DataFrame[S], path: str, lazy: bool ) -> None: - """Write a schema to the backend without recording schema information.""" + """Write a schema to the backend and record schema information.""" @abstractmethod def write_untyped(self, df: pl.DataFrame, path: str, lazy: bool) -> None: - """Write a schema to the backend and record schema information.""" + """Write a schema to the backend without recording schema information.""" @overload def read( @@ -51,6 +51,11 @@ def read( ) -> dy.LazyFrame[S] | dy.DataFrame[S]: """Read from the backend, using schema information if available.""" + @abstractmethod + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + """Overwrite the metadata stored at the given path with the provided + metadata.""" + class ParquetSchemaStorageTester(SchemaStorageTester): """Testing interface for the parquet storage functionality of Schema.""" @@ -93,6 +98,11 @@ def read( else: return schema.read_parquet(self._wrap_path(path), validation=validation) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + target = self._wrap_path(path) + data = pl.read_parquet(target) + data.write_parquet(target, metadata=metadata) + class DeltaSchemaStorageTester(SchemaStorageTester): """Testing interface for the deltalake storage functionality of Schema.""" @@ -122,6 +132,18 @@ def read( return schema.scan_delta(path, validation=validation) return schema.read_delta(path, validation=validation) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + df = pl.read_delta(path) + df.head(0).write_delta( + path, + delta_write_options={ + "commit_properties": deltalake.CommitProperties( + custom_metadata=metadata + ), + }, + mode="overwrite", + ) + # ------------------------------- Collection ------------------------------------------- diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index b2ac6e25..09dde46b 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -11,7 +11,7 @@ import dataframely as dy from dataframely import Validation from dataframely._storage.delta import DeltaStorageBackend -from dataframely.exc import ValidationRequiredError +from dataframely.exc import ValidationError, ValidationRequiredError from dataframely.testing import create_schema from dataframely.testing.storage import ( DeltaSchemaStorageTester, @@ -292,6 +292,7 @@ def test_read_write_parquet_validation_skip_invalid_schema( # ---------------------------- PARQUET SPECIFICS ---------------------------------- # +@pytest.mark.parametrize("tester", TESTERS) @pytest.mark.parametrize("validation", ["allow", "warn", "skip", "forbid"]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( @@ -299,7 +300,8 @@ def test_read_write_parquet_validation_skip_invalid_schema( ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], indirect=True, ) -def test_read_write_parquet_old_format_version( +def test_read_write_parquet_old_metadata_contents( + tester: SchemaStorageTester, any_tmp_path: str, mocker: pytest_mock.MockerFixture, validation: Validation, @@ -308,18 +310,14 @@ def test_read_write_parquet_old_format_version( """If schema has an old/incompatible content, we should fall back to validating when validation is 'allow', 'warn' or 'skip' or raise otherwise.""" # Arrange - from fsspec import AbstractFileSystem, url_to_fs - from dataframely._storage.constants import SCHEMA_METADATA_KEY - schema = create_schema("test", {"a": dy.Int64(), "b": dy.String()}) + schema = create_schema("test", {"a": dy.Int64()}) df = schema.create_empty() - # Write directly with custom metadata containing an old format version - fs: AbstractFileSystem = url_to_fs(any_tmp_path)[0] - file_path = fs.sep.join([any_tmp_path, "test.parquet"]) - df.write_parquet( - file_path, + tester.write_typed(schema, df, any_tmp_path, lazy=lazy) + tester.set_metadata( + any_tmp_path, metadata={ SCHEMA_METADATA_KEY: schema.serialize().replace( "primary_key", "primary_keys" @@ -327,19 +325,34 @@ def test_read_write_parquet_old_format_version( }, ) - # Act - spy = mocker.spy(schema, "validate") - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) + # Act and assert + def read() -> pl.DataFrame | pl.LazyFrame: if lazy: - schema.scan_parquet(file_path, validation=validation) + return schema.scan_parquet(any_tmp_path, validation=validation) else: - schema.read_parquet(file_path, validation=validation) - - # Assert - validation should be called because the old format couldn't be deserialized - spy.assert_called_once() + return schema.read_parquet(any_tmp_path, validation=validation) + + if validation == "forbid": + with pytest.raises(ValidationError): + read() + elif validation == "allow": + spy = mocker.spy(schema, "validate") + out = read() + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + elif validation == "warn": + spy = mocker.spy(schema, "validate") + with pytest.warns( + UserWarning, match=r"requires validation: current schema does not match" + ): + out = read() + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + elif validation == "skip": + spy = mocker.spy(schema, "validate") + out = read() + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_not_called() # ---------------------------- DELTA LAKE SPECIFICS ---------------------------------- # From 037b68cd7b01d41de9ca2e4c6cf5eff3794ac5fa Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 24 Nov 2025 14:47:08 +0100 Subject: [PATCH 07/13] fix --- dataframely/schema.py | 2 +- dataframely/testing/storage.py | 15 +++++++++ tests/collection/test_storage.py | 25 +++++++++------ tests/schema/test_storage.py | 53 +++++++++++++++----------------- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/dataframely/schema.py b/dataframely/schema.py index da7c2b37..13c5620f 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1384,7 +1384,7 @@ def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None: if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION: raise ValueError(f"Unsupported schema format version: {format}") return _schema_from_dict(decoded) - except (ValueError, JSONDecodeError, plexc.ComputeError) as e: + except (ValueError, JSONDecodeError, plexc.ComputeError, TypeError) as e: if strict: raise e from e return None diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 615097f4..85f8ee07 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -45,6 +45,11 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + @abstractmethod def read( self, schema: type[S], path: str, lazy: bool, validation: Validation @@ -88,6 +93,11 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + def read( self, schema: type[S], path: str, lazy: bool, validation: Validation ) -> dy.LazyFrame[S] | dy.DataFrame[S]: @@ -125,6 +135,11 @@ def read( self, schema: type[S], path: str, lazy: Literal[False], validation: Validation ) -> dy.DataFrame[S]: ... + @overload + def read( + self, schema: type[S], path: str, lazy: bool, validation: Validation + ) -> dy.LazyFrame[S] | dy.DataFrame[S]: ... + def read( self, schema: type[S], path: str, lazy: bool, validation: Validation ) -> dy.DataFrame[S] | dy.LazyFrame[S]: diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 911cc455..51f3a368 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -449,17 +449,24 @@ def test_read_write_parquet_old_schema_contents( ) # Act & Assert - if validation == "forbid": - with pytest.raises(ValidationRequiredError): + match validation: + case "forbid": + with pytest.raises(ValidationRequiredError): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + case "allow": + spy = mocker.spy(MyCollection, "validate") tester.read(MyCollection, any_tmp_path, lazy, validation=validation) - else: - spy = mocker.spy(MyCollection, "validate") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) + spy.assert_called_once() + case "warn": + spy = mocker.spy(MyCollection, "validate") + with pytest.warns(UserWarning): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + spy.assert_called_once() + case "skip": + spy = mocker.spy(MyCollection, "validate") tester.read(MyCollection, any_tmp_path, lazy, validation=validation) - - # Validation should be called because the old format couldn't be deserialized - spy.assert_called_once() + # Validation should NOT be called because we are skipping it + spy.assert_not_called() @pytest.mark.parametrize("metadata", [None, {COLLECTION_METADATA_KEY: "invalid"}]) diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index 09dde46b..6600acc7 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -11,7 +11,7 @@ import dataframely as dy from dataframely import Validation from dataframely._storage.delta import DeltaStorageBackend -from dataframely.exc import ValidationError, ValidationRequiredError +from dataframely.exc import ValidationRequiredError from dataframely.testing import create_schema from dataframely.testing.storage import ( DeltaSchemaStorageTester, @@ -326,33 +326,30 @@ def test_read_write_parquet_old_metadata_contents( ) # Act and assert - def read() -> pl.DataFrame | pl.LazyFrame: - if lazy: - return schema.scan_parquet(any_tmp_path, validation=validation) - else: - return schema.read_parquet(any_tmp_path, validation=validation) - - if validation == "forbid": - with pytest.raises(ValidationError): - read() - elif validation == "allow": - spy = mocker.spy(schema, "validate") - out = read() - assert_frame_equal(df.lazy(), out.lazy()) - spy.assert_called_once() - elif validation == "warn": - spy = mocker.spy(schema, "validate") - with pytest.warns( - UserWarning, match=r"requires validation: current schema does not match" - ): - out = read() - assert_frame_equal(df.lazy(), out.lazy()) - spy.assert_called_once() - elif validation == "skip": - spy = mocker.spy(schema, "validate") - out = read() - assert_frame_equal(df.lazy(), out.lazy()) - spy.assert_not_called() + match validation: + case "forbid": + with pytest.raises( + TypeError, match=r"got an unexpected keyword argument 'primary_keys'" + ): + tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + case "allow": + spy = mocker.spy(schema, "validate") + out = tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + case "warn": + spy = mocker.spy(schema, "validate") + with pytest.warns(UserWarning, match=r"requires validation"): + out = tester.read( + schema, any_tmp_path, lazy=lazy, validation=validation + ) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_called_once() + case "skip": + spy = mocker.spy(schema, "validate") + out = tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) + assert_frame_equal(df.lazy(), out.lazy()) + spy.assert_not_called() # ---------------------------- DELTA LAKE SPECIFICS ---------------------------------- # From 9e65575797ef114af181a538afbf55eebede0f7a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 28 Nov 2025 17:41:48 +0000 Subject: [PATCH 08/13] feat: Add DeserializationError exception and parametrize tests over storage backends Co-authored-by: MoritzPotthoffQC <160181542+MoritzPotthoffQC@users.noreply.github.com> --- dataframely/__init__.py | 2 ++ dataframely/collection/collection.py | 10 +++++-- dataframely/exc.py | 7 +++++ dataframely/schema.py | 9 ++++-- dataframely/testing/storage.py | 38 ++++++++++++++++++++++++++ tests/collection/test_serialization.py | 6 ++-- tests/collection/test_storage.py | 19 +++++++------ tests/schema/test_serialization.py | 8 +++--- tests/schema/test_storage.py | 6 ++-- 9 files changed, 82 insertions(+), 23 deletions(-) diff --git a/dataframely/__init__.py b/dataframely/__init__.py index 616fb04b..802d9231 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -51,6 +51,7 @@ UInt64, ) from .config import Config +from .exc import DeserializationError from .filter_result import FailureInfo from .functional import ( concat_collection_members, @@ -106,4 +107,5 @@ "Array", "Object", "Validation", + "DeserializationError", ] diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 85a5870a..4fc7c95a 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -33,7 +33,11 @@ from dataframely._storage.delta import DeltaStorageBackend from dataframely._storage.parquet import ParquetStorageBackend from dataframely._typing import LazyFrame, Validation -from dataframely.exc import ValidationError, ValidationRequiredError +from dataframely.exc import ( + DeserializationError, + ValidationError, + ValidationRequiredError, +) from dataframely.filter_result import FailureInfo from dataframely.random import Generator from dataframely.schema import _schema_from_dict @@ -1332,7 +1336,9 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | ) except (ValueError, TypeError, JSONDecodeError, plexc.ComputeError) as e: if strict: - raise e from e + raise DeserializationError( + "The collection could not be deserialized" + ) from e return None diff --git a/dataframely/exc.py b/dataframely/exc.py index a87497a7..bc324200 100644 --- a/dataframely/exc.py +++ b/dataframely/exc.py @@ -41,3 +41,10 @@ def __init__(self, attr: str, kls: type) -> None: class ValidationRequiredError(Exception): """Error raised when validation is required when reading a parquet file.""" + + +# ---------------------------------- DESERIALIZATION --------------------------------- # + + +class DeserializationError(Exception): + """Error raised when deserialization of a schema or collection fails.""" diff --git a/dataframely/schema.py b/dataframely/schema.py index 13c5620f..64096bf6 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -40,7 +40,12 @@ from ._typing import DataFrame, LazyFrame, Validation from .columns import Column, column_from_dict from .config import Config -from .exc import SchemaError, ValidationError, ValidationRequiredError +from .exc import ( + DeserializationError, + SchemaError, + ValidationError, + ValidationRequiredError, +) from .filter_result import FailureInfo, FilterResult, LazyFilterResult from .random import Generator @@ -1386,7 +1391,7 @@ def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None: return _schema_from_dict(decoded) except (ValueError, JSONDecodeError, plexc.ComputeError, TypeError) as e: if strict: - raise e from e + raise DeserializationError("The schema could not be deserialized") from e return None diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 85f8ee07..dc8e68ef 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -184,6 +184,11 @@ def write_untyped( def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: """Read from the backend, using collection information if available.""" + @abstractmethod + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + """Overwrite the metadata stored at the given path with the provided + metadata.""" + class ParquetCollectionStorageTester(CollectionStorageTester): def write_typed( @@ -230,6 +235,22 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: else: return collection.read_parquet(path, **kwargs) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + fs: AbstractFileSystem = url_to_fs(path)[0] + prefix = ( + "" + if fs.protocol == "file" + else ( + f"{fs.protocol}://" + if isinstance(fs.protocol, str) + else f"{fs.protocol[0]}://" + ) + ) + for file in fs.glob(fs.sep.join([path, "**", "*.parquet"])): + file_path = f"{prefix}{file}" + df = pl.read_parquet(file_path) + df.write_parquet(file_path, metadata=metadata) + class DeltaCollectionStorageTester(CollectionStorageTester): def write_typed( @@ -259,6 +280,23 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: return collection.scan_delta(source=path, **kwargs) return collection.read_delta(source=path, **kwargs) + def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: + fs: AbstractFileSystem = url_to_fs(path)[0] + # For delta, we need to update metadata on each member table + for entry in fs.ls(path): + member_path = entry["name"] if isinstance(entry, dict) else entry + if fs.isdir(member_path): + df = pl.read_delta(member_path) + df.head(0).write_delta( + member_path, + delta_write_options={ + "commit_properties": deltalake.CommitProperties( + custom_metadata=metadata + ), + }, + mode="overwrite", + ) + # ------------------------------------ Failure info ------------------------------------ class FailureInfoStorageTester(ABC): diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 77a8e4b0..d7c6bc52 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -91,7 +91,7 @@ def test_roundtrip_matches(collection: type[dy.Collection]) -> None: def test_deserialize_unknown_format_version(strict: bool) -> None: serialized = '{"versions": {"format": "invalid"}}' if strict: - with pytest.raises(ValueError, match=r"Unsupported schema format version"): + with pytest.raises(dy.DeserializationError): dy.deserialize_collection(serialized) else: assert dy.deserialize_collection(serialized, strict=False) is None @@ -101,7 +101,7 @@ def test_deserialize_unknown_format_version(strict: bool) -> None: def test_deserialize_invalid_json_strict_false(strict: bool) -> None: serialized = '{"invalid json' if strict: - with pytest.raises(json.JSONDecodeError): + with pytest.raises(dy.DeserializationError): dy.deserialize_collection(serialized, strict=True) else: assert dy.deserialize_collection(serialized, strict=False) is None @@ -119,7 +119,7 @@ def test_deserialize_invalid_member_schema(strict: bool) -> None: broken = serialized.replace("primary_key", "primary_keys") if strict: - with pytest.raises(TypeError): + with pytest.raises(dy.DeserializationError): dy.deserialize_collection(broken, strict=strict) else: assert dy.deserialize_collection(broken, strict=False) is None diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 51f3a368..8038fe78 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -14,7 +14,7 @@ from dataframely._storage.constants import COLLECTION_METADATA_KEY from dataframely._storage.delta import DeltaStorageBackend from dataframely.collection.collection import _reconcile_collection_types -from dataframely.exc import ValidationRequiredError +from dataframely.exc import DeserializationError, ValidationRequiredError from dataframely.testing.storage import ( CollectionStorageTester, DeltaCollectionStorageTester, @@ -422,6 +422,7 @@ def test_read_write_parquet_schema_json_fallback_corrupt( spy.assert_called_once() +@pytest.mark.parametrize("tester", TESTERS) @pytest.mark.parametrize("validation", ["forbid", "allow", "skip", "warn"]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( @@ -429,18 +430,20 @@ def test_read_write_parquet_schema_json_fallback_corrupt( ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], indirect=True, ) -def test_read_write_parquet_old_schema_contents( - any_tmp_path: str, mocker: pytest_mock.MockerFixture, validation: Any, lazy: bool +def test_read_write_old_metadata_contents( + tester: CollectionStorageTester, + any_tmp_path: str, + mocker: pytest_mock.MockerFixture, + validation: Any, + lazy: bool, ) -> None: """If collection has an old/incompatible schema content, we should fall back to validating when validation is 'allow' or 'warn', and raise otherwise.""" # Arrange collection = MyCollection.create_empty() - tester = ParquetCollectionStorageTester() - tester.write_untyped( - collection, + tester.write_typed(collection, any_tmp_path, lazy) + tester.set_metadata( any_tmp_path, - lazy, metadata={ COLLECTION_METADATA_KEY: collection.serialize().replace( "primary_key", "primary_keys" @@ -451,7 +454,7 @@ def test_read_write_parquet_old_schema_contents( # Act & Assert match validation: case "forbid": - with pytest.raises(ValidationRequiredError): + with pytest.raises(DeserializationError): tester.read(MyCollection, any_tmp_path, lazy, validation=validation) case "allow": spy = mocker.spy(MyCollection, "validate") diff --git a/tests/schema/test_serialization.py b/tests/schema/test_serialization.py index 78e3bd62..bad2b794 100644 --- a/tests/schema/test_serialization.py +++ b/tests/schema/test_serialization.py @@ -99,7 +99,7 @@ def test_deserialize_unknown_column_type() -> None: "rules": {} } """ - with pytest.raises(ValueError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) @@ -112,13 +112,13 @@ def test_deserialize_unknown_rule_type() -> None: "rules": {"a": {"rule_type": "unknown"}} } """ - with pytest.raises(ValueError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) def test_deserialize_invalid_type() -> None: serialized = '{"__type__": "unknown", "value": "foo"}' - with pytest.raises(TypeError): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) @@ -127,5 +127,5 @@ def test_deserialize_invalid_type() -> None: def test_deserialize_unknown_format_version() -> None: serialized = '{"versions": {"format": "invalid"}}' - with pytest.raises(ValueError, match=r"Unsupported schema format version"): + with pytest.raises(dy.DeserializationError): dy.deserialize_schema(serialized) diff --git a/tests/schema/test_storage.py b/tests/schema/test_storage.py index 6600acc7..0c66fcc1 100644 --- a/tests/schema/test_storage.py +++ b/tests/schema/test_storage.py @@ -11,7 +11,7 @@ import dataframely as dy from dataframely import Validation from dataframely._storage.delta import DeltaStorageBackend -from dataframely.exc import ValidationRequiredError +from dataframely.exc import DeserializationError, ValidationRequiredError from dataframely.testing import create_schema from dataframely.testing.storage import ( DeltaSchemaStorageTester, @@ -328,9 +328,7 @@ def test_read_write_parquet_old_metadata_contents( # Act and assert match validation: case "forbid": - with pytest.raises( - TypeError, match=r"got an unexpected keyword argument 'primary_keys'" - ): + with pytest.raises(DeserializationError): tester.read(schema, any_tmp_path, lazy=lazy, validation=validation) case "allow": spy = mocker.spy(schema, "validate") From 96d336a7129ade62c0dd3d4c2f132d48ed81fbd4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 28 Nov 2025 17:59:15 +0000 Subject: [PATCH 09/13] fix: Use set_metadata instead of metadata kwarg in test_read_write_parquet_schema_json_fallback_corrupt Co-authored-by: MoritzPotthoffQC <160181542+MoritzPotthoffQC@users.noreply.github.com> --- tests/collection/test_storage.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 8038fe78..514352d8 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -405,10 +405,9 @@ def test_read_write_parquet_schema_json_fallback_corrupt( # Arrange collection = MyCollection.create_empty() tester = ParquetCollectionStorageTester() - tester.write_untyped( - collection, + tester.write_untyped(collection, any_tmp_path, lazy) + tester.set_metadata( any_tmp_path, - lazy, metadata={COLLECTION_METADATA_KEY: "} this is not a valid JSON {"}, ) From aae41bf9d47a19817a96ae53477eaf4c8dacbd98 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 1 Dec 2025 11:23:11 +0100 Subject: [PATCH 10/13] cleanup --- dataframely/testing/storage.py | 46 +++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index dc8e68ef..23dd6b7b 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -194,6 +194,11 @@ class ParquetCollectionStorageTester(CollectionStorageTester): def write_typed( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "`metadata` kwarg will be ignored in `write_typed`. Use `set_metadata`." + ) + # Polars does not support partitioning via kwarg on sink_parquet if lazy: kwargs.pop("partition_by", None) @@ -206,6 +211,11 @@ def write_typed( def write_untyped( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "Cannot set metadata through `write_untyped`. Use `set_metadata`." + ) + if lazy: collection.sink_parquet(path, **kwargs) else: @@ -217,17 +227,8 @@ def _delete_meta(file: str) -> None: df.write_parquet(file) fs: AbstractFileSystem = url_to_fs(path)[0] - prefix = ( - "" - if fs.protocol == "file" - else ( - f"{fs.protocol}://" - if isinstance(fs.protocol, str) - else f"{fs.protocol[0]}://" - ) - ) for file in fs.glob(fs.sep.join([path, "**", "*.parquet"])): - _delete_meta(f"{prefix}{file}") + _delete_meta(f"{self._get_prefix(fs)}{file}") def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: if lazy: @@ -237,7 +238,14 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] - prefix = ( + for file in fs.glob(fs.sep.join([path, "*.parquet"])): + file_path = f"{self._get_prefix(fs)}{file}" + df = pl.read_parquet(file_path) + df.write_parquet(file_path, metadata=metadata) + + @staticmethod + def _get_prefix(fs: AbstractFileSystem) -> str: + return ( "" if fs.protocol == "file" else ( @@ -246,16 +254,17 @@ def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: else f"{fs.protocol[0]}://" ) ) - for file in fs.glob(fs.sep.join([path, "**", "*.parquet"])): - file_path = f"{prefix}{file}" - df = pl.read_parquet(file_path) - df.write_parquet(file_path, metadata=metadata) class DeltaCollectionStorageTester(CollectionStorageTester): def write_typed( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "`metadata` kwarg will be ignored in `write_typed`. Use `set_metadata`." + ) + extra_kwargs = {} if partition_by := kwargs.pop("partition_by", None): extra_kwargs["delta_write_options"] = {"partition_by": partition_by} @@ -265,6 +274,10 @@ def write_typed( def write_untyped( self, collection: dy.Collection, path: str, lazy: bool, **kwargs: Any ) -> None: + if "metadata" in kwargs: + raise KeyError( + "Cannot set metadata through `write_untyped`. Use `set_metadata`." + ) collection.write_delta(path, **kwargs) # For each member table, write an empty commit @@ -283,8 +296,7 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] # For delta, we need to update metadata on each member table - for entry in fs.ls(path): - member_path = entry["name"] if isinstance(entry, dict) else entry + for member_path in fs.ls(path): if fs.isdir(member_path): df = pl.read_delta(member_path) df.head(0).write_delta( From a38a6ee5207ed0221cc8d43e80c891f03d40412c Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 1 Dec 2025 11:36:19 +0100 Subject: [PATCH 11/13] fix --- dataframely/collection/collection.py | 5 ++--- dataframely/testing/storage.py | 7 +++++-- tests/collection/test_storage.py | 7 ++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 4fc7c95a..74b92e40 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -1288,9 +1288,8 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | The collection loaded from the JSON data. Raises: - ValueError: If the schema format version is not supported and `strict=True`. - TypeError: If the schema content is invalid and `strict=True`. - JSONDecodeError: If the provided data is not valid JSON and `strict=True`. + DeserializationError: If the collection can not be deserialized + and `strict=True`. Attention: The returned collection **cannot** be used to create instances of the diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 23dd6b7b..696c02a5 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -228,7 +228,7 @@ def _delete_meta(file: str) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] for file in fs.glob(fs.sep.join([path, "**", "*.parquet"])): - _delete_meta(f"{self._get_prefix(fs)}{file}") + _delete_meta(self._prefix_path(file, fs)) def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: if lazy: @@ -239,10 +239,13 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] for file in fs.glob(fs.sep.join([path, "*.parquet"])): - file_path = f"{self._get_prefix(fs)}{file}" + file_path = self._prefix_path(file, fs) df = pl.read_parquet(file_path) df.write_parquet(file_path, metadata=metadata) + def _prefix_path(self, path: str, fs: AbstractFileSystem) -> str: + return f"{self._get_prefix(fs)}{path}" + @staticmethod def _get_prefix(fs: AbstractFileSystem) -> str: return ( diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 514352d8..9ff416df 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -1,7 +1,6 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause -import warnings from typing import Any import polars as pl @@ -413,8 +412,10 @@ def test_read_write_parquet_schema_json_fallback_corrupt( # Act spy = mocker.spy(MyCollection, "validate") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) + if validation == "warn": + with pytest.warns(UserWarning): + tester.read(MyCollection, any_tmp_path, lazy, validation=validation) + else: tester.read(MyCollection, any_tmp_path, lazy, validation=validation) # Assert From 674af3988f55ebdc52c05b6046632f9d52dee0e0 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 1 Dec 2025 11:52:35 +0100 Subject: [PATCH 12/13] fix? --- dataframely/testing/storage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index 696c02a5..a52e338a 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -299,7 +299,8 @@ def read(self, collection: type[C], path: str, lazy: bool, **kwargs: Any) -> C: def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] # For delta, we need to update metadata on each member table - for member_path in fs.ls(path): + for entry in fs.ls(path): + member_path = entry["name"] if isinstance(entry, dict) else entry if fs.isdir(member_path): df = pl.read_delta(member_path) df.head(0).write_delta( From f0f137e69f6d571d6e44c9730842e35c090a5a6b Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Mon, 1 Dec 2025 13:42:48 +0100 Subject: [PATCH 13/13] fix --- .github/copilot-instructions.md | 2 +- dataframely/testing/storage.py | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index d3bf7319..774a0ce1 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -202,7 +202,7 @@ validated_df: dy.DataFrame[MySchema] = MySchema.validate(df, cast=True) 1. **Python code**: Run `pixi run pre-commit run` before committing 2. **Rust code**: Run `pixi run postinstall` to rebuild, then run tests -3. **Tests**: Ensure `pixi run test` passes +3. **Tests**: Ensure `pixi run test` passes. If changes might affect storage backends, use `pixi run test -m s3`. 4. **Documentation**: Update docstrings 5. **API changes**: Ensure backward compatibility or document migration path diff --git a/dataframely/testing/storage.py b/dataframely/testing/storage.py index a52e338a..b4cf14b5 100644 --- a/dataframely/testing/storage.py +++ b/dataframely/testing/storage.py @@ -189,6 +189,21 @@ def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: """Overwrite the metadata stored at the given path with the provided metadata.""" + def _prefix_path(self, path: str, fs: AbstractFileSystem) -> str: + return f"{self._get_prefix(fs)}{path}" + + @staticmethod + def _get_prefix(fs: AbstractFileSystem) -> str: + return ( + "" + if fs.protocol == "file" + else ( + f"{fs.protocol}://" + if isinstance(fs.protocol, str) + else f"{fs.protocol[0]}://" + ) + ) + class ParquetCollectionStorageTester(CollectionStorageTester): def write_typed( @@ -243,21 +258,6 @@ def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: df = pl.read_parquet(file_path) df.write_parquet(file_path, metadata=metadata) - def _prefix_path(self, path: str, fs: AbstractFileSystem) -> str: - return f"{self._get_prefix(fs)}{path}" - - @staticmethod - def _get_prefix(fs: AbstractFileSystem) -> str: - return ( - "" - if fs.protocol == "file" - else ( - f"{fs.protocol}://" - if isinstance(fs.protocol, str) - else f"{fs.protocol[0]}://" - ) - ) - class DeltaCollectionStorageTester(CollectionStorageTester): def write_typed( @@ -300,7 +300,7 @@ def set_metadata(self, path: str, metadata: dict[str, Any]) -> None: fs: AbstractFileSystem = url_to_fs(path)[0] # For delta, we need to update metadata on each member table for entry in fs.ls(path): - member_path = entry["name"] if isinstance(entry, dict) else entry + member_path = self._prefix_path(entry, fs) if fs.isdir(member_path): df = pl.read_delta(member_path) df.head(0).write_delta(