diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 54d4f63a3..a92a9fb3b 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -33,6 +33,7 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.ComparisonJoinKey; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; @@ -440,10 +441,7 @@ public HashJoin hashJoin( return HashJoin.builder() .left(left) .right(right) - .leftKeys( - this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())) - .rightKeys( - this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())) + .keys(this.comparisonJoinKeys(left, right, leftKeys, rightKeys)) .joinType(joinType) .remap(remap) .build(); @@ -490,15 +488,38 @@ public MergeJoin mergeJoin( return MergeJoin.builder() .left(left) .right(right) - .leftKeys( - this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())) - .rightKeys( - this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())) + .keys(this.comparisonJoinKeys(left, right, leftKeys, rightKeys)) .joinType(joinType) .remap(remap) .build(); } + /** + * Builds a list of {@link ComparisonJoinKey}s pairing the given left/right field indexes with an + * {@link ComparisonJoinKey.SimpleComparisonType#EQ} comparison. + * + * @param left the left input relation + * @param right the right input relation + * @param leftKeys field indexes from the left relation + * @param rightKeys field indexes from the right relation + * @return the list of equality join keys + */ + public List comparisonJoinKeys( + Rel left, Rel right, List leftKeys, List rightKeys) { + if (leftKeys.size() != rightKeys.size()) { + throw new IllegalArgumentException("Number of left and right keys must be equal."); + } + List keys = new java.util.ArrayList<>(leftKeys.size()); + for (int i = 0; i < leftKeys.size(); i++) { + keys.add( + ComparisonJoinKey.of( + this.fieldReference(left, leftKeys.get(i)), + this.fieldReference(right, rightKeys.get(i)), + ComparisonJoinKey.SimpleComparisonType.EQ)); + } + return keys; + } + /** * Creates a nested loop join between two relations. * diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index b8a28f78c..95ebef2bf 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -43,6 +43,7 @@ import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.AbstractExchangeRel; import io.substrait.relation.physical.BroadcastExchange; +import io.substrait.relation.physical.ComparisonJoinKey; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.ImmutableBroadcastExchange; import io.substrait.relation.physical.ImmutableExchangeTarget; @@ -64,6 +65,7 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.jspecify.annotations.NonNull; /** Converts from {@link io.substrait.proto.Rel} to {@link io.substrait.relation.Rel} */ @@ -861,8 +863,6 @@ protected Set newSet(SetRel rel) { protected Rel newHashJoin(HashJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - List leftKeys = rel.getLeftKeysList(); - List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); @@ -877,8 +877,13 @@ protected Rel newHashJoin(HashJoinRel rel) { HashJoin.builder() .left(left) .right(right) - .leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList())) - .rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList())) + .keys( + comparisonJoinKeys( + rel.getKeysList(), + rel.getLeftKeysList(), + rel.getRightKeysList(), + leftConverter, + rightConverter)) .joinType(HashJoin.JoinType.fromProto(rel.getType())) .postJoinFilter( Optional.ofNullable( @@ -896,8 +901,6 @@ protected Rel newHashJoin(HashJoinRel rel) { protected Rel newMergeJoin(MergeJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); - List leftKeys = rel.getLeftKeysList(); - List rightKeys = rel.getRightKeysList(); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); @@ -912,8 +915,13 @@ protected Rel newMergeJoin(MergeJoinRel rel) { MergeJoin.builder() .left(left) .right(right) - .leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList())) - .rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList())) + .keys( + comparisonJoinKeys( + rel.getKeysList(), + rel.getLeftKeysList(), + rel.getRightKeysList(), + leftConverter, + rightConverter)) .joinType(MergeJoin.JoinType.fromProto(rel.getType())) .postJoinFilter( Optional.ofNullable( @@ -929,6 +937,63 @@ protected Rel newMergeJoin(MergeJoinRel rel) { return builder.build(); } + /** + * Builds the {@link ComparisonJoinKey} list for a hash/merge join, preferring the {@code keys} + * field. The deprecated {@code left_keys}/{@code right_keys} fields are only consulted when + * {@code keys} is empty, in which case they are paired up with a {@link + * ComparisonJoinKey.SimpleComparisonType#EQ} comparison. + */ + private List comparisonJoinKeys( + List keys, + List leftKeys, + List rightKeys, + ProtoExpressionConverter leftConverter, + ProtoExpressionConverter rightConverter) { + if (!keys.isEmpty()) { + return keys.stream() + .map(key -> comparisonJoinKey(key, leftConverter, rightConverter)) + .collect(Collectors.toList()); + } + if (leftKeys.size() != rightKeys.size()) { + throw new IllegalArgumentException("Number of left and right keys must be equal."); + } + return IntStream.range(0, leftKeys.size()) + .mapToObj( + i -> + ComparisonJoinKey.of( + leftConverter.from(leftKeys.get(i)), + rightConverter.from(rightKeys.get(i)), + ComparisonJoinKey.SimpleComparisonType.EQ)) + .collect(Collectors.toList()); + } + + private ComparisonJoinKey comparisonJoinKey( + io.substrait.proto.ComparisonJoinKey key, + ProtoExpressionConverter leftConverter, + ProtoExpressionConverter rightConverter) { + io.substrait.proto.ComparisonJoinKey.ComparisonType comparison = key.getComparison(); + final ComparisonJoinKey.ComparisonType comparisonType; + switch (comparison.getInnerTypeCase()) { + case SIMPLE: + comparisonType = + ComparisonJoinKey.SimpleComparison.of( + ComparisonJoinKey.SimpleComparisonType.fromProto(comparison.getSimple())); + break; + case CUSTOM_FUNCTION_REFERENCE: + comparisonType = + ComparisonJoinKey.CustomComparison.of(comparison.getCustomFunctionReference()); + break; + default: + throw new IllegalArgumentException( + "Unsupported comparison type: " + comparison.getInnerTypeCase()); + } + return ComparisonJoinKey.builder() + .left(leftConverter.from(key.getLeft())) + .right(rightConverter.from(key.getRight())) + .comparison(comparisonType) + .build(); + } + protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 29c388300..642af37fb 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -9,6 +9,7 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.relation.physical.BroadcastExchange; +import io.substrait.relation.physical.ComparisonJoinKey; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.MultiBucketExchange; @@ -438,14 +439,12 @@ public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) throws E { Optional left = hashJoin.getLeft().accept(this, context); Optional right = hashJoin.getRight().accept(this, context); - Optional> leftKeys = - transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); - Optional> rightKeys = - transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); + Optional> keys = + transformList(hashJoin.getKeys(), context, this::visitComparisonJoinKey); Optional postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter(), context); - if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { + if (allEmpty(left, right, keys, postFilter)) { return Optional.empty(); } return Optional.of( @@ -453,8 +452,7 @@ public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) th .from(hashJoin) .left(left.orElse(hashJoin.getLeft())) .right(right.orElse(hashJoin.getRight())) - .leftKeys(leftKeys.orElse(hashJoin.getLeftKeys())) - .rightKeys(rightKeys.orElse(hashJoin.getRightKeys())) + .keys(keys.orElse(hashJoin.getKeys())) .postJoinFilter(or(postFilter, hashJoin::getPostJoinFilter)) .build()); } @@ -463,14 +461,12 @@ public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) th public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E { Optional left = mergeJoin.getLeft().accept(this, context); Optional right = mergeJoin.getRight().accept(this, context); - Optional> leftKeys = - transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); - Optional> rightKeys = - transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); + Optional> keys = + transformList(mergeJoin.getKeys(), context, this::visitComparisonJoinKey); Optional postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); - if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { + if (allEmpty(left, right, keys, postFilter)) { return Optional.empty(); } return Optional.of( @@ -478,8 +474,7 @@ public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) .from(mergeJoin) .left(left.orElse(mergeJoin.getLeft())) .right(right.orElse(mergeJoin.getRight())) - .leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys())) - .rightKeys(rightKeys.orElse(mergeJoin.getRightKeys())) + .keys(keys.orElse(mergeJoin.getKeys())) .postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter)) .build()); } @@ -569,6 +564,21 @@ public Optional visitFieldReference( return Optional.of(FieldReference.builder().inputExpression(inputExpression).build()); } + public Optional visitComparisonJoinKey( + ComparisonJoinKey key, EmptyVisitationContext context) throws E { + Optional left = visitFieldReference(key.getLeft(), context); + Optional right = visitFieldReference(key.getRight(), context); + if (allEmpty(left, right)) { + return Optional.empty(); + } + return Optional.of( + ComparisonJoinKey.builder() + .from(key) + .left(left.orElse(key.getLeft())) + .right(right.orElse(key.getRight())) + .build()); + } + protected Optional> visitFunctionArguments( List funcArgs, EmptyVisitationContext context) throws E { return CopyOnWriteUtils.transformList( diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 0b9aebada..57cb5b4b8 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -47,6 +47,7 @@ import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.AbstractExchangeRel; import io.substrait.relation.physical.BroadcastExchange; +import io.substrait.relation.physical.ComparisonJoinKey; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.MultiBucketExchange; @@ -154,6 +155,64 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel return toProto((Expression) fieldReference).getSelection(); } + private io.substrait.proto.ComparisonJoinKey toProto(ComparisonJoinKey key) { + io.substrait.proto.ComparisonJoinKey.ComparisonType comparison = + key.getComparison() + .accept( + new ComparisonJoinKey.ComparisonTypeVisitor< + io.substrait.proto.ComparisonJoinKey.ComparisonType, RuntimeException>() { + @Override + public io.substrait.proto.ComparisonJoinKey.ComparisonType visit( + ComparisonJoinKey.SimpleComparison simpleComparison) { + return io.substrait.proto.ComparisonJoinKey.ComparisonType.newBuilder() + .setSimple(simpleComparison.getType().toProto()) + .build(); + } + + @Override + public io.substrait.proto.ComparisonJoinKey.ComparisonType visit( + ComparisonJoinKey.CustomComparison customComparison) { + return io.substrait.proto.ComparisonJoinKey.ComparisonType.newBuilder() + .setCustomFunctionReference(customComparison.getCustomFunctionReference()) + .build(); + } + }); + return io.substrait.proto.ComparisonJoinKey.newBuilder() + .setLeft(toProto(key.getLeft())) + .setRight(toProto(key.getRight())) + .setComparison(comparison) + .build(); + } + + /** + * Returns {@code true} when every key is a plain {@link + * ComparisonJoinKey.SimpleComparisonType#EQ} comparison, i.e. the only case the deprecated {@code + * left_keys}/{@code right_keys} fields can represent without losing information. {@code + * IS_NOT_DISTINCT_FROM}, {@code MIGHT_EQUAL} and custom comparisons cannot be expressed by the + * deprecated fields, where an old consumer would silently interpret them as equality. + */ + private boolean isLosslessAsDeprecatedKeys(List keys) { + return keys.stream() + .allMatch( + key -> + key.getComparison() + .accept( + new ComparisonJoinKey.ComparisonTypeVisitor() { + @Override + public Boolean visit( + ComparisonJoinKey.SimpleComparison simpleComparison) { + return simpleComparison.getType() + == ComparisonJoinKey.SimpleComparisonType.EQ; + } + + @Override + public Boolean visit( + ComparisonJoinKey.CustomComparison customComparison) { + return false; + } + })); + } + @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { final List uniqueGroupingExpressions = @@ -355,6 +414,7 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) } @Override + @SuppressWarnings("deprecation") // intentionally also writes the deprecated left_keys/right_keys public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { HashJoinRel.Builder builder = HashJoinRel.newBuilder() @@ -363,16 +423,19 @@ public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws Runti .setRight(toProto(hashJoin.getRight())) .setType(hashJoin.getJoinType().toProto()); - List leftKeys = hashJoin.getLeftKeys(); - List rightKeys = hashJoin.getRightKeys(); - - if (leftKeys.size() != rightKeys.size()) { - throw new IllegalArgumentException("Number of left and right keys must be equal."); + List keys = hashJoin.getKeys(); + builder.addAllKeys(keys.stream().map(this::toProto).collect(Collectors.toList())); + + // Also populate the deprecated left_keys/right_keys when every key is a plain EQ comparison so + // that consumers which have not yet adopted the new keys field keep working. Lossy comparison + // types are intentionally left out of the deprecated fields. + if (isLosslessAsDeprecatedKeys(keys)) { + builder.addAllLeftKeys( + keys.stream().map(k -> toProto(k.getLeft())).collect(Collectors.toList())); + builder.addAllRightKeys( + keys.stream().map(k -> toProto(k.getRight())).collect(Collectors.toList())); } - builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList())); - builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList())); - hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); hashJoin @@ -382,6 +445,7 @@ public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws Runti } @Override + @SuppressWarnings("deprecation") // intentionally also writes the deprecated left_keys/right_keys public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { MergeJoinRel.Builder builder = MergeJoinRel.newBuilder() @@ -390,16 +454,19 @@ public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws Run .setRight(toProto(mergeJoin.getRight())) .setType(mergeJoin.getJoinType().toProto()); - List leftKeys = mergeJoin.getLeftKeys(); - List rightKeys = mergeJoin.getRightKeys(); - - if (leftKeys.size() != rightKeys.size()) { - throw new IllegalArgumentException("Number of left and right keys must be equal."); + List keys = mergeJoin.getKeys(); + builder.addAllKeys(keys.stream().map(this::toProto).collect(Collectors.toList())); + + // Also populate the deprecated left_keys/right_keys when every key is a plain EQ comparison so + // that consumers which have not yet adopted the new keys field keep working. Lossy comparison + // types are intentionally left out of the deprecated fields. + if (isLosslessAsDeprecatedKeys(keys)) { + builder.addAllLeftKeys( + keys.stream().map(k -> toProto(k.getLeft())).collect(Collectors.toList())); + builder.addAllRightKeys( + keys.stream().map(k -> toProto(k.getRight())).collect(Collectors.toList())); } - builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList())); - builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList())); - mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); mergeJoin diff --git a/core/src/main/java/io/substrait/relation/physical/ComparisonJoinKey.java b/core/src/main/java/io/substrait/relation/physical/ComparisonJoinKey.java new file mode 100644 index 000000000..900e1363d --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/ComparisonJoinKey.java @@ -0,0 +1,118 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.FieldReference; +import org.immutables.value.Value; + +/** + * A single key comparison used by {@link HashJoin} and {@link MergeJoin}. + * + *

