From d21c9d45914a255b76f078941650908aa2a7f6de Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 18 Aug 2025 13:16:45 +0200 Subject: [PATCH 01/12] feat(types): add many new types and tighten nullability/precision checks Add wide-ranging support for additional scalar and parameterized Substrait types, improve parameter handling, and make nullability/precision checks stricter and more correct. Signed-off-by: MBWhite --- src/substrait/builders/type.py | 8 + src/substrait/derivation_expression.py | 104 ++++++++++- src/substrait/extension_registry.py | 141 ++++++++++++--- tests/test_extension_registry.py | 228 +++++++++++++++++-------- 4 files changed, 379 insertions(+), 102 deletions(-) diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 39ed5e6..c4877fe 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -221,6 +221,14 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: ) ) +def timestamp(nullable=True) -> stt.Type: + return stt.Type( + timestamp=stt.Type.Timestamp( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type( diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index f4d18d7..28be2d3 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -65,22 +65,116 @@ def _evaluate(x, values: dict): return Type(fp64=Type.FP64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): return Type(bool=Type.Boolean(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.StringContext): + return Type(string=Type.String(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext): + return Type(timestamp=Type.Timestamp(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.DateContext): + return Type(date=Type.Date(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext): + return Type(interval_year=Type.IntervalYear(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.UuidContext): + return Type(uuid=Type.UUID(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext): + return Type(binary=Type.Binary(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimeContext): + return Type(time=Type.Time(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext): + return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability)) else: raise Exception(f"Unknown scalar type {type(scalar_type)}") elif parametrized_type: + nullability = ( + Type.NULLABILITY_NULLABLE + if parametrized_type.isnull + else Type.NULLABILITY_REQUIRED + ) if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): precision = _evaluate(parametrized_type.precision, values) scale = _evaluate(parametrized_type.scale, values) - nullability = ( - Type.NULLABILITY_NULLABLE - if parametrized_type.isnull - else Type.NULLABILITY_REQUIRED - ) return Type( decimal=Type.Decimal( precision=precision, scale=scale, nullability=nullability ) ) + elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + varchar=Type.VarChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_char=Type.FixedChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_binary=Type.FixedBinary( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampContext): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp=Type.PrecisionTimestamp( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext): + return Type( + interval_year=Type.IntervalYear( + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.StructContext): + types = list(map(lambda x: _evaluate(x,values),parametrized_type.expr())) + return Type( + struct=Type.Struct( + types=types, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.ListContext): + type = _evaluate(parametrized_type.expr(),values) + return Type( + list=Type.List( + type=type, + nullability=nullability, + ) + ) + + elif isinstance(parametrized_type, SubstraitTypeParser.MapContext): + return Type( + map=Type.Map( + key=_evaluate(parametrized_type.key,values), + value=_evaluate(parametrized_type.value,values), + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext): + # it gives me a parser error i may have to update the parser + # string `evaluate("NSTRUCT")` from the docs https://substrait.io/types/type_classes/ + # line 1:17 extraneous input ':' + raise NotImplementedError("Named structure type not implemented yet") + # elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext): + raise Exception(f"Unknown parametrized type {type(parametrized_type)}") elif any_type: any_var = any_type.AnyVar() diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index c854c02..b2615e5 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str: raise Exception(f"Unrecognized substrait type {typ}") -def violates_integer_option(actual: int, option, parameters: dict): +def violates_integer_option(actual: int, option, parameters: dict, subset=False): + option_numeric = None if isinstance(option, SubstraitTypeParser.NumericLiteralContext): - return actual != int(str(option.Number())) + option_numeric = int(str(option.Number())) elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): parameter_name = str(option.Identifier()) - if parameter_name in parameters and parameters[parameter_name] != actual: - return True - else: + + if parameter_name not in parameters: parameters[parameter_name] = actual + option_numeric = parameters[parameter_name] else: raise Exception( f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" ) - - return False + if subset: + return actual < option_numeric + else: + return actual != option_numeric def types_equal(type1: Type, type2: Type, check_nullability=False): @@ -112,6 +115,27 @@ def handle_parameter_cover( return True +def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool: + if not check_nullability: + return True + # The ANTLR context stores a Token called ``isnull`` – it is + # present when the type is declared as nullable. + nullability = ( + Type.Nullability.NULLABILITY_NULLABLE + if getattr(parameterized_type, "isnull", None) is not None + else Type.Nullability.NULLABILITY_REQUIRED + ) + # if nullability == Type.Nullability.NULLABILITY_NULLABLE: + # return True # is still true even if the covered is required + # The protobuf message stores its own enum – we compare the two. + covered_nullability = getattr( + getattr(covered, kind), # e.g. covered.varchar + "nullability", + None, + ) + return nullability == covered_nullability + + def covers( covered: Type, covering: SubstraitTypeParser.TypeLiteralContext, @@ -123,7 +147,6 @@ def covers( return handle_parameter_cover( covered, parameter_name, parameters, check_nullability ) - covering: SubstraitTypeParser.TypeDefContext = covering.typeDef() any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType() @@ -142,31 +165,99 @@ def covers( parameterized_type = covering.parameterizedType() if parameterized_type: - if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - if covered.WhichOneof("kind") != "decimal": + kind = covered.WhichOneof("kind") + if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): + if kind != "varchar": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.varchar.length, parameterized_type.length, parameters + ): return False - nullability = ( - Type.NULLABILITY_NULLABLE - if parameterized_type.isnull - else Type.NULLABILITY_REQUIRED + return _check_nullability( + check_nullability, parameterized_type, covered, kind ) - - if ( - check_nullability - and nullability - != covered.__getattribute__(covered.WhichOneof("kind")).nullability + if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + if kind != "fixed_char": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.fixed_char.length, parameterized_type.length, parameters ): return False + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) + if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + if kind != "fixed_binary": + return False + if hasattr(parameterized_type, "length") and violates_integer_option( + covered.fixed_binary.length, parameterized_type.length, parameters + ): + return False + # return True + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + if kind != "decimal": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # precision / scale are both optional – a missing value means “no limit”. + covered_scale = getattr(covered.decimal, "scale", 0) + param_scale = getattr(parameterized_type, "scale", 0) + covered_prec = getattr(covered.decimal, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) return not ( - violates_integer_option( - covered.decimal.scale, parameterized_type.scale, parameters - ) - or violates_integer_option( - covered.decimal.precision, parameterized_type.precision, parameters - ) + violates_integer_option(covered_scale, param_scale, parameters) + or violates_integer_option(covered_prec, param_prec, parameters) ) + if isinstance( + parameterized_type, SubstraitTypeParser.PrecisionTimestampContext + ): + if kind != "precision_timestamp": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # return True + covered_prec = getattr(covered.precision_timestamp, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) + return not violates_integer_option(covered_prec, param_prec, parameters) + + if isinstance( + parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext + ): + if kind != "precision_timestamp_tz": + return False + if not _check_nullability( + check_nullability, parameterized_type, covered, kind + ): + return False + # return True + covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0) + param_prec = getattr(parameterized_type, "precision", 0) + return not violates_integer_option(covered_prec, param_prec, parameters) + + kind_mapping = { + SubstraitTypeParser.ListContext: "list", + SubstraitTypeParser.MapContext: "map", + SubstraitTypeParser.StructContext: "struct", + SubstraitTypeParser.UserDefinedContext: "user_defined", + SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day", + } + + for ctx_cls, expected_kind in kind_mapping.items(): + if isinstance(parameterized_type, ctx_cls): + if kind != expected_kind: + return False + return _check_nullability( + check_nullability, parameterized_type, covered, kind + ) else: raise Exception(f"Unhandled type {type(parameterized_type)}") diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index f9d63bd..2036a68 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -4,6 +4,11 @@ from substrait.gen.proto.type_pb2 import Type from substrait.extension_registry import ExtensionRegistry, covers from substrait.derivation_expression import _parse +from substrait.builders.type import ( + i8, + i16, + decimal, +) content = """%YAML 1.2 --- @@ -104,10 +109,19 @@ value: decimal nullability: DISCRETE return: decimal? + - name: "equal_test" + impls: + - args: + - name: x + value: any + - name: y + value: any + nullability: DISCRETE + return: any """ -registry = ExtensionRegistry() +registry = ExtensionRegistry(load_default_extensions=True) registry.register_extension_dict( yaml.safe_load(content), @@ -115,52 +129,11 @@ ) -def i8(nullable=False): - return Type( - i8=Type.I8( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def i16(nullable=False): - return Type( - i16=Type.I16( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def bool(nullable=False): - return Type( - bool=Type.Boolean( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def decimal(precision, scale, nullable=False): - return Type( - decimal=Type.Decimal( - scale=scale, - precision=precision, - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE, - ) - ) - def test_non_existing_urn(): assert ( registry.lookup_function( - urn="non_existent", function_name="add", signature=[i8(), i8()] + urn="non_existent", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] ) is None ) @@ -169,7 +142,8 @@ def test_non_existing_urn(): def test_non_existing_function(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="sub", signature=[i8(), i8()] + + urn="extension:test:functions", function_name="sub", signature=[i8(nullable=False), i8(nullable=False)] ) is None ) @@ -178,7 +152,7 @@ def test_non_existing_function(): def test_non_existing_function_signature(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8()] + urn="extension:test:functions", function_name="add", signature=[i8(nullable=False)] ) is None ) @@ -186,7 +160,7 @@ def test_non_existing_function_signature(): def test_exact_match(): assert registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8(), i8()] + urn="extension:test:functions", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) @@ -194,7 +168,7 @@ def test_wildcard_match(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i8(), bool()], + signature=[i8(nullable=False), i8(nullable=False), bool()], )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) @@ -203,7 +177,7 @@ def test_wildcard_match_fails_with_constraits(): registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i16(), i16()], + signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)], ) is None ) @@ -214,9 +188,9 @@ def test_wildcard_match_with_constraits(): registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i16(), i16(), i8()], + signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -225,9 +199,9 @@ def test_variadic(): registry.lookup_function( urn="extension:test:functions", function_name="test_fn", - signature=[i8(), i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -236,16 +210,17 @@ def test_variadic_any(): registry.lookup_function( urn="extension:test:functions", function_name="test_fn_variadic_any", - signature=[i16(), i16(), i16()], + signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], )[1] - == i16() + == i16(nullable=False) ) def test_variadic_fails_min_constraint(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="test_fn", signature=[i8()] + + urn="extension:test:functions", function_name="test_fn", signature=[i8(nullable=False)] ) is None ) @@ -255,8 +230,8 @@ def test_decimal_happy_path(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(8, 6)], - )[1] == decimal(11, 7) + signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=False) def test_decimal_violates_constraint(): @@ -264,7 +239,7 @@ def test_decimal_violates_constraint(): registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(12, 10)], + signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)], ) is None ) @@ -274,8 +249,8 @@ def test_decimal_happy_path_discrete(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal_discrete", - signature=[decimal(10, 8, nullable=True), decimal(8, 6)], - )[1] == decimal(11, 7, nullable=True) + signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=True) def test_enum_with_valid_option(): @@ -283,9 +258,9 @@ def test_enum_with_valid_option(): registry.lookup_function( urn="extension:test:functions", function_name="test_enum", - signature=["FLIP", i8()], + signature=["FLIP", i8(nullable=False)], )[1] - == i8() + == i8(nullable=False) ) @@ -294,7 +269,7 @@ def test_enum_with_nonexistent_option(): registry.lookup_function( urn="extension:test:functions", function_name="test_enum", - signature=["NONEXISTENT", i8()], + signature=["NONEXISTENT", i8(nullable=False)], ) is None ) @@ -304,7 +279,7 @@ def test_function_with_nullable_args(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -312,7 +287,7 @@ def test_function_with_declared_output_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_declared", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], )[1] == i8(nullable=True) @@ -320,7 +295,7 @@ def test_function_with_discrete_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -329,7 +304,7 @@ def test_function_with_discrete_nullability_nonexisting(): registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -337,7 +312,7 @@ def test_function_with_discrete_nullability_nonexisting(): def test_covers(): params = {} - assert covers(i8(), _parse("i8"), params) + assert covers(i8(nullable=False), _parse("i8"), params) assert params == {} @@ -346,18 +321,127 @@ def test_covers_nullability(): assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) -def test_covers_decimal(): - assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) +def test_covers_decimal(nullable=False): + assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) def test_covers_decimal_happy_path(): params = {} - assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), params) assert params == {"A": 8} def test_covers_any(): - assert covers(decimal(10, 8), _parse("any"), {}) + assert covers(decimal(8, 10), _parse("any"), {}) + + +def test_covers_varchar_length_ok(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=15) + ) + param_ctx = _parse("varchar<15>") + assert covers(covered, param_ctx, {}, check_nullability=True) + + +def test_covers_varchar_length_fail(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("varchar<5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_varchar_nullability(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_tx = _parse("varchar?<10>") + assert covers(covered, param_tx, {}) + assert not covers(covered, param_tx, {}, True) + param_ctx2 = _parse("varchar<10>") + assert covers(covered, param_ctx2, {}, True) + + +def test_covers_fixed_char_length_ok(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<8>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_char_length_fail(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<4>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_ok(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<16>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_fail(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<10>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_decimal_precision_scale_fail(): + covered = decimal(8, 10, nullable=False) + param_ctx = _parse("decimal<6, 5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_ok(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=5 + ) + ) + param_ctx = _parse("precision_timestamp<5>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_fail(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=3 + ) + ) + param_ctx = _parse("precision_timestamp<2>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_ok(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<4>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp_tz") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_fail(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<3>") + assert not covers(covered, param_ctx, {}) def test_registry_uri_urn(): From 967e54b4d9431d94d0bb7bb3ae091d0d11074f1f Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Fri, 5 Dec 2025 10:21:14 +0100 Subject: [PATCH 02/12] fix: formatting and linter errors --- src/substrait/builders/type.py | 2 + src/substrait/derivation_expression.py | 24 ++++++--- tests/test_extension_registry.py | 75 ++++++++++++-------------- 3 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index c4877fe..23e51dd 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -221,6 +221,7 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: ) ) + def timestamp(nullable=True) -> stt.Type: return stt.Type( timestamp=stt.Type.Timestamp( @@ -230,6 +231,7 @@ def timestamp(nullable=True) -> stt.Type: ) ) + def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type( struct=stt.Type.Struct( diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 28be2d3..ebeed7b 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,5 +1,7 @@ from typing import Optional -from antlr4 import InputStream, CommonTokenStream + +from antlr4 import CommonTokenStream, InputStream + from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser from substrait.gen.proto.type_pb2 import Type @@ -121,7 +123,9 @@ def _evaluate(x, values: dict): nullability=nullability, ) ) - elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampContext): + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampContext + ): precision = _evaluate(parametrized_type.precision, values) return Type( precision_timestamp=Type.PrecisionTimestamp( @@ -129,7 +133,9 @@ def _evaluate(x, values: dict): nullability=nullability, ) ) - elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext): + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext + ): precision = _evaluate(parametrized_type.precision, values) return Type( precision_timestamp_tz=Type.PrecisionTimestampTZ( @@ -144,7 +150,9 @@ def _evaluate(x, values: dict): ) ) elif isinstance(parametrized_type, SubstraitTypeParser.StructContext): - types = list(map(lambda x: _evaluate(x,values),parametrized_type.expr())) + types = list( + map(lambda x: _evaluate(x, values), parametrized_type.expr()) + ) return Type( struct=Type.Struct( types=types, @@ -152,10 +160,10 @@ def _evaluate(x, values: dict): ) ) elif isinstance(parametrized_type, SubstraitTypeParser.ListContext): - type = _evaluate(parametrized_type.expr(),values) + list_type = _evaluate(parametrized_type.expr(), values) return Type( list=Type.List( - type=type, + type=list_type, nullability=nullability, ) ) @@ -163,8 +171,8 @@ def _evaluate(x, values: dict): elif isinstance(parametrized_type, SubstraitTypeParser.MapContext): return Type( map=Type.Map( - key=_evaluate(parametrized_type.key,values), - value=_evaluate(parametrized_type.value,values), + key=_evaluate(parametrized_type.key, values), + value=_evaluate(parametrized_type.value, values), nullability=nullability, ) ) diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index 2036a68..c4964e6 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -129,11 +129,12 @@ ) - def test_non_existing_urn(): assert ( registry.lookup_function( - urn="non_existent", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] + urn="non_existent", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -142,8 +143,9 @@ def test_non_existing_urn(): def test_non_existing_function(): assert ( registry.lookup_function( - - urn="extension:test:functions", function_name="sub", signature=[i8(nullable=False), i8(nullable=False)] + urn="extension:test:functions", + function_name="sub", + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -152,7 +154,9 @@ def test_non_existing_function(): def test_non_existing_function_signature(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8(nullable=False)] + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False)], ) is None ) @@ -160,7 +164,9 @@ def test_non_existing_function_signature(): def test_exact_match(): assert registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8(nullable=False), i8(nullable=False)] + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) @@ -184,43 +190,35 @@ def test_wildcard_match_fails_with_constraits(): def test_wildcard_match_with_constraits(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], - )[1] - == i8(nullable=False) - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) def test_variadic(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn", - signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], - )[1] - == i8(nullable=False) - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) def test_variadic_any(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn_variadic_any", - signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], - )[1] - == i16(nullable=False) - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn_variadic_any", + signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], + )[1] == i16(nullable=False) def test_variadic_fails_min_constraint(): assert ( registry.lookup_function( - - urn="extension:test:functions", function_name="test_fn", signature=[i8(nullable=False)] + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False)], ) is None ) @@ -254,14 +252,11 @@ def test_decimal_happy_path_discrete(): def test_enum_with_valid_option(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_enum", - signature=["FLIP", i8(nullable=False)], - )[1] - == i8(nullable=False) - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_enum", + signature=["FLIP", i8(nullable=False)], + )[1] == i8(nullable=False) def test_enum_with_nonexistent_option(): From 1806cb4f6620ca434c838c36ed16e629ac49db5e Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:09:05 +0100 Subject: [PATCH 03/12] feat: added evaluate for NamedStruct --- src/substrait/derivation_expression.py | 20 ++++++--- tests/test_derivation_expression.py | 59 +++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index ebeed7b..85f6fd8 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,10 +1,12 @@ +import pdb from typing import Optional from antlr4 import CommonTokenStream, InputStream +from datafusion.functions import named_struct from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser -from substrait.gen.proto.type_pb2 import Type +from substrait.gen.proto.type_pb2 import NamedStruct, Type def _evaluate(x, values: dict): @@ -177,11 +179,17 @@ def _evaluate(x, values: dict): ) ) elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext): - # it gives me a parser error i may have to update the parser - # string `evaluate("NSTRUCT")` from the docs https://substrait.io/types/type_classes/ - # line 1:17 extraneous input ':' - raise NotImplementedError("Named structure type not implemented yet") - # elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext): + names = list(map(lambda k: k.getText(), parametrized_type.Identifier())) + struct = Type.Struct( + types=list( + map(lambda k: _evaluate(k, values), parametrized_type.expr()) + ), + nullability=nullability, + ) + return NamedStruct( + names=names, + struct=struct, + ) raise Exception(f"Unknown parametrized type {type(parametrized_type)}") elif any_type: diff --git a/tests/test_derivation_expression.py b/tests/test_derivation_expression.py index 4b11b3d..d9d2bcd 100644 --- a/tests/test_derivation_expression.py +++ b/tests/test_derivation_expression.py @@ -1,7 +1,8 @@ -from substrait.gen.proto.type_pb2 import Type +from substrait.gen.proto.type_pb2 import NamedStruct, Type from substrait.derivation_expression import evaluate + def test_simple_arithmetic(): assert evaluate("1 + 1") == 2 @@ -113,3 +114,59 @@ def func(P1, S1, P2, S2): ) == func_eval ) + + +def test_struct_simple(): + """Test simple struct with two i32 fields.""" + result = evaluate("struct", {}) + expected = Type( + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected + + +def test_nstruct_simple(): + """Test named struct with field names and types.""" + result = evaluate("nStruct", {}) + expected = NamedStruct( + names=["a", "b"], + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected + + +def test_nstruct_nested(): + """Test named struct with nested struct field.""" + result = evaluate("nStruct>", {}) + expected = NamedStruct( + names=["a", "b", "c"], + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type( + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected From 3a0ffd8980fc63f79764d6007b0af71c43f6dddf Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:09:26 +0100 Subject: [PATCH 04/12] fix: removed builder for timestamp --- src/substrait/builders/type.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 23e51dd..92b560a 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -1,4 +1,5 @@ from typing import Iterable + import substrait.gen.proto.type_pb2 as stt @@ -222,16 +223,6 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: ) -def timestamp(nullable=True) -> stt.Type: - return stt.Type( - timestamp=stt.Type.Timestamp( - nullability=stt.Type.NULLABILITY_NULLABLE - if nullable - else stt.Type.NULLABILITY_REQUIRED, - ) - ) - - def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type( struct=stt.Type.Struct( From 3be974083af86ec4ac1483ad6d03182db0cde76f Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:09:56 +0100 Subject: [PATCH 05/12] ref: extension registry with support to list map and struct --- src/substrait/extension_registry.py | 245 ++++++++++++++++------------ tests/test_extension_registry.py | 71 +++++++- 2 files changed, 204 insertions(+), 112 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index b2615e5..a78eefe 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -1,17 +1,19 @@ -import yaml import itertools import re -from substrait.gen.proto.type_pb2 import Type -from importlib.resources import files as importlib_files from collections import defaultdict +from importlib.resources import files as importlib_files from pathlib import Path from typing import Optional, Union -from .derivation_expression import evaluate, _evaluate, _parse + +import yaml + from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser from substrait.gen.json import simple_extensions as se +from substrait.gen.proto.type_pb2 import Type from substrait.simple_extension_utils import build_simple_extensions -from .bimap import UriUrnBiDiMap +from .bimap import UriUrnBiDiMap +from .derivation_expression import _evaluate, _parse, evaluate DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" @@ -68,7 +70,7 @@ def normalize_substrait_type_names(typ: str) -> str: raise Exception(f"Unrecognized substrait type {typ}") -def violates_integer_option(actual: int, option, parameters: dict, subset=False): +def violates_integer_option(actual: int, option, parameters: dict): option_numeric = None if isinstance(option, SubstraitTypeParser.NumericLiteralContext): option_numeric = int(str(option.Number())) @@ -82,10 +84,7 @@ def violates_integer_option(actual: int, option, parameters: dict, subset=False) raise Exception( f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" ) - if subset: - return actual < option_numeric - else: - return actual != option_numeric + return actual != option_numeric def types_equal(type1: Type, type2: Type, check_nullability=False): @@ -165,101 +164,133 @@ def covers( parameterized_type = covering.parameterizedType() if parameterized_type: - kind = covered.WhichOneof("kind") - if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): - if kind != "varchar": - return False - if hasattr(parameterized_type, "length") and violates_integer_option( - covered.varchar.length, parameterized_type.length, parameters - ): - return False + return _cover_parametrize_type( + covered, parameterized_type, parameters, check_nullability + ) - return _check_nullability( - check_nullability, parameterized_type, covered, kind - ) - if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): - if kind != "fixed_char": - return False - if hasattr(parameterized_type, "length") and violates_integer_option( - covered.fixed_char.length, parameterized_type.length, parameters - ): - return False - return _check_nullability( - check_nullability, parameterized_type, covered, kind - ) - if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): - if kind != "fixed_binary": - return False - if hasattr(parameterized_type, "length") and violates_integer_option( - covered.fixed_binary.length, parameterized_type.length, parameters - ): - return False - # return True - return _check_nullability( - check_nullability, parameterized_type, covered, kind - ) - if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - if kind != "decimal": - return False - if not _check_nullability( - check_nullability, parameterized_type, covered, kind - ): - return False - # precision / scale are both optional – a missing value means “no limit”. - covered_scale = getattr(covered.decimal, "scale", 0) - param_scale = getattr(parameterized_type, "scale", 0) - covered_prec = getattr(covered.decimal, "precision", 0) - param_prec = getattr(parameterized_type, "precision", 0) - return not ( - violates_integer_option(covered_scale, param_scale, parameters) - or violates_integer_option(covered_prec, param_prec, parameters) - ) - if isinstance( - parameterized_type, SubstraitTypeParser.PrecisionTimestampContext - ): - if kind != "precision_timestamp": - return False - if not _check_nullability( - check_nullability, parameterized_type, covered, kind - ): - return False - # return True - covered_prec = getattr(covered.precision_timestamp, "precision", 0) - param_prec = getattr(parameterized_type, "precision", 0) - return not violates_integer_option(covered_prec, param_prec, parameters) +def check_violates_integer_option_parameters( + covered, parameterized_type, attributes, parameters +): + for attr in attributes: + if not hasattr(covered, attr) and not hasattr(parameterized_type, attr): + return False + covered_attr = getattr(covered, attr) + param_attr = getattr(parameterized_type, attr) + if violates_integer_option(covered_attr, param_attr, parameters): + return True + return False + - if isinstance( - parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext +def _cover_parametrize_type( + covered: Type, + parameterized_type: SubstraitTypeParser.ParameterizedTypeContext, + parameters: dict, + check_nullability=False, +): + kind = covered.WhichOneof("kind") + + if not _check_nullability(check_nullability, parameterized_type, covered, kind): + return False + + if ( + isinstance(parameterized_type, SubstraitTypeParser.VarCharContext) + and kind == "varchar" + ): + if hasattr( + parameterized_type, "length" + ) and check_violates_integer_option_parameters( + covered.varchar, parameterized_type, ["length"], parameters ): - if kind != "precision_timestamp_tz": - return False - if not _check_nullability( - check_nullability, parameterized_type, covered, kind - ): + return False + elif ( + isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext) + and kind == "fixed_char" + ): + if hasattr( + parameterized_type, "length" + ) and check_violates_integer_option_parameters( + covered.fixed_char, parameterized_type, ["length"], parameters + ): + return False + + elif ( + isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext) + and kind == "fixed_binary" + ): + if hasattr( + parameterized_type, "length" + ) and check_violates_integer_option_parameters( + covered.fixed_binary, parameterized_type, ["length"], parameters + ): + return False + elif ( + isinstance(parameterized_type, SubstraitTypeParser.DecimalContext) + and kind == "decimal" + ): + if not _check_nullability(check_nullability, parameterized_type, covered, kind): + return False + return not check_violates_integer_option_parameters( + covered.decimal, parameterized_type, ["scale", "precision"], parameters + ) + elif ( + isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext) + and kind == "precision_timestamp" + ): + return not check_violates_integer_option_parameters( + covered.precision_timestamp, parameterized_type, ["precision"], parameters + ) + elif ( + isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext) + and kind == "precision_timestamp_tz" + ): + return not check_violates_integer_option_parameters( + covered.precision_timestamp_tz, + parameterized_type, + ["precision"], + parameters, + ) + + elif ( + isinstance(parameterized_type, SubstraitTypeParser.ListContext) + and kind == "list" + ): + covered_element_type = covered.list.type + param_element_ctx = parameterized_type.expr() + return covers( + covered_element_type, param_element_ctx, parameters, check_nullability + ) + + elif ( + isinstance(parameterized_type, SubstraitTypeParser.MapContext) and kind == "map" + ): + covered_key_type = covered.map.key + covered_value_type = covered.map.value + param_key_ctx = parameterized_type.key + param_value_ctx = parameterized_type.value + return covers( + covered_key_type, param_key_ctx, parameters, check_nullability + ) and covers(covered_value_type, param_value_ctx, parameters, check_nullability) + + elif ( + isinstance(parameterized_type, SubstraitTypeParser.StructContext) + and kind == "struct" + ): + covered_types = covered.struct.types + param_types = parameterized_type.expr() or [] + if not isinstance(param_types, list): + param_types = [param_types] + if len(covered_types) != len(param_types): + return False + for covered_field, param_field_ctx in zip(covered_types, param_types): + if not covers( + covered_field, param_field_ctx, parameters, check_nullability + ): # type: ignore return False - # return True - covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0) - param_prec = getattr(parameterized_type, "precision", 0) - return not violates_integer_option(covered_prec, param_prec, parameters) - - kind_mapping = { - SubstraitTypeParser.ListContext: "list", - SubstraitTypeParser.MapContext: "map", - SubstraitTypeParser.StructContext: "struct", - SubstraitTypeParser.UserDefinedContext: "user_defined", - SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day", - } - - for ctx_cls, expected_kind in kind_mapping.items(): - if isinstance(parameterized_type, ctx_cls): - if kind != expected_kind: - return False - return _check_nullability( - check_nullability, parameterized_type, covered, kind - ) - else: - raise Exception(f"Unhandled type {type(parameterized_type)}") + else: + # Unsupported params type or type miss match + return False + return True class FunctionEntry: @@ -322,14 +353,12 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: output_type = evaluate(self.impl.return_, parameters) if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any( - [ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ] - ) + sig_contains_nullable = any([ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ]) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( Type.NULLABILITY_NULLABLE if sig_contains_nullable diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index c4964e6..1b17f5c 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -1,14 +1,22 @@ import pytest import yaml -from substrait.gen.proto.type_pb2 import Type -from substrait.extension_registry import ExtensionRegistry, covers -from substrait.derivation_expression import _parse from substrait.builders.type import ( + decimal, i8, i16, - decimal, + i32, + struct, +) +from substrait.builders.type import ( + list as list_, +) +from substrait.builders.type import ( + map as map_, ) +from substrait.derivation_expression import _parse +from substrait.extension_registry import ExtensionRegistry, covers +from substrait.gen.proto.type_pb2 import Type content = """%YAML 1.2 --- @@ -318,6 +326,11 @@ def test_covers_nullability(): def test_covers_decimal(nullable=False): assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<10, 9>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 9>"), {}) def test_covers_decimal_happy_path(): @@ -568,3 +581,53 @@ def test_register_requires_uri(): # During migration, URI is required - this should fail with TypeError with pytest.raises(TypeError): registry.register_extension_dict(yaml.safe_load(content)) + + +def test_covers_list_of_i8(): + """Test that a list of i8 covers list.""" + covered = list_(i8(nullable=False), nullable=False) + param_ctx = _parse("list") + assert covers(covered, param_ctx, {}) + + +def test_covers_map_string_to_i8(): + """Test that a map with string keys and i8 values covers map.""" + covered = map_( + key=Type(string=Type.String(nullability=Type.NULLABILITY_REQUIRED)), + value=i8(nullable=False), + nullable=False, + ) + param_ctx = _parse("map") + assert covers(covered, param_ctx, {}) + + +def test_covers_struct_with_two_fields(): + """Test that a struct with two i8 fields covers struct.""" + covered = struct([i8(nullable=False), i8(nullable=False)], nullable=False) + param_ctx = _parse("struct") + assert covers(covered, param_ctx, {}) + + +def test_covers_list_of_i16_fails_i8(): + """Test that a list of i16 does not cover list.""" + covered = list_(i16(nullable=False), nullable=False) + param_ctx = _parse("list") + assert not covers(covered, param_ctx, {}) + + +def test_covers_map_i8_to_i16_fails(): + """Test that a map with i8 keys and i16 values does not cover map.""" + covered = map_( + key=i8(nullable=False), + value=i16(nullable=False), + nullable=False, + ) + param_ctx = _parse("map") + assert not covers(covered, param_ctx, {}) + + +def test_covers_struct_mismatched_types_fails(): + """Test that a struct with mismatched field types does not cover struct.""" + covered = struct([i32(nullable=False), i8(nullable=False)], nullable=False) + param_ctx = _parse("struct") + assert not covers(covered, param_ctx, {}) From 40e04fe4fa8cddc9314450438178ab4e28a3edcc Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:12:34 +0100 Subject: [PATCH 06/12] fix: removed unused import of pdb --- src/substrait/derivation_expression.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 85f6fd8..97ef8c8 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,4 +1,3 @@ -import pdb from typing import Optional from antlr4 import CommonTokenStream, InputStream From 41d7d7483d700da736054e7778c93ff041cdcc8a Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:21:32 +0100 Subject: [PATCH 07/12] feat: restored Unhandled type exception --- src/substrait/extension_registry.py | 66 +++++++++++++---------------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index a78eefe..9c6425e 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -193,20 +193,18 @@ def _cover_parametrize_type( if not _check_nullability(check_nullability, parameterized_type, covered, kind): return False - if ( - isinstance(parameterized_type, SubstraitTypeParser.VarCharContext) - and kind == "varchar" - ): + if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): + if kind != "varchar": + return False if hasattr( parameterized_type, "length" ) and check_violates_integer_option_parameters( covered.varchar, parameterized_type, ["length"], parameters ): return False - elif ( - isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext) - and kind == "fixed_char" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + if kind != "fixed_char": + return False if hasattr( parameterized_type, "length" ) and check_violates_integer_option_parameters( @@ -214,36 +212,34 @@ def _cover_parametrize_type( ): return False - elif ( - isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext) - and kind == "fixed_binary" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + if kind != "fixed_binary": + return False if hasattr( parameterized_type, "length" ) and check_violates_integer_option_parameters( covered.fixed_binary, parameterized_type, ["length"], parameters ): return False - elif ( - isinstance(parameterized_type, SubstraitTypeParser.DecimalContext) - and kind == "decimal" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + if kind != "decimal": + return False if not _check_nullability(check_nullability, parameterized_type, covered, kind): return False return not check_violates_integer_option_parameters( covered.decimal, parameterized_type, ["scale", "precision"], parameters ) - elif ( - isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext) - and kind == "precision_timestamp" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): + if kind != "precision_timestamp": + return False return not check_violates_integer_option_parameters( covered.precision_timestamp, parameterized_type, ["precision"], parameters ) - elif ( - isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext) - and kind == "precision_timestamp_tz" + elif isinstance( + parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext ): + if kind != "precision_timestamp_tz": + return False return not check_violates_integer_option_parameters( covered.precision_timestamp_tz, parameterized_type, @@ -251,19 +247,18 @@ def _cover_parametrize_type( parameters, ) - elif ( - isinstance(parameterized_type, SubstraitTypeParser.ListContext) - and kind == "list" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.ListContext): + if kind != "list": + return False covered_element_type = covered.list.type param_element_ctx = parameterized_type.expr() return covers( covered_element_type, param_element_ctx, parameters, check_nullability ) - elif ( - isinstance(parameterized_type, SubstraitTypeParser.MapContext) and kind == "map" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.MapContext): + if kind != "map": + return False covered_key_type = covered.map.key covered_value_type = covered.map.value param_key_ctx = parameterized_type.key @@ -272,10 +267,9 @@ def _cover_parametrize_type( covered_key_type, param_key_ctx, parameters, check_nullability ) and covers(covered_value_type, param_value_ctx, parameters, check_nullability) - elif ( - isinstance(parameterized_type, SubstraitTypeParser.StructContext) - and kind == "struct" - ): + elif isinstance(parameterized_type, SubstraitTypeParser.StructContext): + if kind != "struct": + return False covered_types = covered.struct.types param_types = parameterized_type.expr() or [] if not isinstance(param_types, list): @@ -288,8 +282,8 @@ def _cover_parametrize_type( ): # type: ignore return False else: - # Unsupported params type or type miss match - return False + raise Exception(f"Unhandled type {type(parameterized_type)}") + return True From 801c8dc1ce3661346291e29a6a227c80942958a8 Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:24:26 +0100 Subject: [PATCH 08/12] feat: runned ruf and linter --- src/substrait/derivation_expression.py | 1 - src/substrait/extension_registry.py | 14 ++++++++------ tests/test_derivation_expression.py | 5 ++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 97ef8c8..17950c6 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,7 +1,6 @@ from typing import Optional from antlr4 import CommonTokenStream, InputStream -from datafusion.functions import named_struct from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index 9c6425e..1d5d6cd 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -347,12 +347,14 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: output_type = evaluate(self.impl.return_, parameters) if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any([ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ]) + sig_contains_nullable = any( + [ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ] + ) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( Type.NULLABILITY_NULLABLE if sig_contains_nullable diff --git a/tests/test_derivation_expression.py b/tests/test_derivation_expression.py index d9d2bcd..68c29b0 100644 --- a/tests/test_derivation_expression.py +++ b/tests/test_derivation_expression.py @@ -2,7 +2,6 @@ from substrait.derivation_expression import evaluate - def test_simple_arithmetic(): assert evaluate("1 + 1") == 2 @@ -142,7 +141,7 @@ def test_nstruct_simple(): Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), ], nullability=Type.NULLABILITY_REQUIRED, - ) + ), ) assert result == expected @@ -167,6 +166,6 @@ def test_nstruct_nested(): ), ], nullability=Type.NULLABILITY_REQUIRED, - ) + ), ) assert result == expected From 484513ae6873fb0644b2ec12217a11da2a2b5b6e Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 16:59:22 +0100 Subject: [PATCH 09/12] fix : removed redundant checks --- src/substrait/extension_registry.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index 1d5d6cd..bb70832 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -196,18 +196,14 @@ def _cover_parametrize_type( if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): if kind != "varchar": return False - if hasattr( - parameterized_type, "length" - ) and check_violates_integer_option_parameters( + if check_violates_integer_option_parameters( covered.varchar, parameterized_type, ["length"], parameters ): return False elif isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): if kind != "fixed_char": return False - if hasattr( - parameterized_type, "length" - ) and check_violates_integer_option_parameters( + if check_violates_integer_option_parameters( covered.fixed_char, parameterized_type, ["length"], parameters ): return False @@ -215,9 +211,7 @@ def _cover_parametrize_type( elif isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): if kind != "fixed_binary": return False - if hasattr( - parameterized_type, "length" - ) and check_violates_integer_option_parameters( + if check_violates_integer_option_parameters( covered.fixed_binary, parameterized_type, ["length"], parameters ): return False @@ -347,14 +341,12 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: output_type = evaluate(self.impl.return_, parameters) if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any( - [ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ] - ) + sig_contains_nullable = any([ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ]) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( Type.NULLABILITY_NULLABLE if sig_contains_nullable From fb0f498b57cb5c4356a6421baafc9f8cdb25a78f Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Tue, 9 Dec 2025 00:14:06 +0100 Subject: [PATCH 10/12] Fix typo in function name from parametrize to parametrized --- src/substrait/extension_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index bb70832..f6d9e4f 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -164,7 +164,7 @@ def covers( parameterized_type = covering.parameterizedType() if parameterized_type: - return _cover_parametrize_type( + return _cover_parametrized_type( covered, parameterized_type, parameters, check_nullability ) @@ -182,7 +182,7 @@ def check_violates_integer_option_parameters( return False -def _cover_parametrize_type( +def _cover_parametrized_type( covered: Type, parameterized_type: SubstraitTypeParser.ParameterizedTypeContext, parameters: dict, From 3d9df5c5e6ea5a371397014a216b62b31e8422b2 Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Tue, 9 Dec 2025 15:07:18 +0100 Subject: [PATCH 11/12] ref: made _cover_parametrize_type clearer --- src/substrait/extension_registry.py | 126 ++++++++++++++-------------- 1 file changed, 64 insertions(+), 62 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index f6d9e4f..ec5a571 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -194,74 +194,74 @@ def _cover_parametrized_type( return False if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): - if kind != "varchar": - return False - if check_violates_integer_option_parameters( + return kind == "varchar" and not check_violates_integer_option_parameters( covered.varchar, parameterized_type, ["length"], parameters - ): - return False - elif isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): - if kind != "fixed_char": - return False - if check_violates_integer_option_parameters( + ) + + if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + return kind == "fixed_char" and not check_violates_integer_option_parameters( covered.fixed_char, parameterized_type, ["length"], parameters - ): - return False + ) - elif isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): - if kind != "fixed_binary": - return False - if check_violates_integer_option_parameters( + if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + return kind == "fixed_binary" and not check_violates_integer_option_parameters( covered.fixed_binary, parameterized_type, ["length"], parameters - ): - return False - elif isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - if kind != "decimal": - return False - if not _check_nullability(check_nullability, parameterized_type, covered, kind): - return False - return not check_violates_integer_option_parameters( - covered.decimal, parameterized_type, ["scale", "precision"], parameters ) - elif isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): - if kind != "precision_timestamp": - return False - return not check_violates_integer_option_parameters( - covered.precision_timestamp, parameterized_type, ["precision"], parameters + + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + return ( + kind == "decimal" + and _check_nullability(check_nullability, parameterized_type, covered, kind) + and not check_violates_integer_option_parameters( + covered.decimal, parameterized_type, ["scale", "precision"], parameters + ) ) - elif isinstance( - parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext - ): - if kind != "precision_timestamp_tz": - return False - return not check_violates_integer_option_parameters( - covered.precision_timestamp_tz, - parameterized_type, - ["precision"], + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): + return ( + kind == "precision_timestamp" + and not check_violates_integer_option_parameters( + covered.precision_timestamp, + parameterized_type, + ["precision"], + parameters, + ) + ) + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext): + return ( + kind == "precision_timestamp_tz" + and not check_violates_integer_option_parameters( + covered.precision_timestamp_tz, + parameterized_type, + ["precision"], + parameters, + ) + ) + + if isinstance(parameterized_type, SubstraitTypeParser.ListContext): + return kind == "list" and covers( + covered.list.type, + parameterized_type.expr(), parameters, + check_nullability, ) - elif isinstance(parameterized_type, SubstraitTypeParser.ListContext): - if kind != "list": - return False - covered_element_type = covered.list.type - param_element_ctx = parameterized_type.expr() - return covers( - covered_element_type, param_element_ctx, parameters, check_nullability + if isinstance(parameterized_type, SubstraitTypeParser.MapContext): + return ( + kind == "map" + and covers( + covered.map.key, parameterized_type.key, parameters, check_nullability + ) + and covers( + covered.map.value, + parameterized_type.value, + parameters, + check_nullability, + ) ) - elif isinstance(parameterized_type, SubstraitTypeParser.MapContext): - if kind != "map": - return False - covered_key_type = covered.map.key - covered_value_type = covered.map.value - param_key_ctx = parameterized_type.key - param_value_ctx = parameterized_type.value - return covers( - covered_key_type, param_key_ctx, parameters, check_nullability - ) and covers(covered_value_type, param_value_ctx, parameters, check_nullability) - - elif isinstance(parameterized_type, SubstraitTypeParser.StructContext): + if isinstance(parameterized_type, SubstraitTypeParser.StructContext): if kind != "struct": return False covered_types = covered.struct.types @@ -272,13 +272,15 @@ def _cover_parametrized_type( return False for covered_field, param_field_ctx in zip(covered_types, param_types): if not covers( - covered_field, param_field_ctx, parameters, check_nullability - ): # type: ignore + covered_field, + param_field_ctx, + parameters, + check_nullability, # type: ignore + ): return False - else: - raise Exception(f"Unhandled type {type(parameterized_type)}") + return True - return True + raise Exception(f"Unhandled type {type(parameterized_type)}") class FunctionEntry: From f3d73b90c2e44fddcfb77200d83e6a3ebb40eede Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Wed, 10 Dec 2025 10:22:51 +0100 Subject: [PATCH 12/12] fix: removed unnecessary nullability check --- src/substrait/extension_registry.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index ec5a571..58bd6f5 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -209,12 +209,8 @@ def _cover_parametrized_type( ) if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - return ( - kind == "decimal" - and _check_nullability(check_nullability, parameterized_type, covered, kind) - and not check_violates_integer_option_parameters( - covered.decimal, parameterized_type, ["scale", "precision"], parameters - ) + return kind == "decimal" and not check_violates_integer_option_parameters( + covered.decimal, parameterized_type, ["scale", "precision"], parameters ) if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):