From b0b27c880e32cad3ea7ae099f1f88dfa874054ca Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 3 Jun 2026 16:19:48 -0400 Subject: [PATCH] Normalize types in dataclass field type resolving Add a pipeline option to allow fallback to Any --- CHANGES.md | 3 +++ .../apache_beam/options/pipeline_options.py | 9 +++++++ sdks/python/apache_beam/typehints/opcodes.py | 4 ++- .../typehints/trivial_inference.py | 26 +++++++++++++++++++ .../typehints/trivial_inference_test.py | 21 ++++++++++++--- .../python/apache_beam/typehints/typehints.py | 1 + 6 files changed, 60 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b8b10c352fae..698d88b01fab 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -74,6 +74,9 @@ ## Breaking Changes * X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). +* (Python) Typehints of dataclass fields are honored during type inferences. To restore the behavior of fallback-to-any, + use pipeline option `--exclude_infer_dataclass_field_type` ([#38797](https://github.com/apache/beam/issues/38797)). + However fixing forward is recommended. ## Deprecations diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index e3ab13e25122..a5b66ce28ac5 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -888,6 +888,15 @@ def _add_argparse_args(cls, parser): default=False, action='store_true', help='Disable the use of beartype for type checking.') + parser.add_argument( + '--exclude_infer_dataclass_field_type', + default=False, + action='store_true', + help='Exclude certain typehint inference involving dataclass fields ' + 'and resolve to Any (as in beam<=2.74.0). NOTE: this option is ' + 'for backward compatibility only and the exclusion scenarios are ' + 'subject to change or remove in a future version. For details see: ' + 'https://beam.apache.org/releases/pydoc/current/apache_beam.typehints.trivial_inference.html#apache_beam.typehints.trivial_inference.resolve_dataclass_field_type') # pylint: disable=line-too-long parser.add_argument( '--runtime_type_check', default=False, diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 963b5e0850b6..53eabdadc4af 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -42,6 +42,7 @@ from apache_beam.typehints.trivial_inference import Const from apache_beam.typehints.trivial_inference import element_type from apache_beam.typehints.trivial_inference import key_value_types +from apache_beam.typehints.trivial_inference import resolve_dataclass_field_type from apache_beam.typehints.trivial_inference import union from apache_beam.typehints.typehints import Any from apache_beam.typehints.typehints import Dict @@ -451,8 +452,9 @@ def _getattr(o, name): elif inspect.isclass(o) and dataclasses.is_dataclass(o): field = o.__dataclass_fields__.get(name) if field is not None: - return field.type + return resolve_dataclass_field_type(field.type) return Any + else: return Any diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 68e126a89393..69edfc309281 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -774,3 +774,29 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): if debug: print(f, id(f), input_types, '->', result) return result + + +def resolve_dataclass_field_type(x): + """ + Resolve a type to Beam typehint under global pipeline option context. + + Since Beam 2.75.0, typehints of dataclass fields are honored during type + inferences. However, in case of breakage (possible scenarios include + incorrect typehints; non-deterministic or nullable types disallowed by + consumer transform but check disabled by Any; tests rely on Any), + --exclude_infer_dataclass_field_type option to instruct falling back to Any. + Fields of builtin primitives are always respected. + """ + from apache_beam.options.pipeline_options_context import get_pipeline_options + options = get_pipeline_options() + if options: + from apache_beam.options.pipeline_options import TypeOptions + disabled = options.view_as(TypeOptions).exclude_infer_dataclass_field_type + else: + disabled = False + + if not disabled: + return typehints.normalize(x) + if x in (bool, bytes, complex, float, int, str): + return x + return Any diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index f421819bdcae..dcb0bac97e80 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -24,6 +24,8 @@ import unittest import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options_context import scoped_pipeline_options from apache_beam.typehints import row_type from apache_beam.typehints import trivial_inference from apache_beam.typehints import typehints @@ -489,15 +491,28 @@ def testPyCallable(self): [int]) def testDataClassFields(self): + @dataclasses.dataclass + class BaseClass: + pass + @dataclasses.dataclass class MyDataClass: id: int name: str + tags: list[str] + custom: BaseClass self.assertReturnType( - typehints.Tuple[int, str], - python_callable.PythonCallableWithSource("lambda x: (x.id, x.name)"), - [MyDataClass]) + typehints.Tuple[int, str, typehints.List[str], BaseClass], + python_callable.PythonCallableWithSource( + "lambda x: (x.id, x.name, x.tags, x.custom)"), [MyDataClass]) + + options = PipelineOptions(['--exclude_infer_dataclass_field_type']) + with scoped_pipeline_options(options): + self.assertReturnType( + typehints.Tuple[int, str, typehints.Any, typehints.Any], + python_callable.PythonCallableWithSource( + "lambda x: (x.id, x.name, x.tags, x.custom)"), [MyDataClass]) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 6dc88a93dd39..ffef40de6673 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1453,6 +1453,7 @@ def __getitem__(self, type_params): def normalize(x, none_as_type=False): + """Normalize a type to Beam typehint.""" # None is inconsistantly used for Any, unknown, or NoneType. # Avoid circular imports