Skip to content
Draft
2 changes: 1 addition & 1 deletion .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ validated_df: dy.DataFrame[MySchema] = MySchema.validate(df, cast=True)

1. **Python code**: Run `pixi run pre-commit run` before committing
2. **Rust code**: Run `pixi run postinstall` to rebuild, then run tests
3. **Tests**: Ensure `pixi run test` passes
3. **Tests**: Ensure `pixi run test` passes. If changes might affect storage backends, use `pixi run test -m s3`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope that copilot is clever enough to figure this out and this saves some running time versus always requiring this.

4. **Documentation**: Update docstrings
5. **API changes**: Ensure backward compatibility or document migration path

Expand Down
2 changes: 2 additions & 0 deletions dataframely/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UInt64,
)
from .config import Config
from .exc import DeserializationError
from .filter_result import FailureInfo
from .functional import (
concat_collection_members,
Expand Down Expand Up @@ -106,4 +107,5 @@
"Array",
"Object",
"Validation",
"DeserializationError",
]
114 changes: 73 additions & 41 deletions dataframely/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import asdict
from json import JSONDecodeError
from pathlib import Path
from typing import IO, Annotated, Any, Literal, cast
from typing import IO, Annotated, Any, Literal, cast, overload

import polars as pl
import polars.exceptions as plexc
Expand All @@ -33,7 +33,11 @@
from dataframely._storage.delta import DeltaStorageBackend
from dataframely._storage.parquet import ParquetStorageBackend
from dataframely._typing import LazyFrame, Validation
from dataframely.exc import ValidationError, ValidationRequiredError
from dataframely.exc import (
DeserializationError,
ValidationError,
ValidationRequiredError,
)
from dataframely.filter_result import FailureInfo
from dataframely.random import Generator
from dataframely.schema import _schema_from_dict
Expand Down Expand Up @@ -891,13 +895,13 @@ def read_parquet(
- `"allow"`: The method tries to read the schema data from the parquet
files. If the stored collection schema matches this collection
schema, the collection is read without validation. If the stored
schema mismatches this schema no metadata can be found in
schema mismatches this schema, no valid metadata can be found in
the parquets, or the files have conflicting metadata,
this method automatically runs :meth:`validate` with `cast=True`.
- `"warn"`: The method behaves similarly to `"allow"`. However,
it prints a warning if validation is necessary.
- `"forbid"`: The method never runs validation automatically and only
returns if the metadata stores a collection schema that matches
returns if the metadata stores a valid collection schema that matches
this collection.
- `"skip"`: The method never runs validation and simply reads the
data, entrusting the user that the schema is valid. *Use this option
Expand Down Expand Up @@ -1184,7 +1188,12 @@ def _read(
members=cls.member_schemas().keys(), **kwargs
)

collection_types = _deserialize_types(serialized_collection_types)
# Use strict=False when validation is "allow", "warn" or "skip" to tolerate
# missing or broken collection metadata.
strict = validation == "forbid"
collection_types = _deserialize_types(
serialized_collection_types, strict=strict
)
collection_type = _reconcile_collection_types(collection_types)

if cls._requires_validation_for_reading_parquets(collection_type, validation):
Expand Down Expand Up @@ -1245,27 +1254,42 @@ def read_parquet_metadata_collection(
"""
metadata = pl.read_parquet_metadata(source)
if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None:
try:
return deserialize_collection(schema_metadata)
except (JSONDecodeError, plexc.ComputeError):
return None
return deserialize_collection(schema_metadata, strict=False)
return None


def deserialize_collection(data: str) -> type[Collection]:
@overload
def deserialize_collection(
data: str, strict: Literal[True] = True
) -> type[Collection]: ...


@overload
def deserialize_collection(
data: str, strict: Literal[False]
) -> type[Collection] | None: ...


@overload
def deserialize_collection(data: str, strict: bool) -> type[Collection] | None: ...


def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | None:
"""Deserialize a collection from a JSON string.

This method allows to dynamically load a collection from its serialization, without
having to know the collection to load in advance.

Args:
data: The JSON string created via :meth:`Collection.serialize`.
strict: Whether to raise an exception if the collection cannot be deserialized.

Returns:
The collection loaded from the JSON data.

Raises:
ValueError: If the schema format version is not supported.
DeserializationError: If the collection can not be deserialized
and `strict=True`.

Attention:
The returned collection **cannot** be used to create instances of the
Expand All @@ -1280,34 +1304,41 @@ def deserialize_collection(data: str) -> type[Collection]:
See also:
:meth:`Collection.serialize` for additional information on serialization.
"""
decoded = json.loads(data, cls=SchemaJSONDecoder)
if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION:
raise ValueError(f"Unsupported schema format version: {format}")

annotations: dict[str, Any] = {}
for name, info in decoded["members"].items():
lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore
if info["is_optional"]:
lf_type = lf_type | None # type: ignore
annotations[name] = Annotated[
lf_type,
CollectionMember(
ignored_in_filters=info["ignored_in_filters"],
inline_for_sampling=info["inline_for_sampling"],
),
]

return type(
f"{decoded['name']}_dynamic",
(Collection,),
{
"__annotations__": annotations,
**{
name: Filter(logic=lambda _, logic=logic: logic) # type: ignore
for name, logic in decoded["filters"].items()
try:
decoded = json.loads(data, cls=SchemaJSONDecoder)
if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION:
raise ValueError(f"Unsupported schema format version: {format}")

annotations: dict[str, Any] = {}
for name, info in decoded["members"].items():
lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore
if info["is_optional"]:
lf_type = lf_type | None # type: ignore
annotations[name] = Annotated[
lf_type,
CollectionMember(
ignored_in_filters=info["ignored_in_filters"],
inline_for_sampling=info["inline_for_sampling"],
),
]

return type(
f"{decoded['name']}_dynamic",
(Collection,),
{
"__annotations__": annotations,
**{
name: Filter(logic=lambda _, logic=logic: logic) # type: ignore
for name, logic in decoded["filters"].items()
},
},
},
)
)
except (ValueError, TypeError, JSONDecodeError, plexc.ComputeError) as e:
if strict:
raise DeserializationError(
"The collection could not be deserialized"
) from e
return None


# --------------------------------------- UTILS -------------------------------------- #
Expand All @@ -1333,14 +1364,15 @@ def _extract_keys_if_exist(

def _deserialize_types(
serialized_collection_types: Iterable[str | None],
strict: bool = True,
) -> list[type[Collection]]:
collection_types = []
collection_type: type[Collection] | None = None
for t in serialized_collection_types:
if t is None:
continue
collection_type = deserialize_collection(t)
collection_types.append(collection_type)
collection_type = deserialize_collection(t, strict=strict)
if collection_type is not None:
collection_types.append(collection_type)

return collection_types

Expand Down
7 changes: 7 additions & 0 deletions dataframely/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ def __init__(self, attr: str, kls: type) -> None:

class ValidationRequiredError(Exception):
"""Error raised when validation is required when reading a parquet file."""


# ---------------------------------- DESERIALIZATION --------------------------------- #


class DeserializationError(Exception):
"""Error raised when deserialization of a schema or collection fails."""
22 changes: 18 additions & 4 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from ._typing import DataFrame, LazyFrame, Validation
from .columns import Column, column_from_dict
from .config import Config
from .exc import SchemaError, ValidationError, ValidationRequiredError
from .exc import (
DeserializationError,
SchemaError,
ValidationError,
ValidationRequiredError,
)
from .filter_result import FailureInfo, FilterResult, LazyFilterResult
from .random import Generator

Expand Down Expand Up @@ -1238,8 +1243,13 @@ def _validate_if_needed(
validation: Validation,
source: str,
) -> DataFrame[Self] | LazyFrame[Self]:
# Use strict=False when validation is "allow", "warn" or "skip" to tolerate
# deserialization failures from old serialized formats.
strict = validation == "forbid"
deserialized_schema = (
deserialize_schema(serialized_schema) if serialized_schema else None
deserialize_schema(serialized_schema, strict=strict)
if serialized_schema
else None
)

# Smart validation
Expand Down Expand Up @@ -1347,6 +1357,10 @@ def deserialize_schema(data: str, strict: Literal[True] = True) -> type[Schema]:
def deserialize_schema(data: str, strict: Literal[False]) -> type[Schema] | None: ...


@overload
def deserialize_schema(data: str, strict: bool) -> type[Schema] | None: ...


def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None:
"""Deserialize a schema from a JSON string.

Expand Down Expand Up @@ -1375,9 +1389,9 @@ def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None:
if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION:
raise ValueError(f"Unsupported schema format version: {format}")
return _schema_from_dict(decoded)
except (ValueError, JSONDecodeError, plexc.ComputeError) as e:
except (ValueError, JSONDecodeError, plexc.ComputeError, TypeError) as e:
if strict:
raise e from e
raise DeserializationError("The schema could not be deserialized") from e
return None


Expand Down
Loading
Loading