From 75f3b7a7fb0d7c2bfe7568087a3a1ff03a4980f8 Mon Sep 17 00:00:00 2001 From: Yash Limbad Date: Wed, 18 Mar 2026 16:26:45 +0530 Subject: [PATCH] [CALCITE-7442] Correlated variable has wrong index inside subquery --- .../org/apache/calcite/plan/RelOptUtil.java | 90 ++++++++-- .../org/apache/calcite/rel/core/Join.java | 24 +++ .../calcite/rel/logical/LogicalJoin.java | 8 + .../calcite/rel/rules/FilterJoinRule.java | 39 ++++- .../calcite/sql2rel/RelDecorrelatorTest.java | 156 ++++++++++++++++++ 5 files changed, 304 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java index 274e91f14de1..c1ea964e0333 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java @@ -2957,7 +2957,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, leftFields, - filter); + filter, + joinRel.getInput(0)); leftFilters.add(shiftedFilter); } @@ -2975,7 +2976,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, rightFields, - filter); + filter, + joinRel.getInput(1)); rightFilters.add(shiftedFilter); } filtersToRemove.add(filter); @@ -3079,7 +3081,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, leftFields, - filter); + filter, + joinRel.getInput(0)); leftFilters.add(shiftedFilter); } @@ -3105,7 +3108,8 @@ public static boolean classifyFilters( joinFields, nTotalFields, rightFields, - filter); + filter, + joinRel.getInput(1)); rightFilters.add(shiftedFilter); } filtersToRemove.add(filter); @@ -3140,7 +3144,8 @@ private static RexNode shiftFilter( List joinFields, int nTotalFields, List rightFields, - RexNode filter) { + RexNode filter, + RelNode child) { int[] adjustments = new int[nTotalFields]; for (int i = start; i < end; i++) { adjustments[i] = offset; @@ -3150,7 +3155,9 @@ private static RexNode shiftFilter( rexBuilder, joinFields, rightFields, - adjustments)); + adjustments, + offset, + child)); } /** @@ -4752,6 +4759,17 @@ public ImmutableBitSet build() { } return super.visitCall(call); } + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + final Set variablesSet = RelOptUtil.getVariablesUsed(subQuery.rel); + for (CorrelationId id : variablesSet) { + ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, subQuery.rel); + for (int index : requiredColumns) { + bitBuilder.set(index); + } + } + return super.visitSubQuery(subQuery); + } } /** @@ -4766,6 +4784,8 @@ public static class RexInputConverter extends RexShuttle { private final @Nullable List rightDestFields; private final int nLeftDestFields; private final int[] adjustments; + private final int offset; + private final @Nullable RelNode correlateVariableChild; /** * Creates a RexInputConverter. @@ -4784,6 +4804,13 @@ public static class RexInputConverter extends RexShuttle { * @param rightDestFields in the case where the destination is a join, * these are the fields from the right join input * @param adjustments the amount to adjust each field by + * @param offset the amount to shift field accesses by when + * rewriting correlated subqueries + * @param correlateVariableChild the child relation providing the + * correlated variable; if non-null, subqueries + * referencing a correlation variable will have + * their field accesses shifted by {@code offset} + * relative to this child */ private RexInputConverter( RexBuilder rexBuilder, @@ -4791,7 +4818,9 @@ private RexInputConverter( @Nullable List destFields, @Nullable List leftDestFields, @Nullable List rightDestFields, - int[] adjustments) { + int[] adjustments, + int offset, + @Nullable RelNode correlateVariableChild) { this.rexBuilder = rexBuilder; this.srcFields = srcFields; this.destFields = destFields; @@ -4804,6 +4833,8 @@ private RexInputConverter( assert destFields == null; nLeftDestFields = leftDestFields.size(); } + this.offset = offset; + this.correlateVariableChild = correlateVariableChild; } public RexInputConverter( @@ -4818,7 +4849,9 @@ public RexInputConverter( null, leftDestFields, rightDestFields, - adjustments); + adjustments, + 0, + null); } public RexInputConverter( @@ -4826,14 +4859,51 @@ public RexInputConverter( @Nullable List srcFields, @Nullable List destFields, int[] adjustments) { - this(rexBuilder, srcFields, destFields, null, null, adjustments); + this(rexBuilder, srcFields, destFields, null, null, adjustments, 0, null); } public RexInputConverter( RexBuilder rexBuilder, @Nullable List srcFields, int[] adjustments) { - this(rexBuilder, srcFields, null, null, null, adjustments); + this(rexBuilder, srcFields, null, null, null, adjustments, 0, null); + } + + public RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + @Nullable List destFields, + int[] adjustments, + int offset, + RelNode child) { + this(rexBuilder, srcFields, destFields, null, null, adjustments, offset, child); + } + + @Override public RexNode visitSubQuery(RexSubQuery subQuery) { + boolean[] update = {false}; + List clonedOperands = visitList(subQuery.operands, update); + if (update[0]) { + subQuery = subQuery.clone(subQuery.getType(), clonedOperands); + } + final Set variablesSet = + RelOptUtil.getVariablesUsed(subQuery.rel); + if (!variablesSet.isEmpty() && correlateVariableChild != null) { + for (CorrelationId id : variablesSet) { + RelNode newSubQueryRel = + subQuery.rel.accept(new RelHomogeneousShuttle() { + @Override public RelNode visit(RelNode other) { + RelNode node = + RexUtil.shiftFieldAccess(rexBuilder, other, id, + correlateVariableChild, offset); + return super.visit(node); + } + }); + if (newSubQueryRel != subQuery.rel) { + subQuery = subQuery.clone(newSubQueryRel); + } + } + } + return subQuery; } @Override public RexNode visitInputRef(RexInputRef var) { diff --git a/core/src/main/java/org/apache/calcite/rel/core/Join.java b/core/src/main/java/org/apache/calcite/rel/core/Join.java index 999c4639f9c7..7e7ce7155ce6 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/Join.java +++ b/core/src/main/java/org/apache/calcite/rel/core/Join.java @@ -355,6 +355,30 @@ public static RelDataType createJoinType( public abstract Join copy(RelTraitSet traitSet, RexNode conditionExpr, RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone); + /** + * Creates a copy of this join, overriding condition, system fields and + * inputs. + * + *

General contract as {@link RelNode#copy}. + * + * @param traitSet Traits + * @param conditionExpr Condition + * @param left Left input + * @param right Right input + * @param joinType Join type + * @param semiJoinDone Whether this join has been translated to a + * semi-join + * @param variablesSet Set of variables that are set by the + * LHS and used by the RHS and are not available to + * nodes above this LogicalJoin in the tree + * @return Copy of this join + */ + public Join copy(RelTraitSet traitSet, RexNode conditionExpr, + RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone, + Set variablesSet) { + return copy(traitSet, conditionExpr, left, right, joinType, semiJoinDone); + } + /** * Analyzes the join condition. * diff --git a/core/src/main/java/org/apache/calcite/rel/logical/LogicalJoin.java b/core/src/main/java/org/apache/calcite/rel/logical/LogicalJoin.java index 9504f01e5534..c58a925e8cdf 100644 --- a/core/src/main/java/org/apache/calcite/rel/logical/LogicalJoin.java +++ b/core/src/main/java/org/apache/calcite/rel/logical/LogicalJoin.java @@ -181,6 +181,14 @@ public static LogicalJoin create(RelNode left, RelNode right, List hint variablesSet, joinType, semiJoinDone, systemFieldList); } + @Override public LogicalJoin copy(RelTraitSet traitSet, RexNode conditionExpr, RelNode left, + RelNode right, JoinRelType joinType, boolean semiJoinDone, Set variablesSet) { + assert traitSet.containsIfApplicable(Convention.NONE); + return new LogicalJoin(getCluster(), + getCluster().traitSetOf(Convention.NONE), hints, left, right, conditionExpr, + variablesSet, joinType, semiJoinDone, systemFieldList); + } + @Override public RelNode accept(RelShuttle shuttle) { return shuttle.visit(this); } diff --git a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java index d4e1473ccbfc..6c1184828e2c 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java @@ -20,6 +20,7 @@ import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; @@ -29,7 +30,9 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.tools.RelBuilder; @@ -198,14 +201,43 @@ protected void perform(RelOptRuleCall call, @Nullable Filter filter, return; } + Set leftVariablesSet = new LinkedHashSet<>(); + Set rightVariablesSet = new LinkedHashSet<>(); + + for (RexNode condition : leftFilters) { + condition.accept(new RexVisitorImpl(true) { + @Override public Void visitSubQuery(RexSubQuery subQuery) { + leftVariablesSet.addAll(RelOptUtil.getVariablesUsed(subQuery.rel)); + return super.visitSubQuery(subQuery); + } + }); + } + + for (RexNode condition : rightFilters) { + condition.accept(new RexVisitorImpl(true) { + @Override public Void visitSubQuery(RexSubQuery subQuery) { + rightVariablesSet.addAll(RelOptUtil.getVariablesUsed(subQuery.rel)); + return super.visitSubQuery(subQuery); + } + }); + } + + ImmutableSet.Builder newJoinCorrelationIds = ImmutableSet.builder(); + for (CorrelationId correlationId : join.getVariablesSet()) { + if (!leftVariablesSet.contains(correlationId) + && !rightVariablesSet.contains(correlationId)) { + newJoinCorrelationIds.add(correlationId); + } + } + // create Filters on top of the children if any filters were // pushed to them final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); final RelBuilder relBuilder = call.builder(); final RelNode leftRel = - relBuilder.push(join.getLeft()).filter(leftFilters).build(); + relBuilder.push(join.getLeft()).filter(leftVariablesSet, leftFilters).build(); final RelNode rightRel = - relBuilder.push(join.getRight()).filter(rightFilters).build(); + relBuilder.push(join.getRight()).filter(rightVariablesSet, rightFilters).build(); // create the new join node referencing the new children and // containing its new join filters (if there are any) @@ -233,7 +265,8 @@ protected void perform(RelOptRuleCall call, @Nullable Filter filter, leftRel, rightRel, joinType, - join.isSemiJoinDone()); + join.isSemiJoinDone(), + newJoinCorrelationIds.build()); call.getPlanner().onCopy(join, newJoinRel); if (!leftFilters.isEmpty() && filter != null) { call.getPlanner().onCopy(filter, leftRel); diff --git a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java index 2b406c3c0154..77a6b28faf0f 100644 --- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java +++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java @@ -1830,4 +1830,160 @@ public static Frameworks.ConfigBuilder config() { + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(after, hasTree(planAfter)); } + + /** Test case for [CALCITE-7442] + * Getting Wrong index of Correlated variable inside Subquery after FilterJoinRule. */ + @Test void testCorrelatedVariableIndexForInClause() { + final FrameworkConfig frameworkConfig = config().build(); + final RelBuilder builder = RelBuilder.create(frameworkConfig); + final RelOptCluster cluster = builder.getCluster(); + final Planner planner = Frameworks.getPlanner(frameworkConfig); + final String sql = "select e.empno, d.dname, b.ename\n" + + "from emp e\n" + + "inner join dept d\n" + + " on d.deptno = e.deptno\n" + + "inner join bonus b\n" + + " on e.ename = b.ename\n" + + " and b.job in (\n" + + " select b2.job\n" + + " from bonus b2\n" + + " where b2.ename = b.ename)\n" + + "where e.sal > 1000 and d.dname = 'SALES'"; + + final RelNode originalRel; + try { + final SqlNode parse = planner.parse(sql); + final SqlNode validate = planner.validate(parse); + originalRel = planner.rel(validate).rel; + } catch (Exception e) { + throw TestUtil.rethrow(e); + } + + final HepProgram hepProgram = HepProgram.builder() + .addRuleCollection( + ImmutableList.of( + CoreRules.FILTER_INTO_JOIN, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)) + .build(); + final Program program = + Programs.of(hepProgram, true, + requireNonNull(cluster.getMetadataProvider())); + final RelNode before = + program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), + Collections.emptyList(), Collections.emptyList()); + + final String planBefore = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalFilter(condition=[=($1, $4)])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(JOB=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.ENAME)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(before, hasTree(planBefore)); + + final RelNode after = + RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()), + RuleSets.ofList(Collections.emptyList())); + final String planAfter = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalJoin(condition=[AND(=($0, $5), =($1, $4))], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(JOB=[$1], ENAME=[$0])\n" + + " LogicalFilter(condition=[IS NOT NULL($0)])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(after, hasTree(planAfter)); + } + + /** Test case for [CALCITE-7442] + * Getting Wrong index of Correlated variable inside Subquery after FilterJoinRule. + * Same as {@link #testCorrelatedVariableIndexForInClause()} but uses EXISTS + * instead of IN. */ + @Test void testCorrelatedVariableIndexForExistsClause() { + final FrameworkConfig frameworkConfig = config().build(); + final RelBuilder builder = RelBuilder.create(frameworkConfig); + final RelOptCluster cluster = builder.getCluster(); + final Planner planner = Frameworks.getPlanner(frameworkConfig); + final String sql = "select e.empno, d.dname, b.ename\n" + + "from emp e\n" + + "inner join dept d\n" + + " on d.deptno = e.deptno\n" + + "inner join bonus b\n" + + " on e.ename = b.ename\n" + + " and exists (\n" + + " select b2.job\n" + + " from bonus b2\n" + + " where b2.ename = b.ename\n" + + " and b2.job = b.job)\n" + + "where e.sal > 1000 and d.dname = 'SALES'"; + + final RelNode originalRel; + try { + final SqlNode parse = planner.parse(sql); + final SqlNode validate = planner.validate(parse); + originalRel = planner.rel(validate).rel; + } catch (Exception e) { + throw TestUtil.rethrow(e); + } + + final HepProgram hepProgram = HepProgram.builder() + .addRuleCollection( + ImmutableList.of( + CoreRules.FILTER_INTO_JOIN, + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)) + .build(); + final Program program = + Programs.of(hepProgram, true, + requireNonNull(cluster.getMetadataProvider())); + final RelNode before = + program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), + Collections.emptyList(), Collections.emptyList()); + + final String planBefore = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalAggregate(group=[{0}])\n" + + " LogicalProject(i=[true])\n" + + " LogicalFilter(condition=[AND(=($0, $cor0.ENAME), =($1, $cor0.JOB))])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + + assertThat(before, hasTree(planBefore)); + + final RelNode after = + RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()), + RuleSets.ofList(Collections.emptyList())); + final String planAfter = "LogicalProject(EMPNO=[$0], DNAME=[$9], ENAME=[$11])\n" + + " LogicalJoin(condition=[=($1, $11)], joinType=[inner])\n" + + " LogicalJoin(condition=[=($8, $7)], joinType=[inner])\n" + + " LogicalFilter(condition=[>(CAST($5):DECIMAL(12, 2), 1000.00)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalFilter(condition=[=($1, 'SALES')])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], SAL=[$2], COMM=[$3])\n" + + " LogicalJoin(condition=[AND(=($0, $4), =($1, $5))], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n" + + " LogicalProject(ENAME=[$0], JOB=[$1], $f2=[true])\n" + + " LogicalFilter(condition=[AND(IS NOT NULL($0), IS NOT NULL($1))])\n" + + " LogicalTableScan(table=[[scott, BONUS]])\n"; + assertThat(after, hasTree(planAfter)); + } }