Models the {@code substrait.ComparisonJoinKey} protobuf message: a {@code left} and {@code + * right} {@link FieldReference} together with the {@link ComparisonType} describing how the two are + * compared. + */ +@Value.Enclosing +@Value.Immutable +public abstract class ComparisonJoinKey { + + public abstract FieldReference getLeft(); + + public abstract FieldReference getRight(); + + public abstract ComparisonType getComparison(); + + public static ImmutableComparisonJoinKey.Builder builder() { + return ImmutableComparisonJoinKey.builder(); + } + + /** Convenience factory for the common case of a {@link SimpleComparison}. */ + public static ComparisonJoinKey of( + FieldReference left, FieldReference right, SimpleComparisonType type) { + return builder().left(left).right(right).comparison(SimpleComparison.of(type)).build(); + } + + /** + * Describes how the two keys of a {@link ComparisonJoinKey} are compared. Models the {@code + * ComparisonType} oneof, which is either a {@link SimpleComparison} or a {@link + * CustomComparison}. + */ + public interface ComparisonType { + R accept(ComparisonTypeVisitor visitor) throws E; + } + + public interface ComparisonTypeVisitor { + R visit(SimpleComparison simpleComparison) throws E; + + R visit(CustomComparison customComparison) throws E; + } + + /** One of the predefined {@link SimpleComparisonType} behaviors. */ + @Value.Immutable + public abstract static class SimpleComparison implements ComparisonType { + public abstract SimpleComparisonType getType(); + + public static SimpleComparison of(SimpleComparisonType type) { + return ImmutableComparisonJoinKey.SimpleComparison.builder().type(type).build(); + } + + @Override + public R accept(ComparisonTypeVisitor visitor) throws E { + return visitor.visit(this); + } + } + + /** + * A custom comparison behavior, given by a reference to a binary function with a boolean return + * type. + */ + @Value.Immutable + public abstract static class CustomComparison implements ComparisonType { + public abstract int getCustomFunctionReference(); + + public static CustomComparison of(int customFunctionReference) { + return ImmutableComparisonJoinKey.CustomComparison.builder() + .customFunctionReference(customFunctionReference) + .build(); + } + + @Override + public R accept(ComparisonTypeVisitor visitor) throws E { + return visitor.visit(this); + } + } + + /** + * The predefined comparison behaviors, mapping {@code ComparisonJoinKey.SimpleComparisonType}. + */ + public enum SimpleComparisonType { + UNSPECIFIED( + io.substrait.proto.ComparisonJoinKey.SimpleComparisonType + .SIMPLE_COMPARISON_TYPE_UNSPECIFIED), + EQ(io.substrait.proto.ComparisonJoinKey.SimpleComparisonType.SIMPLE_COMPARISON_TYPE_EQ), + IS_NOT_DISTINCT_FROM( + io.substrait.proto.ComparisonJoinKey.SimpleComparisonType + .SIMPLE_COMPARISON_TYPE_IS_NOT_DISTINCT_FROM), + MIGHT_EQUAL( + io.substrait.proto.ComparisonJoinKey.SimpleComparisonType + .SIMPLE_COMPARISON_TYPE_MIGHT_EQUAL); + + private final io.substrait.proto.ComparisonJoinKey.SimpleComparisonType proto; + + SimpleComparisonType(io.substrait.proto.ComparisonJoinKey.SimpleComparisonType proto) { + this.proto = proto; + } + + public static SimpleComparisonType fromProto( + io.substrait.proto.ComparisonJoinKey.SimpleComparisonType proto) { + for (SimpleComparisonType v : values()) { + if (v.proto == proto) { + return v; + } + } + throw new IllegalArgumentException("Unknown type: " + proto); + } + + public io.substrait.proto.ComparisonJoinKey.SimpleComparisonType toProto() { + return proto; + } + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index ca49bcd46..44776a59c 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -17,9 +17,28 @@ @Value.Immutable public abstract class HashJoin extends BiRel implements HasExtension { - public abstract List getLeftKeys(); + public abstract List getKeys(); + + /** + * @deprecated the left-hand sides of the join {@link #getKeys()}; use {@link #getKeys()} instead. + */ + @Deprecated + public List getLeftKeys() { + return getKeys().stream() + .map(ComparisonJoinKey::getLeft) + .collect(java.util.stream.Collectors.toList()); + } - public abstract List getRightKeys(); + /** + * @deprecated the right-hand sides of the join {@link #getKeys()}; use {@link #getKeys()} + * instead. + */ + @Deprecated + public List getRightKeys() { + return getKeys().stream() + .map(ComparisonJoinKey::getRight) + .collect(java.util.stream.Collectors.toList()); + } public abstract JoinType getJoinType(); diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java index 9664868bb..edc0ce4bf 100644 --- a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -17,9 +17,28 @@ @Value.Immutable public abstract class MergeJoin extends BiRel implements HasExtension { - public abstract List getLeftKeys(); + public abstract List getKeys(); + + /** + * @deprecated the left-hand sides of the join {@link #getKeys()}; use {@link #getKeys()} instead. + */ + @Deprecated + public List getLeftKeys() { + return getKeys().stream() + .map(ComparisonJoinKey::getLeft) + .collect(java.util.stream.Collectors.toList()); + } - public abstract List getRightKeys(); + /** + * @deprecated the right-hand sides of the join {@link #getKeys()}; use {@link #getKeys()} + * instead. + */ + @Deprecated + public List getRightKeys() { + return getKeys().stream() + .map(ComparisonJoinKey::getRight) + .collect(java.util.stream.Collectors.toList()); + } public abstract JoinType getJoinType(); diff --git a/core/src/test/java/io/substrait/type/proto/HashMergeJoinKeysTest.java b/core/src/test/java/io/substrait/type/proto/HashMergeJoinKeysTest.java new file mode 100644 index 000000000..ac296f335 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/HashMergeJoinKeysTest.java @@ -0,0 +1,251 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.relation.Rel; +import io.substrait.relation.physical.ComparisonJoinKey; +import io.substrait.relation.physical.ComparisonJoinKey.SimpleComparisonType; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * Verifies that {@link HashJoin}/{@link MergeJoin} consume both the deprecated {@code + * left_keys}/{@code right_keys} proto fields and the new {@code keys} field, always produce the new + * {@code keys} field, and additionally produce the deprecated fields when (and only when) every key + * is a plain {@code EQ} comparison they can represent without loss. + */ +class HashMergeJoinKeysTest extends TestBase { + + final Rel leftTable = + sb.namedScan( + Arrays.asList("T1"), + Arrays.asList("a", "b", "c"), + Arrays.asList(R.I64, R.FP64, R.STRING)); + + final Rel rightTable = + sb.namedScan( + Arrays.asList("T2"), + Arrays.asList("d", "e", "f"), + Arrays.asList(R.FP64, R.STRING, R.I64)); + + final HashJoin hashJoin = + sb.hashJoin( + Arrays.asList(0, 1), Arrays.asList(2, 0), HashJoin.JoinType.INNER, leftTable, rightTable); + + final MergeJoin mergeJoin = + sb.mergeJoin( + Arrays.asList(0, 1), + Arrays.asList(2, 0), + MergeJoin.JoinType.INNER, + leftTable, + rightTable); + + @Test + void hashJoinRoundTrip() { + verifyRoundTrip(hashJoin); + } + + @Test + void mergeJoinRoundTrip() { + verifyRoundTrip(mergeJoin); + } + + // The DSL builds equality keys; the deprecated convenience views should derive from them. + @Test + void deprecatedViewsDeriveFromKeys() { + assertEquals( + hashJoin.getKeys().stream().map(ComparisonJoinKey::getLeft).collect(Collectors.toList()), + hashJoin.getLeftKeys()); + assertEquals( + hashJoin.getKeys().stream().map(ComparisonJoinKey::getRight).collect(Collectors.toList()), + hashJoin.getRightKeys()); + } + + // A plain EQ join must populate the new keys field and, for backwards compatibility with + // consumers that have not yet adopted it, the deprecated left_keys/right_keys as well. + @Test + void producesBothKeysForEqJoins() { + io.substrait.proto.HashJoinRel proto = relProtoConverter.toProto(hashJoin).getHashJoin(); + assertEquals(2, proto.getKeysCount()); + assertEquals( + proto.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getLeft) + .collect(Collectors.toList()), + proto.getLeftKeysList()); + assertEquals( + proto.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getRight) + .collect(Collectors.toList()), + proto.getRightKeysList()); + + io.substrait.proto.MergeJoinRel mergeProto = + relProtoConverter.toProto(mergeJoin).getMergeJoin(); + assertEquals(2, mergeProto.getKeysCount()); + assertEquals( + mergeProto.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getLeft) + .collect(Collectors.toList()), + mergeProto.getLeftKeysList()); + assertEquals( + mergeProto.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getRight) + .collect(Collectors.toList()), + mergeProto.getRightKeysList()); + } + + // The deprecated fields cannot represent IS_NOT_DISTINCT_FROM, MIGHT_EQUAL or custom comparisons, + // so a join containing any such key must only populate the new keys field. Otherwise an old + // consumer would silently misinterpret those keys as plain equality. + @Test + void producesOnlyNewKeysForLossyComparisons() { + List keys = + Arrays.asList( + ComparisonJoinKey.of( + sb.fieldReference(leftTable, 0), + sb.fieldReference(rightTable, 2), + SimpleComparisonType.EQ), + ComparisonJoinKey.of( + sb.fieldReference(leftTable, 1), + sb.fieldReference(rightTable, 0), + SimpleComparisonType.IS_NOT_DISTINCT_FROM)); + + HashJoin hash = + HashJoin.builder() + .left(leftTable) + .right(rightTable) + .keys(keys) + .joinType(HashJoin.JoinType.INNER) + .build(); + io.substrait.proto.HashJoinRel proto = relProtoConverter.toProto(hash).getHashJoin(); + assertEquals(2, proto.getKeysCount()); + assertTrue(proto.getLeftKeysList().isEmpty()); + assertTrue(proto.getRightKeysList().isEmpty()); + + MergeJoin merge = + MergeJoin.builder() + .left(leftTable) + .right(rightTable) + .keys(keys) + .joinType(MergeJoin.JoinType.INNER) + .build(); + io.substrait.proto.MergeJoinRel mergeProto = relProtoConverter.toProto(merge).getMergeJoin(); + assertEquals(2, mergeProto.getKeysCount()); + assertTrue(mergeProto.getLeftKeysList().isEmpty()); + assertTrue(mergeProto.getRightKeysList().isEmpty()); + } + + // A plan from a legacy producer (only deprecated fields set) is consumed and mapped to EQ keys. + @Test + void consumesLegacyHashJoin() { + io.substrait.proto.HashJoinRel modern = relProtoConverter.toProto(hashJoin).getHashJoin(); + io.substrait.proto.HashJoinRel legacy = + modern.toBuilder() + .clearKeys() + .clearLeftKeys() + .clearRightKeys() + .addAllLeftKeys( + modern.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getLeft) + .collect(Collectors.toList())) + .addAllRightKeys( + modern.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getRight) + .collect(Collectors.toList())) + .build(); + + Rel result = + protoRelConverter.from(io.substrait.proto.Rel.newBuilder().setHashJoin(legacy).build()); + assertEquals(hashJoin, result); + assertEquals(2, ((HashJoin) result).getKeys().size()); + ((HashJoin) result) + .getKeys() + .forEach( + k -> + assertEquals( + ComparisonJoinKey.SimpleComparison.of(SimpleComparisonType.EQ), + k.getComparison())); + } + + @Test + void consumesLegacyMergeJoin() { + io.substrait.proto.MergeJoinRel modern = relProtoConverter.toProto(mergeJoin).getMergeJoin(); + io.substrait.proto.MergeJoinRel legacy = + modern.toBuilder() + .clearKeys() + .clearLeftKeys() + .clearRightKeys() + .addAllLeftKeys( + modern.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getLeft) + .collect(Collectors.toList())) + .addAllRightKeys( + modern.getKeysList().stream() + .map(io.substrait.proto.ComparisonJoinKey::getRight) + .collect(Collectors.toList())) + .build(); + + Rel result = + protoRelConverter.from(io.substrait.proto.Rel.newBuilder().setMergeJoin(legacy).build()); + assertEquals(mergeJoin, result); + } + + // When both the deprecated fields and the new keys are present, keys wins. + @Test + void prefersNewKeysOverDeprecated() { + io.substrait.proto.HashJoinRel modern = relProtoConverter.toProto(hashJoin).getHashJoin(); + // Add bogus deprecated keys that point at different fields than the real keys. + io.substrait.proto.HashJoinRel both = + modern.toBuilder() + .addLeftKeys(modern.getKeys(0).getRight()) + .addRightKeys(modern.getKeys(0).getLeft()) + .build(); + + Rel result = + protoRelConverter.from(io.substrait.proto.Rel.newBuilder().setHashJoin(both).build()); + assertEquals(hashJoin, result); + } + + // Non-EQ simple comparisons and custom comparison functions survive a round trip. + @Test + void fullFidelityRoundTrip() { + List keys = + Arrays.asList( + ComparisonJoinKey.of( + sb.fieldReference(leftTable, 0), + sb.fieldReference(rightTable, 2), + SimpleComparisonType.EQ), + ComparisonJoinKey.of( + sb.fieldReference(leftTable, 1), + sb.fieldReference(rightTable, 0), + SimpleComparisonType.IS_NOT_DISTINCT_FROM), + ComparisonJoinKey.builder() + .left(sb.fieldReference(leftTable, 2)) + .right(sb.fieldReference(rightTable, 1)) + .comparison(ComparisonJoinKey.CustomComparison.of(42)) + .build()); + + Rel hash = + HashJoin.builder() + .left(leftTable) + .right(rightTable) + .keys(keys) + .joinType(HashJoin.JoinType.INNER) + .build(); + verifyRoundTrip(hash); + + Rel merge = + MergeJoin.builder() + .left(leftTable) + .right(rightTable) + .keys(keys) + .joinType(MergeJoin.JoinType.INNER) + .build(); + verifyRoundTrip(merge); + } +}