Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import types
import typing
import dataclasses
from typing import Generic
from typing import TypeVar

Expand Down Expand Up @@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))


def match_is_dataclass(user_type):
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)


def _match_is_optional(user_type):
return _match_is_union(user_type) and sum(
tp is type(None) for tp in _get_args(user_type)) == 1
Expand Down Expand Up @@ -418,6 +423,7 @@ def convert_to_beam_type(typ):
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_is_primitive(tuple), arity=-1,
beam_type=typehints.Tuple),
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

Expand Down Expand Up @@ -127,6 +128,30 @@ def from_user_type(
field_options=field_options,
field_descriptions=field_descriptions)

if match_is_dataclass(user_type):
import dataclasses
fields = [(field.name, field.type)
for field in dataclasses.fields(user_type)]

field_descriptions = getattr(user_type, '_field_descriptions', None)

if _user_type_is_generated(user_type):
return RowTypeConstraint.from_fields(
fields,
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)

# TODO(https://github.com/apache/beam/issues/22125): Add user API for
# specifying schema/field options
return RowTypeConstraint(
fields=fields,
user_type=user_type,
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)

return None

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
from apache_beam.typehints.native_type_compatibility import extract_optional_type
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
Expand Down Expand Up @@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
elif match_is_named_tuple(element_type):
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
Expand Down
61 changes: 61 additions & 0 deletions sdks/python/apache_beam/typehints/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import dataclasses
import itertools
import pickle
import unittest
Expand Down Expand Up @@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type):
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)

def test_dataclass_roundtrip(self):
@dataclasses.dataclass
class SimpleDataclass:
id: np.int64
name: str

roundtripped = typing_from_runner_api(
typing_to_runner_api(
SimpleDataclass, schema_registry=SchemaTypeRegistry()),
schema_registry=SchemaTypeRegistry())

self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
# The roundtripped user_type is generated as a NamedTuple, so we can't test
# equivalence directly with the dataclass.
# Instead, let's verify annotations.
self.assertEqual(
roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__)

def test_row_type_constraint_to_schema(self):
result_type = typing_to_runner_api(
row_type.RowTypeConstraint.from_fields([
Expand Down Expand Up @@ -646,6 +665,48 @@ def test_trivial_example(self):
expected.row_type.schema.fields,
typing_to_runner_api(MyCuteClass).row_type.schema.fields)

def test_trivial_example_dataclass(self):
@dataclasses.dataclass
class MyCuteDataclass:
name: str
age: Optional[int]
interests: List[str]
height: float
blob: ByteString

expected = schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
fields=[
schema_pb2.Field(
name='name',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING),
),
schema_pb2.Field(
name='age',
type=schema_pb2.FieldType(
nullable=True, atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name='interests',
type=schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)))),
schema_pb2.Field(
name='height',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.DOUBLE)),
schema_pb2.Field(
name='blob',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.BYTES)),
])))

self.assertEqual(
expected.row_type.schema.fields,
typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)

def test_user_type_annotated_with_id_after_conversion(self):
MyCuteClass = NamedTuple('MyCuteClass', [
('name', str),
Expand Down
Loading