diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 345c04706d6f..f73a41d98bcf 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -25,6 +25,7 @@ import sys import types import typing +import dataclasses from typing import Generic from typing import TypeVar @@ -175,6 +176,10 @@ def match_is_named_tuple(user_type): hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields')) +def match_is_dataclass(user_type): + return dataclasses.is_dataclass(user_type) and isinstance(user_type, type) + + def _match_is_optional(user_type): return _match_is_union(user_type) and sum( tp is type(None) for tp in _get_args(user_type)) == 1 @@ -418,6 +423,7 @@ def convert_to_beam_type(typ): # This MUST appear before the entry for the normal Tuple. _TypeMapEntry( match=match_is_named_tuple, arity=0, beam_type=typehints.Any), + _TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any), _TypeMapEntry( match=_match_is_primitive(tuple), arity=-1, beam_type=typehints.Tuple), diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 08838c84a050..bf2b0f90aab6 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -26,6 +26,7 @@ from typing import Tuple from apache_beam.typehints import typehints +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -127,6 +128,30 @@ def from_user_type( field_options=field_options, field_descriptions=field_descriptions) + if match_is_dataclass(user_type): + import dataclasses + fields = [(field.name, field.type) + for field in dataclasses.fields(user_type)] + + field_descriptions = getattr(user_type, '_field_descriptions', None) + + if _user_type_is_generated(user_type): + return RowTypeConstraint.from_fields( + fields, + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) + + # TODO(https://github.com/apache/beam/issues/22125): Add user API for + # specifying schema/field options + return RowTypeConstraint( + fields=fields, + user_type=user_type, + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) + return None @staticmethod diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index e9674fa5bc20..5dd8ff290c48 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -96,6 +96,7 @@ from apache_beam.typehints.native_type_compatibility import _safe_issubclass from apache_beam.typehints.native_type_compatibility import convert_to_python_type from apache_beam.typehints.native_type_compatibility import extract_optional_type +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema: Returns schema as a list of (name, python_type) tuples""" if isinstance(element_type, row_type.RowTypeConstraint): return named_fields_to_schema(element_type._fields) - elif match_is_named_tuple(element_type): + elif match_is_named_tuple(element_type) or match_is_dataclass(element_type): if hasattr(element_type, row_type._BEAM_SCHEMA_ID): # if the named tuple's schema is in registry, we just use it instead of # regenerating one. diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 73db06b9a8d2..5a5d7396ab30 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import dataclasses import itertools import pickle import unittest @@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type): self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) self.assert_namedtuple_equivalent(roundtripped.user_type, user_type) + def test_dataclass_roundtrip(self): + @dataclasses.dataclass + class SimpleDataclass: + id: np.int64 + name: str + + roundtripped = typing_from_runner_api( + typing_to_runner_api( + SimpleDataclass, schema_registry=SchemaTypeRegistry()), + schema_registry=SchemaTypeRegistry()) + + self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) + # The roundtripped user_type is generated as a NamedTuple, so we can't test + # equivalence directly with the dataclass. + # Instead, let's verify annotations. + self.assertEqual( + roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__) + def test_row_type_constraint_to_schema(self): result_type = typing_to_runner_api( row_type.RowTypeConstraint.from_fields([ @@ -646,6 +665,48 @@ def test_trivial_example(self): expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields) + def test_trivial_example_dataclass(self): + @dataclasses.dataclass + class MyCuteDataclass: + name: str + age: Optional[int] + interests: List[str] + height: float + blob: ByteString + + expected = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING), + ), + schema_pb2.Field( + name='age', + type=schema_pb2.FieldType( + nullable=True, atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='interests', + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)))), + schema_pb2.Field( + name='height', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.DOUBLE)), + schema_pb2.Field( + name='blob', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTES)), + ]))) + + self.assertEqual( + expected.row_type.schema.fields, + typing_to_runner_api(MyCuteDataclass).row_type.schema.fields) + def test_user_type_annotated_with_id_after_conversion(self): MyCuteClass = NamedTuple('MyCuteClass', [ ('name', str),