From 6d676a110797011251ecbf14a13e3edf7116ae85 Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 25 Aug 2025 09:57:43 +0200 Subject: [PATCH 1/4] feat:added alias for cast and test --- src/substrait/builders/extended_expression.py | 77 +++++++++---------- .../builders/extended_expression/test_cast.py | 7 +- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index abda416..a728be2 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -1,18 +1,19 @@ -from datetime import date import itertools +from datetime import date +from typing import Any, Callable, Iterable, Union + import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stp import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.extensions.extensions_pb2 as ste +import substrait.gen.proto.type_pb2 as stp from substrait.extension_registry import ExtensionRegistry +from substrait.type_inference import infer_extended_expression_schema from substrait.utils import ( - type_num_names, - merge_extension_urns, - merge_extension_uris, merge_extension_declarations, + merge_extension_uris, + merge_extension_urns, + type_num_names, ) -from substrait.type_inference import infer_extended_expression_schema -from typing import Callable, Any, Union, Iterable UnboundExtendedExpression = Callable[ [stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression @@ -21,7 +22,7 @@ def _alias_or_inferred( - alias: Union[Iterable[str], str], + alias: Union[Iterable[str], str, None], op: str, args: Iterable[str], ): @@ -44,7 +45,7 @@ def resolve_expression( def literal( - value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None + value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None ) -> UnboundExtendedExpression: """Builds a resolver for ExtendedExpression containing a literal expression""" @@ -154,7 +155,7 @@ def resolve( return resolve -def column(field: Union[str, int], alias: Union[Iterable[str], str] = None): +def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None): """Builds a resolver for ExtendedExpression containing a FieldReference expression Accepts either an index or a field name of a desired field. @@ -208,7 +209,7 @@ def scalar_function( urn: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" @@ -306,7 +307,7 @@ def aggregate_function( urn: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" @@ -402,7 +403,7 @@ def window_function( function: str, expressions: Iterable[ExtendedExpressionOrUnbound], partitions: Iterable[ExtendedExpressionOrUnbound] = [], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" @@ -512,7 +513,7 @@ def resolve( def if_then( ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing an IfThen expression""" @@ -551,24 +552,16 @@ def resolve( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( - if_then=stalg.Expression.IfThen( - **{ - "ifs": [ - stalg.Expression.IfThen.IfClause( - **{ - "if": if_clause[0] - .referred_expr[0] - .expression, - "then": if_clause[1] - .referred_expr[0] - .expression, - } - ) - for if_clause in bound_ifs - ], - "else": bound_else.referred_expr[0].expression, - } - ) + if_then=stalg.Expression.IfThen(**{ + "ifs": [ + stalg.Expression.IfThen.IfClause(**{ + "if": if_clause[0].referred_expr[0].expression, + "then": if_clause[1].referred_expr[0].expression, + }) + for if_clause in bound_ifs + ], + "else": bound_else.referred_expr[0].expression, + }) ), output_names=_alias_or_inferred( alias, @@ -639,12 +632,10 @@ def resolve( switch_expression=stalg.Expression.SwitchExpression( match=bound_match.referred_expr[0].expression, ifs=[ - stalg.Expression.SwitchExpression.IfValue( - **{ - "if": i.referred_expr[0].expression.literal, - "then": t.referred_expr[0].expression, - } - ) + stalg.Expression.SwitchExpression.IfValue(**{ + "if": i.referred_expr[0].expression.literal, + "then": t.referred_expr[0].expression, + }) for i, t in bound_ifs ], **{"else": bound_else.referred_expr[0].expression}, @@ -767,7 +758,11 @@ def resolve( return resolve -def cast(input: ExtendedExpressionOrUnbound, type: stp.Type): +def cast( + input: ExtendedExpressionOrUnbound, + type: stp.Type, + alias: Union[Iterable[str], str, None] = None, +): """Builds a resolver for ExtendedExpression containing a cast expression""" def resolve( @@ -785,7 +780,9 @@ def resolve( failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], # TODO construct name from inputs + output_names=_alias_or_inferred( + alias, "cast", [bound_input.referred_expr[0].output_names[0]] + ), ) ], base_schema=base_schema, diff --git a/tests/builders/extended_expression/test_cast.py b/tests/builders/extended_expression/test_cast.py index 704f80d..bdad8d1 100644 --- a/tests/builders/extended_expression/test_cast.py +++ b/tests/builders/extended_expression/test_cast.py @@ -1,6 +1,6 @@ import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.type_pb2 as stt from substrait.builders.extended_expression import cast, literal from substrait.builders.type import i8, i16 from substrait.extension_registry import ExtensionRegistry @@ -37,7 +37,7 @@ def test_cast(): failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], + output_names=["cast(Literal(3))"], ) ], base_schema=named_struct, @@ -48,6 +48,7 @@ def test_cast(): def test_cast_with_extension(): import yaml + import substrait.gen.proto.extensions.extensions_pb2 as ste from substrait.builders.extended_expression import scalar_function @@ -134,7 +135,7 @@ def test_cast_with_extension(): failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], + output_names=["cast(add(Literal(1),Literal(2)))"], ) ], base_schema=named_struct, From 4bd0b3938089f77537b6798a5669dab955bf585c Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 1 Dec 2025 15:49:28 +0100 Subject: [PATCH 2/4] feat: update proto codegen and add typing_extension dependency --- pyproject.toml | 2 +- src/substrait/gen/json/simple_extensions.py | 26 +++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 677ef57..c1ddd51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" -dependencies = ["protobuf >=3.19.1,<6"] +dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 2885bb4..e323ac5 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -7,13 +7,15 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union +from typing_extensions import TypeAlias + class Functions(Enum): INHERITS = 'INHERITS' SEPARATE = 'SEPARATE' -Type = Union[str, Dict[str, Any]] +Type: TypeAlias = Union[str, Dict[str, Any]] class Type1(Enum): @@ -24,7 +26,7 @@ class Type1(Enum): string = 'string' -EnumOptions = List[str] +EnumOptions: TypeAlias = List[str] @dataclass @@ -49,7 +51,7 @@ class TypeArg: description: Optional[str] = None -Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] @dataclass @@ -58,7 +60,7 @@ class Options1: description: Optional[str] = None -Options = Dict[str, Options1] +Options: TypeAlias = Dict[str, Options1] class ParameterConsistency(Enum): @@ -73,10 +75,10 @@ class VariadicBehavior: parameterConsistency: Optional[ParameterConsistency] = None -Deterministic = bool +Deterministic: TypeAlias = bool -SessionDependent = bool +SessionDependent: TypeAlias = bool class NullabilityHandling(Enum): @@ -85,13 +87,13 @@ class NullabilityHandling(Enum): DISCRETE = 'DISCRETE' -ReturnValue = Type +ReturnValue: TypeAlias = Type -Implementation = Dict[str, str] +Implementation: TypeAlias = Dict[str, str] -Intermediate = Type +Intermediate: TypeAlias = Type class Decomposable(Enum): @@ -100,10 +102,10 @@ class Decomposable(Enum): MANY = 'MANY' -Maxset = float +Maxset: TypeAlias = float -Ordered = bool +Ordered: TypeAlias = bool @dataclass @@ -196,7 +198,7 @@ class TypeParamDef: optional: Optional[bool] = None -TypeParamDefs = List[TypeParamDef] +TypeParamDefs: TypeAlias = List[TypeParamDef] @dataclass From fd834d151a30fbe7128ea1a1588cc83e3a216c1c Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Wed, 3 Dec 2025 13:34:43 +0100 Subject: [PATCH 3/4] fix: run format and linter and reverted pyproject.toml --- pyproject.toml | 2 +- src/substrait/builders/extended_expression.py | 38 ++++++++++++------- src/substrait/gen/json/simple_extensions.py | 4 +- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c1ddd51..677ef57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.10" -dependencies = ["protobuf >=3.19.1,<6", "typing_extensions"] +dependencies = ["protobuf >=3.19.1,<6"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index a728be2..88ec5fd 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -552,16 +552,24 @@ def resolve( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( - if_then=stalg.Expression.IfThen(**{ - "ifs": [ - stalg.Expression.IfThen.IfClause(**{ - "if": if_clause[0].referred_expr[0].expression, - "then": if_clause[1].referred_expr[0].expression, - }) - for if_clause in bound_ifs - ], - "else": bound_else.referred_expr[0].expression, - }) + if_then=stalg.Expression.IfThen( + **{ + "ifs": [ + stalg.Expression.IfThen.IfClause( + **{ + "if": if_clause[0] + .referred_expr[0] + .expression, + "then": if_clause[1] + .referred_expr[0] + .expression, + } + ) + for if_clause in bound_ifs + ], + "else": bound_else.referred_expr[0].expression, + } + ) ), output_names=_alias_or_inferred( alias, @@ -632,10 +640,12 @@ def resolve( switch_expression=stalg.Expression.SwitchExpression( match=bound_match.referred_expr[0].expression, ifs=[ - stalg.Expression.SwitchExpression.IfValue(**{ - "if": i.referred_expr[0].expression.literal, - "then": t.referred_expr[0].expression, - }) + stalg.Expression.SwitchExpression.IfValue( + **{ + "if": i.referred_expr[0].expression.literal, + "then": t.referred_expr[0].expression, + } + ) for i, t in bound_ifs ], **{"else": bound_else.referred_expr[0].expression}, diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index e323ac5..765fbef 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -5,9 +5,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union - -from typing_extensions import TypeAlias +from typing import Any, Dict, List, Optional, TypeAlias, Union class Functions(Enum): From 2bf4e67744acee0fd920ab798b650e6f2f7810cc Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Wed, 3 Dec 2025 14:45:45 +0100 Subject: [PATCH 4/4] fix: make file --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 4e95eb6..670c6eb 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ codegen-extensions: --input third_party/substrait/text/simple_extensions_schema.yaml \ --output src/substrait/gen/json/simple_extensions.py \ --output-model-type dataclasses.dataclass \ + --target-python-version 3.10 \ --disable-timestamp lint: