Skip to content

Commit 3f1b403

Browse files
committed
[FLINK-38709][table][python] Fix ScalarFunctionSplitter to allow PythonFunction & AsyncFunction work when taking the recursive field of composite type as input
1 parent 05dee4a commit 3f1b403

File tree

5 files changed

+175
-108
lines changed

5 files changed

+175
-108
lines changed

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ class ScalarFunctionSplitter(
434434

435435
private var fieldsRexCall: Map[Int, Int] = Map[Int, Int]()
436436

437+
private val extractedRexNodeRefs: mutable.HashSet[RexNode] = mutable.HashSet[RexNode]()
438+
437439
override def visitCall(call: RexCall): RexNode = {
438440
if (needConvert(call)) {
439441
getExtractedRexNode(call)
@@ -454,7 +456,9 @@ class ScalarFunctionSplitter(
454456
new RexInputRef(field.getIndex, field.getType)
455457
case _ =>
456458
val newFieldAccess =
457-
rexBuilder.makeFieldAccess(expr.accept(this), fieldAccess.getField.getIndex)
459+
rexBuilder.makeFieldAccess(
460+
convertInputRefToLocalRefIfNecessary(expr.accept(this)),
461+
fieldAccess.getField.getIndex)
458462
getExtractedRexNode(newFieldAccess)
459463
}
460464
} else {
@@ -468,9 +472,18 @@ class ScalarFunctionSplitter(
468472

469473
override def visitNode(rexNode: RexNode): RexNode = rexNode
470474

475+
private def convertInputRefToLocalRefIfNecessary(node: RexNode): RexNode = {
476+
node match {
477+
case inputRef: RexInputRef if extractedRexNodeRefs.contains(node) =>
478+
new RexLocalRef(inputRef.getIndex, node.getType)
479+
case _ => node
480+
}
481+
}
482+
471483
private def getExtractedRexNode(node: RexNode): RexNode = {
472484
val newNode = new RexInputRef(extractedFunctionOffset + extractedRexNodes.length, node.getType)
473485
extractedRexNodes.append(node)
486+
extractedRexNodeRefs.add(newNode)
474487
newNode
475488
}
476489

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public void setup() {
6262
+ " a int,\n"
6363
+ " b bigint,\n"
6464
+ " c string,\n"
65-
+ " d ARRAY<INT NOT NULL>\n"
65+
+ " d ARRAY<INT NOT NULL>,\n"
66+
+ " e ROW<f ROW<h int, i double>, g string>"
6667
+ ") WITH (\n"
6768
+ " 'connector' = 'test-simple-table-source'\n"
6869
+ ") ;");
@@ -89,6 +90,7 @@ public void setup() {
8990
@Test
9091
public void testSingleCall() {
9192
String sqlQuery = "SELECT func1(a) FROM MyTable";
93+
util.getTableEnv().explainSql(sqlQuery);
9294
util.verifyRelPlan(sqlQuery);
9395
}
9496

@@ -182,6 +184,12 @@ public void testFieldAccessAfter() {
182184
util.verifyRelPlan(sqlQuery);
183185
}
184186

187+
@Test
188+
public void testCompositeFieldAsInput() {
189+
String sqlQuery = "SELECT func1(e.f.h) from MyTable";
190+
util.verifyRelPlan(sqlQuery);
191+
}
192+
185193
@Test
186194
public void testFieldOperand() {
187195
String sqlQuery = "SELECT func1(func5(a).f0) from MyTable";

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public void setup() {
6262
+ " a int,\n"
6363
+ " b bigint,\n"
6464
+ " c string,\n"
65-
+ " d ARRAY<INT NOT NULL>\n"
65+
+ " d ARRAY<INT NOT NULL>,\n"
66+
+ " e ROW<f ROW<h int, i double>, g string>\n"
6667
+ ") WITH (\n"
6768
+ " 'connector' = 'test-simple-table-source'\n"
6869
+ ") ;");
@@ -110,6 +111,12 @@ public void testCorrelateWithCast() {
110111
util.verifyRelPlan(sqlQuery);
111112
}
112113

114+
@Test
115+
public void testCorrelateWithCompositeFieldAsInput() {
116+
String sqlQuery = "select * FROM MyTable, LATERAL TABLE(asyncTableFunc(e.f.h))";
117+
util.verifyRelPlan(sqlQuery);
118+
}
119+
113120
/** Test function. */
114121
public static class AsyncFunc extends AsyncTableFunction<String> {
115122

0 commit comments

Comments
 (0)