From 0bde784477dd5e7517ee6b137349b8e5679a2dc9 Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Tue, 21 Apr 2026 12:44:07 +0200 Subject: [PATCH 01/10] [FLINK-39392][table] Support conditional traits for PTF table arguments --- .../docs/sql/reference/queries/changelog.md | 10 +- .../functions/BuiltInFunctionDefinitions.java | 31 ++-- .../table/types/inference/StaticArgument.java | 142 +++++++++++++++++- .../types/inference/SystemTypeInference.java | 52 +++++-- .../table/types/inference/TraitCondition.java | 66 ++++++++ .../table/types/inference/TraitContext.java | 44 ++++++ .../bridging/BridgingSqlFunction.java | 9 +- .../StreamExecProcessTableFunction.java | 16 +- .../StreamPhysicalProcessTableFunction.java | 120 ++++++++++++++- ...treamPhysicalProcessTableFunctionRule.java | 38 +++-- .../FlinkChangelogModeInferenceProgram.scala | 13 +- .../plan/stream/sql/ToChangelogTest.java | 17 +++ .../plan/stream/sql/ToChangelogTest.xml | 20 +++ 13 files changed, 511 insertions(+), 67 deletions(-) create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java diff --git a/docs/content/docs/sql/reference/queries/changelog.md b/docs/content/docs/sql/reference/queries/changelog.md index 60773af2b27f1..2b66ded50d013 100644 --- a/docs/content/docs/sql/reference/queries/changelog.md +++ b/docs/content/docs/sql/reference/queries/changelog.md @@ -177,7 +177,7 @@ This is useful when you need to materialize changelog events into a downstream s ```sql SELECT * FROM TO_CHANGELOG( - input => TABLE source_table, + input => TABLE source_table [PARTITION BY key_col], [op => DESCRIPTOR(op_column_name),] [op_mapping => MAP['INSERT', 'I', 'DELETE', 'D', ...]] ) @@ -185,10 +185,10 @@ SELECT * FROM TO_CHANGELOG( ### Parameters -| Parameter | Required | Description | -|:-------------|:---------|:------------| -| `input` | Yes | The input table. Accepts insert-only, retract, and upsert tables. | -| `op` | No | A `DESCRIPTOR` with a single column name for the operation code column. Defaults to `op`. | +| Parameter | Required | Description | +|:-------------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `input` | Yes | The input table. With `PARTITION BY`, rows with the same key are co-located and run in the same operator instance. Without `PARTITION BY`, each row is processed independently. Accepts insert-only, retract, and upsert tables. For upsert tables, providing `PARTITION BY` is recommended for better performance. | +| `op` | No | A `DESCRIPTOR` with a single column name for the operation code column. Defaults to `op`. | | `op_mapping` | No | A `MAP` mapping change operation names to custom output codes. Keys can contain comma-separated names to map multiple operations to the same code (e.g., `'INSERT, UPDATE_AFTER'`). When provided, only mapped operations are forwarded - unmapped events are dropped. Each change operation may appear at most once across all entries. | #### Default op_mapping diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index bbe38ca13d5a5..cd25361923c7e 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -36,6 +36,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; +import org.apache.flink.table.types.inference.TraitCondition; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.types.inference.strategies.ArrayOfStringArgumentTypeStrategy; import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies; @@ -785,22 +786,22 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) .name("TO_CHANGELOG") .kind(PROCESS_TABLE) .staticArguments( - // Row semantics (no PARTITION BY). Accepts updating - // inputs. The planner inserts ChangelogNormalize for - // upsert sources to produce UPDATE_BEFORE and full - // DELETE rows. + // Row semantics (no PARTITION BY). + // With PARTITION BY, switches to set + // semantics for co-located parallel execution. StaticArgument.table( - "input", - Row.class, - false, - EnumSet.of( - StaticArgumentTrait.TABLE, - StaticArgumentTrait.ROW_SEMANTIC_TABLE, - StaticArgumentTrait.SUPPORT_UPDATES, - StaticArgumentTrait.REQUIRE_UPDATE_BEFORE, - // Not strictly necessary but explicitly state that - // we require full deletes. - StaticArgumentTrait.REQUIRE_FULL_DELETE)), + "input", + Row.class, + false, + EnumSet.of( + StaticArgumentTrait.TABLE, + StaticArgumentTrait.ROW_SEMANTIC_TABLE, + StaticArgumentTrait.SUPPORT_UPDATES, + StaticArgumentTrait.REQUIRE_UPDATE_BEFORE, + StaticArgumentTrait.REQUIRE_FULL_DELETE)) + .withConditionalTrait( + StaticArgumentTrait.SET_SEMANTIC_TABLE, + TraitCondition.hasPartitionBy()), StaticArgument.scalar("op", DataTypes.DESCRIPTOR(), true), StaticArgument.scalar( "op_mapping", diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java index 3f5c48db8d221..1a027d430b222 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java @@ -31,10 +31,14 @@ import javax.annotation.Nullable; +import java.io.Serializable; +import java.util.ArrayList; import java.util.EnumSet; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Describes an argument in a static signature that is not overloaded and does not support varargs. @@ -57,6 +61,7 @@ public class StaticArgument { private final @Nullable Class conversionClass; private final boolean isOptional; private final EnumSet traits; + private final List conditionalTraits; private StaticArgument( String name, @@ -64,11 +69,22 @@ private StaticArgument( @Nullable Class conversionClass, boolean isOptional, EnumSet traits) { + this(name, dataType, conversionClass, isOptional, traits, List.of()); + } + + private StaticArgument( + String name, + @Nullable DataType dataType, + @Nullable Class conversionClass, + boolean isOptional, + EnumSet traits, + List conditionalTraits) { this.name = Preconditions.checkNotNull(name, "Name must not be null."); this.dataType = dataType; this.conversionClass = conversionClass; this.isOptional = isOptional; this.traits = Preconditions.checkNotNull(traits, "Traits must not be null."); + this.conditionalTraits = conditionalTraits; checkName(); checkTraits(traits); checkOptionalType(); @@ -196,6 +212,91 @@ public boolean is(StaticArgumentTrait trait) { return traits.contains(trait); } + /** + * Context-aware trait check. Evaluates conditional trait rules against the given context to + * determine the effective traits. + */ + public boolean is(StaticArgumentTrait trait, TraitContext ctx) { + return resolveTraits(ctx).contains(trait); + } + + /** + * Returns a new {@link StaticArgument} with an additional conditional trait rule. The trait is + * added to the effective trait set when the condition evaluates to {@code true} at planning + * time. Only non-root traits (subtraits of TABLE, SCALAR, or MODEL) are allowed. + * + *

Multiple conditions for the same trait use OR semantics: the trait is activated if any of + * its conditions is met. + * + *

Example: + * + *

{@code
+     * StaticArgument.table("input", Row.class, false, EnumSet.of(TABLE, SUPPORT_UPDATES))
+     *         .withConditionalTrait(SET_SEMANTIC_TABLE, hasPartitionBy());
+     * }
+ */ + public StaticArgument withConditionalTrait( + final StaticArgumentTrait trait, final TraitCondition condition) { + if (trait == StaticArgumentTrait.SCALAR + || trait == StaticArgumentTrait.TABLE + || trait == StaticArgumentTrait.MODEL) { + throw new IllegalArgumentException( + "Root traits (SCALAR, TABLE, MODEL) cannot be conditional."); + } + final List accumulated = new ArrayList<>(this.conditionalTraits); + accumulated.add(new ConditionalTrait(condition, trait)); + return new StaticArgument(name, dataType, conversionClass, isOptional, traits, accumulated); + } + + /** Whether this argument has conditional trait rules. */ + public boolean hasConditionalTraits() { + return !conditionalTraits.isEmpty(); + } + + /** Whether any conditional trait rule may add the given trait. */ + public boolean hasConditionalTrait(final StaticArgumentTrait trait) { + return conditionalTraits.stream().anyMatch(c -> c.trait == trait); + } + + /** + * Returns a new {@link StaticArgument} with conditional traits resolved against the given + * context. The returned argument has the effective traits baked in and no conditional rules. + */ + public StaticArgument applyConditionalTraits(final TraitContext ctx) { + if (conditionalTraits.isEmpty()) { + return this; + } + return new StaticArgument(name, dataType, conversionClass, isOptional, resolveTraits(ctx)); + } + + /** + * Resolves effective traits by evaluating conditional rules against the context. Returns the + * base traits combined with any conditional traits whose conditions are met. + */ + public EnumSet resolveTraits(final TraitContext ctx) { + if (conditionalTraits.isEmpty()) { + return traits; + } + final EnumSet resolved = EnumSet.copyOf(traits); + for (final ConditionalTrait conditionalTrait : conditionalTraits) { + if (conditionalTrait.condition.test(ctx)) { + removeMutuallyExclusiveTraits(resolved, conditionalTrait.trait); + resolved.add(conditionalTrait.trait); + } + } + return resolved; + } + + /** ROW and SET semantics are mutually exclusive - adding one removes the other. */ + private static void removeMutuallyExclusiveTraits( + final EnumSet traits, final StaticArgumentTrait adding) { + if (adding == StaticArgumentTrait.SET_SEMANTIC_TABLE) { + traits.remove(StaticArgumentTrait.ROW_SEMANTIC_TABLE); + } else if (adding == StaticArgumentTrait.ROW_SEMANTIC_TABLE) { + traits.remove(StaticArgumentTrait.SET_SEMANTIC_TABLE); + } + } + @Override public String toString() { final StringBuilder s = new StringBuilder(); @@ -210,11 +311,13 @@ public String toString() { s.append(dataType); } if (!traits.equals(EnumSet.of(StaticArgumentTrait.SCALAR))) { + final Stream baseTraitNames = + traits.stream().map(Enum::name).map(n -> n.replace('_', ' ')); + final Stream conditionalTraitNames = + conditionalTraits.stream().map(c -> c.trait.name().replace('_', ' ')); s.append(" "); s.append( - traits.stream() - .map(Enum::name) - .map(n -> n.replace('_', ' ')) + Stream.concat(baseTraitNames, conditionalTraitNames) .collect(Collectors.joining(", ", "{", "}"))); } return s.toString(); @@ -233,12 +336,13 @@ public boolean equals(Object o) { && Objects.equals(name, that.name) && Objects.equals(dataType, that.dataType) && Objects.equals(conversionClass, that.conversionClass) - && Objects.equals(traits, that.traits); + && Objects.equals(traits, that.traits) + && Objects.equals(conditionalTraits, that.conditionalTraits); } @Override public int hashCode() { - return Objects.hash(name, dataType, conversionClass, isOptional, traits); + return Objects.hash(name, dataType, conversionClass, isOptional, traits, conditionalTraits); } private void checkName() { @@ -354,4 +458,32 @@ private void checkModelNotOptional() { throw new ValidationException("Model arguments must not be optional."); } } + + /** A trait that is conditionally added based on a {@link TraitCondition}. */ + private static final class ConditionalTrait implements Serializable { + private final TraitCondition condition; + private final StaticArgumentTrait trait; + + ConditionalTrait(final TraitCondition condition, final StaticArgumentTrait trait) { + this.condition = condition; + this.trait = trait; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ConditionalTrait that = (ConditionalTrait) o; + return Objects.equals(condition, that.condition) && trait == that.trait; + } + + @Override + public int hashCode() { + return Objects.hash(condition, trait); + } + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java index 92fd79f407b67..0d9641b96adda 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java @@ -182,6 +182,32 @@ private static void checkReservedArgs(List staticArgs) { } } + static TraitContext buildTraitContext( + @Nullable final TableSemantics semantics, + final CallContext callContext, + final List staticArgs) { + return new TraitContext() { + @Override + public boolean hasPartitionBy() { + return semantics != null && semantics.partitionByColumns().length > 0; + } + + @Override + public Optional getScalarArgument(final String name, final Class clazz) { + for (int i = 0; i < staticArgs.size(); i++) { + final StaticArgument arg = staticArgs.get(i); + if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) { + if (!callContext.isArgumentLiteral(i)) { + return Optional.empty(); + } + return callContext.getArgumentValue(i, clazz); + } + } + return Optional.empty(); + } + }; + } + private static void checkMultipleTableArgs(List staticArgs) { final List tableArgs = staticArgs.stream() @@ -311,17 +337,21 @@ private List derivePassThroughFields(CallContext callContext) { return IntStream.range(0, staticArgs.size()) .mapToObj( pos -> { - final StaticArgument arg = staticArgs.get(pos); - if (arg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { + final TableSemantics semantics = + callContext.getTableSemantics(pos).orElse(null); + final TraitContext traitCtx = + buildTraitContext(semantics, callContext, staticArgs); + final StaticArgument resolvedArg = + staticArgs.get(pos).applyConditionalTraits(traitCtx); + if (resolvedArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { return DataType.getFields(argDataTypes.get(pos)).stream(); } - if (!arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { + if (semantics == null) { + return Stream.empty(); + } + if (!resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { return Stream.empty(); } - final TableSemantics semantics = - callContext - .getTableSemantics(pos) - .orElseThrow(IllegalStateException::new); final DataType rowDataType = DataTypes.ROW(DataType.getFields(argDataTypes.get(pos))); final DataType projectedRow = @@ -620,8 +650,12 @@ private static void checkTableArgs( "Table expected for argument '%s'.", staticArg.getName())); } - checkRowSemantics(staticArg, semantics); - checkSetSemantics(staticArg, semantics); + final TraitContext traitCtx = + buildTraitContext(semantics, callContext, staticArgs); + final StaticArgument resolvedArg = + staticArg.applyConditionalTraits(traitCtx); + checkRowSemantics(resolvedArg, semantics); + checkSetSemantics(resolvedArg, semantics); tableSemantics.add(semantics); }); checkCoPartitioning(tableSemantics); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java new file mode 100644 index 0000000000000..61db5a2f3e614 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.annotation.PublicEvolving; + +import java.io.Serializable; + +/** + * A condition that determines whether a conditional trait on a {@link StaticArgument} should be + * active for a given call. + * + *

Conditions are evaluated at planning time using the {@link TraitContext} which provides access + * to the SQL call's properties (PARTITION BY presence, scalar literal values, etc.). + * + *

Use the static factory methods for common conditions: + * + *

{@code
+ * import static org.apache.flink.table.types.inference.TraitCondition.*;
+ *
+ * StaticArgument.table("input", Row.class, false, EnumSet.of(TABLE, SUPPORT_UPDATES))
+ *         .withConditionalTrait(SET_SEMANTIC_TABLE, hasPartitionBy());
+ * }
+ */ +@PublicEvolving +@FunctionalInterface +public interface TraitCondition extends Serializable { + + /** Evaluates this condition against the given context. */ + boolean test(TraitContext ctx); + + /** True when PARTITION BY is provided on the table argument. */ + static TraitCondition hasPartitionBy() { + return TraitContext::hasPartitionBy; + } + + /** True when the named scalar argument equals the expected value. */ + @SuppressWarnings("unchecked") + static TraitCondition argIsEqualTo(final String name, final T expected) { + return ctx -> + ctx.getScalarArgument(name, (Class) expected.getClass()) + .map(expected::equals) + .orElse(false); + } + + /** Negates the given condition. */ + static TraitCondition not(final TraitCondition condition) { + return ctx -> !condition.test(ctx); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java new file mode 100644 index 0000000000000..317a9a0ece054 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference; + +import org.apache.flink.annotation.PublicEvolving; + +import java.util.Optional; + +/** + * Read-only context provided to {@link TraitCondition} during trait resolution at planning time. + * + *

Allows conditions to inspect the SQL call (e.g., whether PARTITION BY was provided, or what + * value a scalar argument has) to decide whether a conditional trait should be active. + */ +@PublicEvolving +public interface TraitContext { + + /** Whether PARTITION BY was provided on this table argument. */ + boolean hasPartitionBy(); + + /** + * Reads a scalar argument value by name. + * + * @return the argument value, or empty if the argument was not provided, is not a literal, or + * cannot be converted to the requested type + */ + Optional getScalarArgument(String name, Class clazz); +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java index 1cad3c7ec86dc..e5080ad94fcbe 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java @@ -273,10 +273,13 @@ public SqlReturnTypeInference getRowTypeInference() { } final StaticArgument arg = args.get(ordinal); final TableCharacteristic.Semantics semantics; - if (arg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE)) { - semantics = TableCharacteristic.Semantics.ROW; - } else if (arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { + // Report SET if it may apply - which allows the use of Partition BY + // actual semantics resolved later via resolveTraits + if (arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE) + || arg.hasConditionalTrait(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { semantics = TableCharacteristic.Semantics.SET; + } else if (arg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE)) { + semantics = TableCharacteristic.Semantics.ROW; } else { return null; } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java index 12a20833068c8..9fe3467ec4c80 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java @@ -188,7 +188,7 @@ protected Transformation translateToPlanInternal( (RexTableArgCall) operands.get(providedInputArg.i); final StaticArgument tableArg = providedInputArg.e; return createRuntimeTableSemantics( - tableArg, tableArgCall, inputTimeColumns); + invocation, tableArg, tableArgCall, inputTimeColumns); }) .collect(Collectors.toList()); @@ -293,7 +293,10 @@ protected Transformation translateToPlanInternal( } private RuntimeTableSemantics createRuntimeTableSemantics( - StaticArgument tableArg, RexTableArgCall tableArgCall, List inputTimeColumns) { + RexCall call, + StaticArgument tableArg, + RexTableArgCall tableArgCall, + List inputTimeColumns) { final RuntimeChangelogMode consumedChangelogMode = RuntimeChangelogMode.serialize( inputChangelogModes.get(tableArgCall.getInputIndex())); @@ -305,17 +308,20 @@ private RuntimeTableSemantics createRuntimeTableSemantics( } final int timeColumn = inputTimeColumns.get(tableArgCall.getInputIndex()); + final StaticArgument resolvedArg = + tableArg.applyConditionalTraits( + StreamPhysicalProcessTableFunction.buildTraitContext(call, tableArgCall)); return new RuntimeTableSemantics( - tableArg.getName(), + resolvedArg.getName(), tableArgCall.getInputIndex(), dataType, tableArgCall.getPartitionKeys(), tableArgCall.getOrderKeys(), RexTableArgCall.toSortDirections(tableArgCall.getSortOrder()), consumedChangelogMode, - tableArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH), - tableArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE), + resolvedArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH), + resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE), timeColumn); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java index 754fe1b328492..2122d516596a1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java @@ -40,6 +40,8 @@ import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; +import org.apache.flink.table.types.inference.TraitContext; +import org.apache.flink.types.ColumnList; import org.apache.flink.types.RowKind; import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet; @@ -68,8 +70,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -220,8 +225,7 @@ protected RelDataType deriveRowType() { final RexNode uidRexNode = operands.get(operands.size() - 1); if (uidRexNode.getKind() == SqlKind.DEFAULT) { // Optional for constant or row semantics functions - if (staticArgs.stream() - .noneMatch(arg -> arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE))) { + if (!hasResolvedSetSemantics(staticArgs, operands, rexCall)) { return null; } final String uid = @@ -379,6 +383,118 @@ public static List> getProvidedInputArgs(RexCall call) { .collect(Collectors.toList()); } + /** + * Builds a {@link TraitContext} for resolving conditional traits on a table argument at + * planning time. + */ + public static TraitContext buildTraitContext( + final RexCall call, final RexTableArgCall tableArgCall) { + final List declaredArgs = getStaticArguments(call); + final List operands = call.getOperands(); + + return new TraitContext() { + @Override + public boolean hasPartitionBy() { + return tableArgCall.getPartitionKeys().length > 0; + } + + @Override + public Optional getScalarArgument(final String name, final Class clazz) { + return findScalarLiteral(declaredArgs, operands, name, clazz); + } + }; + } + + /** Checks if any table argument resolves to SET_SEMANTIC_TABLE after applying conditions. */ + private static boolean hasResolvedSetSemantics( + final List staticArgs, + final List operands, + final RexCall rexCall) { + for (int i = 0; i < staticArgs.size(); i++) { + final StaticArgument arg = staticArgs.get(i); + if (!arg.is(StaticArgumentTrait.TABLE)) { + continue; + } + + final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(i); + + final StaticArgument resolvedArg = + arg.applyConditionalTraits(buildTraitContext(rexCall, tableArgCall)); + if (resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { + return true; + } + } + return false; + } + + public static List getStaticArguments(final RexCall call) { + final BridgingSqlFunction.WithTableFunction function = + (BridgingSqlFunction.WithTableFunction) call.getOperator(); + return function.getTypeInference() + .getStaticArguments() + .orElseThrow(IllegalStateException::new); + } + + /** + * Extracts a scalar argument value by name. Handles NULL, DEFAULT, DESCRIPTOR, MAP, and + * standard literal values (Boolean, String, Integer, etc.) following the same rules as {@link + * CallContext#getArgumentValue}. Does not support time types (Instant, Duration, LocalDate) + * which require the full CallContext bridge. + */ + @SuppressWarnings("unchecked") + private static Optional findScalarLiteral( + final List declaredArgs, + final List operands, + final String name, + final Class clazz) { + for (int i = 0; i < declaredArgs.size(); i++) { + final StaticArgument arg = declaredArgs.get(i); + if (!arg.is(StaticArgumentTrait.SCALAR) || !arg.getName().equals(name)) { + continue; + } + final RexNode operand = operands.get(i); + // NULL and DEFAULT produce empty + if (operand.getKind() == SqlKind.DEFAULT || RexUtil.isNullLiteral(operand, true)) { + return Optional.empty(); + } + // DESCRIPTOR and MAP are RexCalls, not RexLiterals + if (operand.getKind() == SqlKind.DESCRIPTOR && clazz == ColumnList.class) { + final List columns = + ((RexCall) operand) + .getOperands().stream() + .map(RexLiteral::stringValue) + .collect(Collectors.toList()); + return Optional.of((T) ColumnList.of(columns)); + } + if (operand.getKind() == SqlKind.MAP_VALUE_CONSTRUCTOR && clazz == Map.class) { + return Optional.ofNullable((T) extractMap((RexCall) operand)); + } + // Standard literals + if (operand instanceof RexLiteral) { + try { + return Optional.ofNullable(((RexLiteral) operand).getValueAs(clazz)); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } + } + return Optional.empty(); + } + return Optional.empty(); + } + + private static @Nullable Map extractMap(final RexCall mapCall) { + final List operands = mapCall.getOperands(); + final Map map = new LinkedHashMap<>(); + for (int i = 0; i < operands.size(); i += 2) { + final @Nullable String key = RexLiteral.stringValue(operands.get(i)); + final @Nullable String value = RexLiteral.stringValue(operands.get(i + 1)); + if (key != null) { + map.put(key, value); + } + } + return map; + } + public static Set deriveOnTimeFields(RexCall call) { final List operands = call.getOperands(); final RexCall onTimeOperand = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java index 0785b1527bf88..e79acfc8b211f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java @@ -21,13 +21,14 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.planner.calcite.RexTableArgCall; -import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalProcessTableFunction; import org.apache.flink.table.planner.plan.rules.physical.common.PhysicalMLPredictTableFunctionRule; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.utils.ShortcutUtils; +import org.apache.flink.table.types.inference.StaticArgument; +import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.calcite.linq4j.Ord; import org.apache.calcite.plan.RelOptRule; @@ -37,8 +38,6 @@ import org.apache.calcite.rel.convert.ConverterRule; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.TableCharacteristic; -import org.apache.calcite.sql.TableCharacteristic.Semantics; import org.checkerframework.checker.nullness.qual.Nullable; import java.util.List; @@ -81,11 +80,9 @@ public boolean matches(RelOptRuleCall call) { final FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) rel; final RexCall rexCall = (RexCall) scan.getCall(); - final BridgingSqlFunction.WithTableFunction function = - (BridgingSqlFunction.WithTableFunction) rexCall.getOperator(); final List operands = rexCall.getOperands(); final List newInputs = - applyDistributionOnInputs(function, operands, rel.getInputs()); + applyDistributionOnInputs(rexCall, operands, rel.getInputs()); final RelTraitSet providedTraitSet = rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()); return new StreamPhysicalProcessTableFunction( @@ -93,42 +90,43 @@ public boolean matches(RelOptRuleCall call) { } private static List applyDistributionOnInputs( - BridgingSqlFunction.WithTableFunction function, - List operands, - List inputs) { + RexCall rexCall, List operands, List inputs) { + final List staticArgs = + StreamPhysicalProcessTableFunction.getStaticArguments(rexCall); return Ord.zip(operands).stream() .filter(operand -> operand.e instanceof RexTableArgCall) .map( tableOperand -> { - final int pos = tableOperand.i; final RexTableArgCall tableArgCall = (RexTableArgCall) tableOperand.e; - final TableCharacteristic tableCharacteristic = - function.tableCharacteristic(pos); - assert tableCharacteristic != null; + final StaticArgument tableArg = staticArgs.get(tableOperand.i); + final StaticArgument resolvedTableArg = + tableArg.applyConditionalTraits( + StreamPhysicalProcessTableFunction.buildTraitContext( + rexCall, tableArgCall)); return applyDistributionOnInput( tableArgCall, - tableCharacteristic, + resolvedTableArg, inputs.get(tableArgCall.getInputIndex())); }) .collect(Collectors.toList()); } private static RelNode applyDistributionOnInput( - RexTableArgCall tableOperand, TableCharacteristic tableCharacteristic, RelNode input) { - final FlinkRelDistribution requiredDistribution = - deriveDistribution(tableOperand, tableCharacteristic); + RexTableArgCall tableOperand, StaticArgument resolvedTableArg, RelNode input) { + final FlinkRelDistribution distribution = + deriveDistribution(tableOperand, resolvedTableArg); final RelTraitSet requiredTraitSet = input.getCluster() .getPlanner() .emptyTraitSet() - .replace(requiredDistribution) + .replace(distribution) .replace(FlinkConventions.STREAM_PHYSICAL()); return RelOptRule.convert(input, requiredTraitSet); } private static FlinkRelDistribution deriveDistribution( - RexTableArgCall tableOperand, TableCharacteristic tableCharacteristic) { - if (tableCharacteristic.semantics == Semantics.SET) { + RexTableArgCall tableOperand, StaticArgument resolvedTableArg) { + if (resolvedTableArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { final int[] partitionKeys = tableOperand.getPartitionKeys(); if (partitionKeys.length == 0) { return FlinkRelDistribution.SINGLETON(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index b6767e68ac6e0..a621b3b3a660c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -1640,6 +1640,10 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti modeBuilder.build() } + /** + * Whether the planner can skip generating UPDATE_BEFORE for this PTF's input. Requires partition + * keys that cover the upsert keys so related events are co-located. + */ private def isPtfUpsert( tableArg: StaticArgument, tableArgCall: RexTableArgCall, @@ -1743,10 +1747,13 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } StreamPhysicalProcessTableFunction .getProvidedInputArgs(call) - .map(_.e) .foreach { - tableArg => - if (tableArg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE)) { + arg => + val tableArg = arg.e + val tableArgCall = call.operands.get(arg.i).asInstanceOf[RexTableArgCall] + val traitCtx = StreamPhysicalProcessTableFunction + .buildTraitContext(call, tableArgCall) + if (tableArg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE, traitCtx)) { throw new ValidationException( s"PTFs that take table arguments with row semantics don't support upsert output. " + s"Table argument '${tableArg.getName}' of function '${call.getOperator.toString}' " + diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.java index ee9576b2e758c..e98e28fc53af9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.java @@ -87,4 +87,21 @@ void testInsertOnlySource() { util.verifyRelPlan( "SELECT * FROM TO_CHANGELOG(input => TABLE insert_only_source)", CHANGELOG_MODE); } + + @Test + void testSetSemanticsWithPartitionBy() { + util.tableEnv() + .executeSql( + "CREATE TABLE retract_source (" + + " id INT," + + " name STRING," + + " PRIMARY KEY (id) NOT ENFORCED" + + ") WITH (" + + " 'connector' = 'values'," + + " 'changelog-mode' = 'I,UB,UA,D'" + + ")"); + util.verifyRelPlan( + "SELECT * FROM TO_CHANGELOG(input => TABLE retract_source PARTITION BY id)", + CHANGELOG_MODE); + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.xml index 91c73153bee03..77133f4fe419c 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ToChangelogTest.xml @@ -51,6 +51,26 @@ LogicalProject(op=[$0], id=[$1], name=[$2]) + + + + + TABLE retract_source PARTITION BY id)]]> + + + + + + From bc7ca02841b165ec5a1dbf65526733f6a6e1836e Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Tue, 21 Apr 2026 13:40:37 +0200 Subject: [PATCH 02/10] [FLINK-39392][table] Rename isPtfUpsert to ptfRequiresUpdateBefore --- .../FlinkChangelogModeInferenceProgram.scala | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index a621b3b3a660c..177a691bb0634 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -852,10 +852,10 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti !modifyKindSet.isInsertOnly && tableArg.is( StaticArgumentTrait.SUPPORT_UPDATES) ) { - if (isPtfUpsert(tableArg, tableArgCall, child)) { - UpdateKindTrait.ONLY_UPDATE_AFTER - } else { + if (ptfRequiresUpdateBefore(tableArg, tableArgCall, child)) { UpdateKindTrait.BEFORE_AND_AFTER + } else { + UpdateKindTrait.ONLY_UPDATE_AFTER } } else { UpdateKindTrait.NONE @@ -1272,7 +1272,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti extractPtfTableArgComponents(process, child, inputArg) if ( tableArg.is(StaticArgumentTrait.SUPPORT_UPDATES) - && isPtfUpsert(tableArg, tableArgCall, child) + && !ptfRequiresUpdateBefore(tableArg, tableArgCall, child) && !tableArg.is(StaticArgumentTrait.REQUIRE_FULL_DELETE) ) { this @@ -1641,24 +1641,18 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } /** - * Whether the planner can skip generating UPDATE_BEFORE for this PTF's input. Requires partition - * keys that cover the upsert keys so related events are co-located. + * Whether the PTF requires UPDATE_BEFORE from its input. Returns true unless partition keys + * cover the upsert keys (co-located) and the argument doesn't explicitly require UPDATE_BEFORE. */ - private def isPtfUpsert( + private def ptfRequiresUpdateBefore( tableArg: StaticArgument, tableArgCall: RexTableArgCall, input: StreamPhysicalRel): Boolean = { val partitionKeys = ImmutableBitSet.of(tableArgCall.getPartitionKeys: _*) val fmq = FlinkRelMetadataQuery.reuseOrCreate(input.getCluster.getMetadataQuery) val upsertKeys = fmq.getUpsertKeys(input) - if ( - upsertKeys == null || partitionKeys.isEmpty || !upsertKeys.contains(partitionKeys) - || tableArg.is(StaticArgumentTrait.REQUIRE_UPDATE_BEFORE) - ) { - false - } else { - true - } + upsertKeys == null || partitionKeys.isEmpty || !upsertKeys.contains(partitionKeys) + || tableArg.is(StaticArgumentTrait.REQUIRE_UPDATE_BEFORE) } private def extractPtfTableArgComponents( From d03839757c1810fa272ffb3cb9683008e339b8af Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Tue, 21 Apr 2026 16:00:46 +0200 Subject: [PATCH 03/10] [FLINK-39392][table] Improve documentation and simplify code --- .../docs/sql/reference/queries/changelog.md | 2 +- .../table/types/inference/StaticArgument.java | 3 +- .../table/types/inference/TraitCondition.java | 7 +- .../bridging/BridgingSqlFunction.java | 3 +- .../StreamPhysicalProcessTableFunction.java | 81 ++++--------------- .../FlinkChangelogModeInferenceProgram.scala | 9 ++- 6 files changed, 28 insertions(+), 77 deletions(-) diff --git a/docs/content/docs/sql/reference/queries/changelog.md b/docs/content/docs/sql/reference/queries/changelog.md index 2b66ded50d013..494dcc60be892 100644 --- a/docs/content/docs/sql/reference/queries/changelog.md +++ b/docs/content/docs/sql/reference/queries/changelog.md @@ -187,7 +187,7 @@ SELECT * FROM TO_CHANGELOG( | Parameter | Required | Description | |:-------------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `input` | Yes | The input table. With `PARTITION BY`, rows with the same key are co-located and run in the same operator instance. Without `PARTITION BY`, each row is processed independently. Accepts insert-only, retract, and upsert tables. For upsert tables, providing `PARTITION BY` is recommended for better performance. | +| `input` | Yes | The input table. With `PARTITION BY`, rows with the same key are co-located and run in the same operator instance. Without `PARTITION BY`, each row is processed independently. Accepts insert-only, retract, and upsert tables. For upsert tables, the provided `PARTITION BY` key should match or be a subset of the upsert key of the subquery. | | `op` | No | A `DESCRIPTOR` with a single column name for the operation code column. Defaults to `op`. | | `op_mapping` | No | A `MAP` mapping change operation names to custom output codes. Keys can contain comma-separated names to map multiple operations to the same code (e.g., `'INSERT, UPDATE_AFTER'`). When provided, only mapped operations are forwarded - unmapped events are dropped. Each change operation may appear at most once across all entries. | diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java index 1a027d430b222..6f10c6dedb427 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java @@ -31,7 +31,6 @@ import javax.annotation.Nullable; -import java.io.Serializable; import java.util.ArrayList; import java.util.EnumSet; import java.util.List; @@ -460,7 +459,7 @@ private void checkModelNotOptional() { } /** A trait that is conditionally added based on a {@link TraitCondition}. */ - private static final class ConditionalTrait implements Serializable { + private static final class ConditionalTrait { private final TraitCondition condition; private final StaticArgumentTrait trait; diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java index 61db5a2f3e614..6189c3f4c1d6a 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java @@ -20,8 +20,6 @@ import org.apache.flink.annotation.PublicEvolving; -import java.io.Serializable; - /** * A condition that determines whether a conditional trait on a {@link StaticArgument} should be * active for a given call. @@ -29,6 +27,9 @@ *

Conditions are evaluated at planning time using the {@link TraitContext} which provides access * to the SQL call's properties (PARTITION BY presence, scalar literal values, etc.). * + *

Implementations must implement {@code hashCode} and {@code equals} for {@link + * StaticArgument#equals}/{@link StaticArgument#hashCode} to work correctly. + * *

Use the static factory methods for common conditions: * *

{@code
@@ -40,7 +41,7 @@
  */
 @PublicEvolving
 @FunctionalInterface
-public interface TraitCondition extends Serializable {
+public interface TraitCondition {
 
     /** Evaluates this condition against the given context. */
     boolean test(TraitContext ctx);
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java
index e5080ad94fcbe..235ef4f319b1a 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java
@@ -273,8 +273,7 @@ public SqlReturnTypeInference getRowTypeInference() {
             }
             final StaticArgument arg = args.get(ordinal);
             final TableCharacteristic.Semantics semantics;
-            // Report SET if it may apply - which allows the use of Partition BY
-            // actual semantics resolved later via resolveTraits
+            // Report SET semantics if it may apply - which allows the use of PARTITION BY
             if (arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)
                     || arg.hasConditionalTrait(StaticArgumentTrait.SET_SEMANTIC_TABLE)) {
                 semantics = TableCharacteristic.Semantics.SET;
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
index 2122d516596a1..c654f850e1feb 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
@@ -41,7 +41,6 @@
 import org.apache.flink.table.types.inference.StaticArgumentTrait;
 import org.apache.flink.table.types.inference.SystemTypeInference;
 import org.apache.flink.table.types.inference.TraitContext;
-import org.apache.flink.types.ColumnList;
 import org.apache.flink.types.RowKind;
 
 import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet;
@@ -70,9 +69,7 @@
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
-import java.util.LinkedHashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
@@ -389,8 +386,8 @@ public static List> getProvidedInputArgs(RexCall call) {
      */
     public static TraitContext buildTraitContext(
             final RexCall call, final RexTableArgCall tableArgCall) {
+        final CallContext callContext = toCallContext(call);
         final List declaredArgs = getStaticArguments(call);
-        final List operands = call.getOperands();
 
         return new TraitContext() {
             @Override
@@ -400,7 +397,13 @@ public boolean hasPartitionBy() {
 
             @Override
             public  Optional getScalarArgument(final String name, final Class clazz) {
-                return findScalarLiteral(declaredArgs, operands, name, clazz);
+                for (int i = 0; i < declaredArgs.size(); i++) {
+                    final StaticArgument arg = declaredArgs.get(i);
+                    if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) {
+                        return callContext.getArgumentValue(i, clazz);
+                    }
+                }
+                return Optional.empty();
             }
         };
     }
@@ -435,66 +438,6 @@ public static List getStaticArguments(final RexCall call) {
                 .orElseThrow(IllegalStateException::new);
     }
 
-    /**
-     * Extracts a scalar argument value by name. Handles NULL, DEFAULT, DESCRIPTOR, MAP, and
-     * standard literal values (Boolean, String, Integer, etc.) following the same rules as {@link
-     * CallContext#getArgumentValue}. Does not support time types (Instant, Duration, LocalDate)
-     * which require the full CallContext bridge.
-     */
-    @SuppressWarnings("unchecked")
-    private static  Optional findScalarLiteral(
-            final List declaredArgs,
-            final List operands,
-            final String name,
-            final Class clazz) {
-        for (int i = 0; i < declaredArgs.size(); i++) {
-            final StaticArgument arg = declaredArgs.get(i);
-            if (!arg.is(StaticArgumentTrait.SCALAR) || !arg.getName().equals(name)) {
-                continue;
-            }
-            final RexNode operand = operands.get(i);
-            // NULL and DEFAULT produce empty
-            if (operand.getKind() == SqlKind.DEFAULT || RexUtil.isNullLiteral(operand, true)) {
-                return Optional.empty();
-            }
-            // DESCRIPTOR and MAP are RexCalls, not RexLiterals
-            if (operand.getKind() == SqlKind.DESCRIPTOR && clazz == ColumnList.class) {
-                final List columns =
-                        ((RexCall) operand)
-                                .getOperands().stream()
-                                        .map(RexLiteral::stringValue)
-                                        .collect(Collectors.toList());
-                return Optional.of((T) ColumnList.of(columns));
-            }
-            if (operand.getKind() == SqlKind.MAP_VALUE_CONSTRUCTOR && clazz == Map.class) {
-                return Optional.ofNullable((T) extractMap((RexCall) operand));
-            }
-            // Standard literals
-            if (operand instanceof RexLiteral) {
-                try {
-                    return Optional.ofNullable(((RexLiteral) operand).getValueAs(clazz));
-                } catch (IllegalArgumentException e) {
-                    return Optional.empty();
-                }
-            }
-            return Optional.empty();
-        }
-        return Optional.empty();
-    }
-
-    private static @Nullable Map extractMap(final RexCall mapCall) {
-        final List operands = mapCall.getOperands();
-        final Map map = new LinkedHashMap<>();
-        for (int i = 0; i < operands.size(); i += 2) {
-            final @Nullable String key = RexLiteral.stringValue(operands.get(i));
-            final @Nullable String value = RexLiteral.stringValue(operands.get(i + 1));
-            if (key != null) {
-                map.put(key, value);
-            }
-        }
-        return map;
-    }
-
     public static Set deriveOnTimeFields(RexCall call) {
         final List operands = call.getOperands();
         final RexCall onTimeOperand =
@@ -629,6 +572,14 @@ public static Set toPartitionColumns(RexCall call) {
         return ImmutableSet.copyOf(partitionColumnsPerArg);
     }
 
+    /**
+     * Creates a CallContext for argument value extraction only, (no changelog and no input time
+     * columns).
+     */
+    public static CallContext toCallContext(RexCall udfCall) {
+        return toCallContext(udfCall, null, null, null);
+    }
+
     public static CallContext toCallContext(
             RexCall udfCall,
             List inputTimeColumns,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala
index 177a691bb0634..1b8f864b6fef8 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala
@@ -1641,8 +1641,8 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
   }
 
   /**
-   * Whether the PTF requires UPDATE_BEFORE from its input. Returns true unless partition keys
-   * cover the upsert keys (co-located) and the argument doesn't explicitly require UPDATE_BEFORE.
+   * Whether the PTF requires UPDATE_BEFORE from its input. Returns true unless partition keys cover
+   * the upsert keys (co-located) and the argument doesn't explicitly require UPDATE_BEFORE.
    */
   private def ptfRequiresUpdateBefore(
       tableArg: StaticArgument,
@@ -1651,8 +1651,9 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti
     val partitionKeys = ImmutableBitSet.of(tableArgCall.getPartitionKeys: _*)
     val fmq = FlinkRelMetadataQuery.reuseOrCreate(input.getCluster.getMetadataQuery)
     val upsertKeys = fmq.getUpsertKeys(input)
-    upsertKeys == null || partitionKeys.isEmpty || !upsertKeys.contains(partitionKeys)
-    || tableArg.is(StaticArgumentTrait.REQUIRE_UPDATE_BEFORE)
+    upsertKeys == null || partitionKeys.isEmpty ||
+    !upsertKeys.contains(partitionKeys) ||
+    tableArg.is(StaticArgumentTrait.REQUIRE_UPDATE_BEFORE)
   }
 
   private def extractPtfTableArgComponents(

From 1ecf4314f4471e2787ff368c2a38d0e03f257601 Mon Sep 17 00:00:00 2001
From: Gustavo de Morais 
Date: Tue, 21 Apr 2026 16:23:01 +0200
Subject: [PATCH 04/10] [FLINK-39392][table] Resolve traits once in
 StreamPhysicalProcessTableFunction and reuse it

---
 .../StreamExecProcessTableFunction.java       | 18 +++---
 .../StreamPhysicalProcessTableFunction.java   | 60 ++++++++++---------
 2 files changed, 40 insertions(+), 38 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java
index 9fe3467ec4c80..add2368a529c3 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java
@@ -186,9 +186,13 @@ protected Transformation translateToPlanInternal(
                                 providedInputArg -> {
                                     final RexTableArgCall tableArgCall =
                                             (RexTableArgCall) operands.get(providedInputArg.i);
-                                    final StaticArgument tableArg = providedInputArg.e;
+                                    final StaticArgument resolvedArg =
+                                            providedInputArg.e.applyConditionalTraits(
+                                                    StreamPhysicalProcessTableFunction
+                                                            .buildTraitContext(
+                                                                    invocation, tableArgCall));
                                     return createRuntimeTableSemantics(
-                                            invocation, tableArg, tableArgCall, inputTimeColumns);
+                                            resolvedArg, tableArgCall, inputTimeColumns);
                                 })
                         .collect(Collectors.toList());
 
@@ -293,24 +297,20 @@ protected Transformation translateToPlanInternal(
     }
 
     private RuntimeTableSemantics createRuntimeTableSemantics(
-            RexCall call,
-            StaticArgument tableArg,
+            StaticArgument resolvedArg,
             RexTableArgCall tableArgCall,
             List inputTimeColumns) {
         final RuntimeChangelogMode consumedChangelogMode =
                 RuntimeChangelogMode.serialize(
                         inputChangelogModes.get(tableArgCall.getInputIndex()));
         final DataType dataType;
-        if (tableArg.getDataType().isPresent()) {
-            dataType = tableArg.getDataType().get();
+        if (resolvedArg.getDataType().isPresent()) {
+            dataType = resolvedArg.getDataType().get();
         } else {
             dataType = DataTypes.of(FlinkTypeFactory.toLogicalRowType(tableArgCall.type));
         }
 
         final int timeColumn = inputTimeColumns.get(tableArgCall.getInputIndex());
-        final StaticArgument resolvedArg =
-                tableArg.applyConditionalTraits(
-                        StreamPhysicalProcessTableFunction.buildTraitContext(call, tableArgCall));
 
         return new RuntimeTableSemantics(
                 resolvedArg.getName(),
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
index c654f850e1feb..4e088c14b2756 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java
@@ -92,6 +92,7 @@ public class StreamPhysicalProcessTableFunction extends AbstractRelNode
 
     private final FlinkLogicalTableFunctionScan scan;
     private final @Nullable String uid;
+    private final List resolvedStaticArgs;
 
     private List inputs;
 
@@ -107,10 +108,15 @@ public StreamPhysicalProcessTableFunction(
         this.scan = scan;
         final RexCall call = (RexCall) scan.getCall();
         validateAllowSystemArgs(call);
-        this.uid = deriveUniqueIdentifier(scan);
+        this.resolvedStaticArgs = resolveStaticArgs(call);
+        this.uid = deriveUniqueIdentifier(scan, resolvedStaticArgs);
         verifyInputSize(ShortcutUtils.unwrapTableConfig(cluster), inputs.size());
     }
 
+    public List getResolvedStaticArgs() {
+        return resolvedStaticArgs;
+    }
+
     public StreamPhysicalProcessTableFunction(
             RelOptCluster cluster,
             RelTraitSet traitSet,
@@ -208,21 +214,39 @@ protected RelDataType deriveRowType() {
      *
      * @see SystemTypeInference
      */
-    private static @Nullable String deriveUniqueIdentifier(FlinkLogicalTableFunctionScan scan) {
+    /** Resolves conditional traits for all static args at construction time. */
+    private static List resolveStaticArgs(final RexCall call) {
+        final List staticArgs = getStaticArguments(call);
+        final List operands = call.getOperands();
+        return IntStream.range(0, staticArgs.size())
+                .mapToObj(
+                        i -> {
+                            final StaticArgument arg = staticArgs.get(i);
+                            if (!arg.is(StaticArgumentTrait.TABLE) || !arg.hasConditionalTraits()) {
+                                return arg;
+                            }
+                            final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(i);
+                            return arg.applyConditionalTraits(
+                                    buildTraitContext(call, tableArgCall));
+                        })
+                .collect(Collectors.toList());
+    }
+
+    private static @Nullable String deriveUniqueIdentifier(
+            FlinkLogicalTableFunctionScan scan, List resolvedStaticArgs) {
         final RexCall rexCall = (RexCall) scan.getCall();
         final BridgingSqlFunction.WithTableFunction function =
                 (BridgingSqlFunction.WithTableFunction) rexCall.getOperator();
-        final List staticArgs =
-                function.getTypeInference()
-                        .getStaticArguments()
-                        .orElseThrow(IllegalStateException::new);
         final ContextResolvedFunction resolvedFunction = function.getResolvedFunction();
         final List operands = rexCall.getOperands();
         // Type inference ensures that uid is always added at the end
         final RexNode uidRexNode = operands.get(operands.size() - 1);
         if (uidRexNode.getKind() == SqlKind.DEFAULT) {
             // Optional for constant or row semantics functions
-            if (!hasResolvedSetSemantics(staticArgs, operands, rexCall)) {
+            final boolean hasSetSemantics =
+                    resolvedStaticArgs.stream()
+                            .anyMatch(arg -> arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE));
+            if (!hasSetSemantics) {
                 return null;
             }
             final String uid =
@@ -408,28 +432,6 @@ public  Optional getScalarArgument(final String name, final Class clazz
         };
     }
 
-    /** Checks if any table argument resolves to SET_SEMANTIC_TABLE after applying conditions. */
-    private static boolean hasResolvedSetSemantics(
-            final List staticArgs,
-            final List operands,
-            final RexCall rexCall) {
-        for (int i = 0; i < staticArgs.size(); i++) {
-            final StaticArgument arg = staticArgs.get(i);
-            if (!arg.is(StaticArgumentTrait.TABLE)) {
-                continue;
-            }
-
-            final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(i);
-
-            final StaticArgument resolvedArg =
-                    arg.applyConditionalTraits(buildTraitContext(rexCall, tableArgCall));
-            if (resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
     public static List getStaticArguments(final RexCall call) {
         final BridgingSqlFunction.WithTableFunction function =
                 (BridgingSqlFunction.WithTableFunction) call.getOperator();

From 785947a572d27aa9723e0c73b402a8d7e214d378 Mon Sep 17 00:00:00 2001
From: Gustavo de Morais 
Date: Wed, 22 Apr 2026 16:12:49 +0200
Subject: [PATCH 05/10] [FLINK-39392][table] Simplify trait resolution in
 TableFunctionScan and ExecNode deserialization

---
 .../table/types/inference/StaticArgument.java |  11 +-
 .../types/inference/StaticArgumentTrait.java  |  24 ++++
 .../types/inference/SystemTypeInference.java  |  92 +++++++-------
 .../table/types/inference/TraitCondition.java | 103 +++++++++++++--
 .../table/types/inference/TraitContext.java   |  37 ++++++
 .../bridging/BridgingSqlFunction.java         | 120 ++++++++++++++++++
 .../StreamExecProcessTableFunction.java       |  28 ++--
 .../FlinkLogicalTableFunctionScan.java        |  10 +-
 .../StreamPhysicalProcessTableFunction.java   |  85 ++-----------
 ...treamPhysicalProcessTableFunctionRule.java |  38 +++---
 .../FlinkChangelogModeInferenceProgram.scala  |   9 +-
 11 files changed, 376 insertions(+), 181 deletions(-)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
index 6f10c6dedb427..546f30a7f7ea8 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
@@ -236,9 +236,7 @@ public boolean is(StaticArgumentTrait trait, TraitContext ctx) {
      */
     public StaticArgument withConditionalTrait(
             final StaticArgumentTrait trait, final TraitCondition condition) {
-        if (trait == StaticArgumentTrait.SCALAR
-                || trait == StaticArgumentTrait.TABLE
-                || trait == StaticArgumentTrait.MODEL) {
+        if (trait.isRoot()) {
             throw new IllegalArgumentException(
                     "Root traits (SCALAR, TABLE, MODEL) cannot be conditional.");
         }
@@ -286,14 +284,9 @@ public EnumSet resolveTraits(final TraitContext ctx) {
         return resolved;
     }
 
-    /** ROW and SET semantics are mutually exclusive - adding one removes the other. */
     private static void removeMutuallyExclusiveTraits(
             final EnumSet traits, final StaticArgumentTrait adding) {
-        if (adding == StaticArgumentTrait.SET_SEMANTIC_TABLE) {
-            traits.remove(StaticArgumentTrait.ROW_SEMANTIC_TABLE);
-        } else if (adding == StaticArgumentTrait.ROW_SEMANTIC_TABLE) {
-            traits.remove(StaticArgumentTrait.SET_SEMANTIC_TABLE);
-        }
+        traits.removeAll(adding.getIncompatibleWith());
     }
 
     @Override
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
index 647248ada2df4..7b0083ed2dcac 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
@@ -21,6 +21,8 @@
 import org.apache.flink.annotation.PublicEvolving;
 
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.EnumSet;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -51,6 +53,8 @@ public enum StaticArgumentTrait {
     REQUIRE_UPDATE_BEFORE(SUPPORT_UPDATES),
     REQUIRE_FULL_DELETE(SUPPORT_UPDATES);
 
+    private static final Set ROOTS = EnumSet.of(SCALAR, TABLE, MODEL);
+
     private final Set requirements;
 
     StaticArgumentTrait(StaticArgumentTrait... requirements) {
@@ -60,4 +64,24 @@ public enum StaticArgumentTrait {
     public Set getRequirements() {
         return requirements;
     }
+
+    /** Whether this trait is one of the top-level roots (SCALAR, TABLE, MODEL). */
+    public boolean isRoot() {
+        return ROOTS.contains(this);
+    }
+
+    /**
+     * Returns the traits that are mutually exclusive with this one. Adding this trait to a set
+     * implies removing all returned traits. Empty by default.
+     */
+    public Set getIncompatibleWith() {
+        switch (this) {
+            case SET_SEMANTIC_TABLE:
+                return Collections.singleton(ROW_SEMANTIC_TABLE);
+            case ROW_SEMANTIC_TABLE:
+                return Collections.singleton(SET_SEMANTIC_TABLE);
+            default:
+                return Collections.emptySet();
+        }
+    }
 }
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java
index 0d9641b96adda..1ce06e1e36777 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java
@@ -182,30 +182,27 @@ private static void checkReservedArgs(List staticArgs) {
         }
     }
 
-    static TraitContext buildTraitContext(
-            @Nullable final TableSemantics semantics,
-            final CallContext callContext,
-            final List staticArgs) {
-        return new TraitContext() {
-            @Override
-            public boolean hasPartitionBy() {
-                return semantics != null && semantics.partitionByColumns().length > 0;
-            }
-
-            @Override
-            public  Optional getScalarArgument(final String name, final Class clazz) {
-                for (int i = 0; i < staticArgs.size(); i++) {
-                    final StaticArgument arg = staticArgs.get(i);
-                    if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) {
-                        if (!callContext.isArgumentLiteral(i)) {
-                            return Optional.empty();
-                        }
-                        return callContext.getArgumentValue(i, clazz);
-                    }
-                }
-                return Optional.empty();
-            }
-        };
+    /**
+     * Resolves conditional traits (see {@link StaticArgument#withConditionalTrait}) on every static
+     * arg using the call's semantics and operands. Called once at the top of {@link
+     * SystemInputStrategy#inferInputTypes} and {@link SystemOutputStrategy#inferType}; downstream
+     * helpers receive the resolved list and iterate it directly.
+     */
+    private static List resolveStaticArgs(
+            final CallContext callContext, final List staticArgs) {
+        return IntStream.range(0, staticArgs.size())
+                .mapToObj(
+                        pos -> {
+                            final StaticArgument arg = staticArgs.get(pos);
+                            if (!arg.hasConditionalTraits()) {
+                                return arg;
+                            }
+                            final TableSemantics semantics =
+                                    callContext.getTableSemantics(pos).orElse(null);
+                            return arg.applyConditionalTraits(
+                                    TraitContext.of(semantics, callContext, staticArgs));
+                        })
+                .collect(Collectors.toList());
     }
 
     private static void checkMultipleTableArgs(List staticArgs) {
@@ -288,6 +285,10 @@ public Optional inferType(CallContext callContext) {
             return origin.inferType(callContext)
                     .map(
                             functionDataType -> {
+                                // Resolve once so all helpers see the same effective signature
+                                // (PARTITION BY / scalar literals applied to conditional traits).
+                                final List resolvedArgs =
+                                        resolveStaticArgs(callContext, staticArgs);
                                 final List fields = new ArrayList<>();
 
                                 // According to the SQL standard, pass-through columns should
@@ -299,11 +300,11 @@ public Optional inferType(CallContext callContext) {
                                 // - Flink SESSION windows add pass-through columns at the beginning
                                 // - Oracle adds pass-through columns for all ROW semantics args, so
                                 // this whole topic is kind of vendor specific already
-                                fields.addAll(derivePassThroughFields(callContext));
+                                fields.addAll(derivePassThroughFields(callContext, resolvedArgs));
                                 fields.addAll(deriveFunctionOutputFields(functionDataType));
 
                                 if (!disableSystemArgs) {
-                                    fields.addAll(deriveRowtimeField(callContext));
+                                    fields.addAll(deriveRowtimeField(callContext, resolvedArgs));
                                 }
 
                                 final List uniqueFields = makeFieldNamesUnique(fields);
@@ -329,23 +330,21 @@ private List makeFieldNamesUnique(List fields) {
                     .collect(Collectors.toList());
         }
 
-        private List derivePassThroughFields(CallContext callContext) {
+        private List derivePassThroughFields(
+                CallContext callContext, List resolvedArgs) {
             if (functionKind != FunctionKind.PROCESS_TABLE) {
                 return List.of();
             }
             final List argDataTypes = callContext.getArgumentDataTypes();
-            return IntStream.range(0, staticArgs.size())
+            return IntStream.range(0, resolvedArgs.size())
                     .mapToObj(
                             pos -> {
-                                final TableSemantics semantics =
-                                        callContext.getTableSemantics(pos).orElse(null);
-                                final TraitContext traitCtx =
-                                        buildTraitContext(semantics, callContext, staticArgs);
-                                final StaticArgument resolvedArg =
-                                        staticArgs.get(pos).applyConditionalTraits(traitCtx);
+                                final StaticArgument resolvedArg = resolvedArgs.get(pos);
                                 if (resolvedArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
                                     return DataType.getFields(argDataTypes.get(pos)).stream();
                                 }
+                                final TableSemantics semantics =
+                                        callContext.getTableSemantics(pos).orElse(null);
                                 if (semantics == null) {
                                     return Stream.empty();
                                 }
@@ -379,7 +378,8 @@ private List deriveFunctionOutputFields(DataType functionDataType) {
                     .collect(Collectors.toList());
         }
 
-        private List deriveRowtimeField(CallContext callContext) {
+        private List deriveRowtimeField(
+                CallContext callContext, List resolvedArgs) {
             if (this.functionKind != FunctionKind.PROCESS_TABLE) {
                 return List.of();
             }
@@ -398,10 +398,10 @@ private List deriveRowtimeField(CallContext callContext) {
 
             final List onTimeColumns = new ArrayList<>();
             final List missingOnTimeColumns = new ArrayList<>();
-            IntStream.range(0, staticArgs.size())
+            IntStream.range(0, resolvedArgs.size())
                     .forEach(
                             pos -> {
-                                final StaticArgument staticArg = staticArgs.get(pos);
+                                final StaticArgument staticArg = resolvedArgs.get(pos);
                                 if (!staticArg.is(StaticArgumentTrait.TABLE)) {
                                     return;
                                 }
@@ -596,8 +596,10 @@ public Optional> inferInputTypes(
                                 + "that is not overloaded and doesn't contain varargs.");
             }
 
+            // Resolve once so the rest of validation iterates the effective signature.
+            final List resolvedArgs = resolveStaticArgs(callContext, staticArgs);
             try {
-                checkTableArgs(staticArgs, callContext);
+                checkTableArgs(resolvedArgs, callContext);
                 if (!disableSystemArgs) {
                     checkUidArg(callContext);
                 }
@@ -633,13 +635,13 @@ private static void checkUidArg(CallContext callContext) {
         }
 
         private static void checkTableArgs(
-                List staticArgs, CallContext callContext) {
+                List resolvedArgs, CallContext callContext) {
             final List tableSemantics = new ArrayList<>();
-            IntStream.range(0, staticArgs.size())
+            IntStream.range(0, resolvedArgs.size())
                     .forEach(
                             pos -> {
-                                final StaticArgument staticArg = staticArgs.get(pos);
-                                if (!staticArg.is(StaticArgumentTrait.TABLE)) {
+                                final StaticArgument resolvedArg = resolvedArgs.get(pos);
+                                if (!resolvedArg.is(StaticArgumentTrait.TABLE)) {
                                     return;
                                 }
                                 final TableSemantics semantics =
@@ -648,12 +650,8 @@ private static void checkTableArgs(
                                     throw new ValidationException(
                                             String.format(
                                                     "Table expected for argument '%s'.",
-                                                    staticArg.getName()));
+                                                    resolvedArg.getName()));
                                 }
-                                final TraitContext traitCtx =
-                                        buildTraitContext(semantics, callContext, staticArgs);
-                                final StaticArgument resolvedArg =
-                                        staticArg.applyConditionalTraits(traitCtx);
                                 checkRowSemantics(resolvedArg, semantics);
                                 checkSetSemantics(resolvedArg, semantics);
                                 tableSemantics.add(semantics);
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java
index 6189c3f4c1d6a..5d0e9b84b5cc8 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java
@@ -20,6 +20,8 @@
 
 import org.apache.flink.annotation.PublicEvolving;
 
+import java.util.Objects;
+
 /**
  * A condition that determines whether a conditional trait on a {@link StaticArgument} should be
  * active for a given call.
@@ -28,7 +30,9 @@
  * to the SQL call's properties (PARTITION BY presence, scalar literal values, etc.).
  *
  * 

Implementations must implement {@code hashCode} and {@code equals} for {@link - * StaticArgument#equals}/{@link StaticArgument#hashCode} to work correctly. + * StaticArgument#equals}/{@link StaticArgument#hashCode} to work correctly. The built-in factories + * below return value-comparable instances; user-supplied lambdas do not - prefer the factories or + * named classes. * *

Use the static factory methods for common conditions: * @@ -48,20 +52,103 @@ public interface TraitCondition { /** True when PARTITION BY is provided on the table argument. */ static TraitCondition hasPartitionBy() { - return TraitContext::hasPartitionBy; + return HasPartitionByCondition.INSTANCE; } /** True when the named scalar argument equals the expected value. */ - @SuppressWarnings("unchecked") static TraitCondition argIsEqualTo(final String name, final T expected) { - return ctx -> - ctx.getScalarArgument(name, (Class) expected.getClass()) - .map(expected::equals) - .orElse(false); + return new ArgIsEqualToCondition<>(name, expected); } /** Negates the given condition. */ static TraitCondition not(final TraitCondition condition) { - return ctx -> !condition.test(ctx); + return new NotCondition(condition); + } + + // -------------------------------------------------------------------------------------------- + // Built-in implementations - named so that StaticArgument equality cascades correctly. + // -------------------------------------------------------------------------------------------- + + /** Singleton condition that is true when PARTITION BY is provided on the table argument. */ + final class HasPartitionByCondition implements TraitCondition { + + private static final HasPartitionByCondition INSTANCE = new HasPartitionByCondition(); + + private HasPartitionByCondition() {} + + @Override + public boolean test(final TraitContext ctx) { + return ctx.hasPartitionBy(); + } + + // equals/hashCode by identity - safe because there is exactly one instance. + } + + /** Condition that is true when the named scalar argument equals the expected value. */ + final class ArgIsEqualToCondition implements TraitCondition { + + private final String name; + private final T expected; + private final Class clazz; + + @SuppressWarnings("unchecked") + ArgIsEqualToCondition(final String name, final T expected) { + this.name = name; + this.expected = expected; + this.clazz = (Class) expected.getClass(); + } + + @Override + public boolean test(final TraitContext ctx) { + return ctx.getScalarArgument(name, clazz).map(expected::equals).orElse(false); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ArgIsEqualToCondition)) { + return false; + } + final ArgIsEqualToCondition that = (ArgIsEqualToCondition) o; + return name.equals(that.name) && expected.equals(that.expected); + } + + @Override + public int hashCode() { + return Objects.hash(name, expected); + } + } + + /** Condition that negates another condition. */ + final class NotCondition implements TraitCondition { + + private final TraitCondition condition; + + NotCondition(final TraitCondition condition) { + this.condition = condition; + } + + @Override + public boolean test(final TraitContext ctx) { + return !condition.test(ctx); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NotCondition)) { + return false; + } + return condition.equals(((NotCondition) o).condition); + } + + @Override + public int hashCode() { + return Objects.hash(NotCondition.class, condition); + } } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java index 317a9a0ece054..e58a4b36dceea 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitContext.java @@ -19,7 +19,11 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.functions.TableSemantics; +import javax.annotation.Nullable; + +import java.util.List; import java.util.Optional; /** @@ -41,4 +45,37 @@ public interface TraitContext { * cannot be converted to the requested type */ Optional getScalarArgument(String name, Class clazz); + + /** + * Builds a {@link TraitContext} from validation-time inputs. + * + *

Used by {@code SystemTypeInference} when wrapping a function's strategies. Planner-side + * code that has a {@code RexCall} should use the planner adapter in {@code BridgingSqlFunction} + * instead. + */ + static TraitContext of( + @Nullable final TableSemantics semantics, + final CallContext callContext, + final List staticArgs) { + return new TraitContext() { + @Override + public boolean hasPartitionBy() { + return semantics != null && semantics.partitionByColumns().length > 0; + } + + @Override + public Optional getScalarArgument(final String name, final Class clazz) { + for (int i = 0; i < staticArgs.size(); i++) { + final StaticArgument arg = staticArgs.get(i); + if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) { + if (!callContext.isArgumentLiteral(i)) { + return Optional.empty(); + } + return callContext.getArgumentValue(i, clazz); + } + } + return Optional.empty(); + } + }; + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java index 235ef4f319b1a..e62b44830ecc6 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java @@ -30,17 +30,23 @@ import org.apache.flink.table.planner.calcite.FlinkRelBuilder; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexFactory; +import org.apache.flink.table.planner.calcite.RexTableArgCall; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; +import org.apache.flink.table.types.inference.TraitContext; import org.apache.flink.table.types.inference.TypeInference; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.StructKind; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; @@ -51,7 +57,9 @@ import org.checkerframework.checker.nullness.qual.Nullable; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.flink.table.planner.functions.bridging.BridgingUtils.createName; import static org.apache.flink.table.planner.functions.bridging.BridgingUtils.createSqlFunctionCategory; @@ -216,6 +224,118 @@ public boolean isDeterministic() { return resolvedFunction.getDefinition().isDeterministic(); } + // -------------------------------------------------------------------------------------------- + // Conditional trait resolution + // -------------------------------------------------------------------------------------------- + + /** + * Rewrites {@code call} so that the operator's {@link StaticArgument}s have any conditional + * traits (see {@link StaticArgument#withConditionalTrait}) applied against the call site + * (PARTITION BY, scalar literals). Downstream consumers can then treat the operator's static + * arguments as the effective signature and use plain {@code arg.is(SET_SEMANTIC_TABLE)} checks. + * + *

Called from the two places where a planner-level {@link RexCall} for a PTF is first built + * for downstream consumption: {@code FlinkLogicalTableFunctionScan} converter (fresh planning) + * and {@code StreamExecProcessTableFunction.@JsonCreator} (compiled-plan restore). A no-op for + * non-PTF calls and for PTFs that declare no conditional traits. + */ + public static RexCall resolveCallTraits(RexCall call) { + if (!(call.getOperator() instanceof BridgingSqlFunction)) { + return call; + } + final BridgingSqlFunction function = (BridgingSqlFunction) call.getOperator(); + final List declared = + function.typeInference.getStaticArguments().orElse(null); + if (declared == null || declared.stream().noneMatch(StaticArgument::hasConditionalTraits)) { + return call; + } + final List operands = call.getOperands(); + final List resolved = + IntStream.range(0, declared.size()) + .mapToObj(i -> resolveArg(declared.get(i), declared, operands, i)) + .collect(Collectors.toList()); + if (resolved.equals(declared)) { + return call; + } + final BridgingSqlFunction rewritten = function.withStaticArguments(resolved); + // Use a fresh RexBuilder from the function's own type factory so this can run from a + // Jackson @JsonCreator that has no planner context. + return (RexCall) + new RexBuilder(function.typeFactory).makeCall(call.getType(), rewritten, operands); + } + + private static StaticArgument resolveArg( + StaticArgument declaredArg, + List declared, + List operands, + int index) { + // We only resolve conditional traits for the Table Argument with conditional traits + if (!declaredArg.hasConditionalTraits() + || !(operands.get(index) instanceof RexTableArgCall)) { + return declaredArg; + } + return declaredArg.applyConditionalTraits( + buildTraitContext((RexTableArgCall) operands.get(index), declared, operands)); + } + + /** + * Planner-side adapter to {@link TraitContext}. Sourced from a {@link RexCall} (PARTITION BY + * via {@link RexTableArgCall}, scalar literals via the operand list) instead of a {@link + * org.apache.flink.table.types.inference.CallContext}, since the planner doesn't carry one. The + * validation-time equivalent is {@link TraitContext#of}. + */ + private static TraitContext buildTraitContext( + RexTableArgCall tableArgCall, List declared, List operands) { + return new TraitContext() { + @Override + public boolean hasPartitionBy() { + return tableArgCall.getPartitionKeys().length > 0; + } + + @Override + public Optional getScalarArgument(String name, Class clazz) { + for (int i = 0; i < declared.size(); i++) { + final StaticArgument arg = declared.get(i); + if (!arg.is(StaticArgumentTrait.SCALAR) || !arg.getName().equals(name)) { + continue; + } + if (i >= operands.size() || !(operands.get(i) instanceof RexLiteral)) { + return Optional.empty(); + } + return Optional.ofNullable(((RexLiteral) operands.get(i)).getValueAs(clazz)); + } + return Optional.empty(); + } + }; + } + + /** + * Returns a copy of this function whose {@link TypeInference} reports the given static + * arguments. The wrapped input/output strategies are reused unchanged - they ran at validation + * time and aren't invoked again afterwards. + */ + private BridgingSqlFunction withStaticArguments(List staticArguments) { + final TypeInference rewritten = + TypeInference.newBuilder() + .staticArguments(staticArguments) + .inputTypeStrategy(typeInference.getInputTypeStrategy()) + .stateTypeStrategies(typeInference.getStateTypeStrategies()) + .outputTypeStrategy(typeInference.getOutputTypeStrategy()) + .disableSystemArguments(typeInference.disableSystemArguments()) + .build(); + if (this instanceof WithTableFunction) { + return new WithTableFunction( + dataTypeFactory, + typeFactory, + rexFactory, + getKind(), + resolvedFunction, + rewritten); + } + return new BridgingSqlFunction( + dataTypeFactory, typeFactory, rexFactory, getKind(), resolvedFunction, rewritten); + } + // -------------------------------------------------------------------------------------------- // Table function extension // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java index add2368a529c3..3973329af7484 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecProcessTableFunction.java @@ -38,6 +38,7 @@ import org.apache.flink.table.planner.codegen.ProcessTableRunnerGenerator; import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator; import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext; @@ -157,7 +158,10 @@ public StreamExecProcessTableFunction( @JsonProperty(FIELD_NAME_OUTPUT_CHANGELOG_MODE) ChangelogMode outputChangelogMode) { super(id, context, persistedConfig, inputProperties, outputType, description); this.uid = uid; - this.invocation = (RexCall) invocation; + // Mirror the FlinkLogicalTableFunctionScan converter for the compiled-plan restore path: + // bake StaticArgument#withConditionalTrait rules into the operator's static args so + // downstream code can use plain arg.is(SET_SEMANTIC_TABLE) checks. + this.invocation = BridgingSqlFunction.resolveCallTraits((RexCall) invocation); this.inputChangelogModes = inputChangelogModes; this.outputChangelogMode = outputChangelogMode; } @@ -186,13 +190,9 @@ protected Transformation translateToPlanInternal( providedInputArg -> { final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(providedInputArg.i); - final StaticArgument resolvedArg = - providedInputArg.e.applyConditionalTraits( - StreamPhysicalProcessTableFunction - .buildTraitContext( - invocation, tableArgCall)); + final StaticArgument tableArg = providedInputArg.e; return createRuntimeTableSemantics( - resolvedArg, tableArgCall, inputTimeColumns); + tableArg, tableArgCall, inputTimeColumns); }) .collect(Collectors.toList()); @@ -297,15 +297,13 @@ protected Transformation translateToPlanInternal( } private RuntimeTableSemantics createRuntimeTableSemantics( - StaticArgument resolvedArg, - RexTableArgCall tableArgCall, - List inputTimeColumns) { + StaticArgument tableArg, RexTableArgCall tableArgCall, List inputTimeColumns) { final RuntimeChangelogMode consumedChangelogMode = RuntimeChangelogMode.serialize( inputChangelogModes.get(tableArgCall.getInputIndex())); final DataType dataType; - if (resolvedArg.getDataType().isPresent()) { - dataType = resolvedArg.getDataType().get(); + if (tableArg.getDataType().isPresent()) { + dataType = tableArg.getDataType().get(); } else { dataType = DataTypes.of(FlinkTypeFactory.toLogicalRowType(tableArgCall.type)); } @@ -313,15 +311,15 @@ private RuntimeTableSemantics createRuntimeTableSemantics( final int timeColumn = inputTimeColumns.get(tableArgCall.getInputIndex()); return new RuntimeTableSemantics( - resolvedArg.getName(), + tableArg.getName(), tableArgCall.getInputIndex(), dataType, tableArgCall.getPartitionKeys(), tableArgCall.getOrderKeys(), RexTableArgCall.toSortDirections(tableArgCall.getSortOrder()), consumedChangelogMode, - resolvedArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH), - resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE), + tableArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH), + tableArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE), timeColumn); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.java index d478b3afcf072..17473ab19a7d7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalTableFunctionScan.java @@ -22,6 +22,7 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.functions.TemporalTableFunction; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.utils.ShortcutUtils; @@ -115,7 +116,14 @@ public boolean matches(RelOptRuleCall call) { functionScan.getInputs().stream() .map(input -> RelOptRule.convert(input, FlinkConventions.LOGICAL())) .collect(Collectors.toList()); - final RexCall rexCall = (RexCall) functionScan.getCall(); + + // Resolve any StaticArgument#withConditionalTrait rules on the operator against this + // call site (PARTITION BY, scalar literals). After this rewrite, downstream code sees a + // BridgingSqlFunction whose getStaticArguments() reports the effective signature, so + // simple staticArg.is(SET_SEMANTIC_TABLE) checks suffice. + final RexCall rexCall = + BridgingSqlFunction.resolveCallTraits((RexCall) functionScan.getCall()); + return new FlinkLogicalTableFunctionScan( functionScan.getCluster(), traitSet, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java index 4e088c14b2756..754fe1b328492 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java @@ -40,7 +40,6 @@ import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; -import org.apache.flink.table.types.inference.TraitContext; import org.apache.flink.types.RowKind; import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet; @@ -71,7 +70,6 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -92,7 +90,6 @@ public class StreamPhysicalProcessTableFunction extends AbstractRelNode private final FlinkLogicalTableFunctionScan scan; private final @Nullable String uid; - private final List resolvedStaticArgs; private List inputs; @@ -108,15 +105,10 @@ public StreamPhysicalProcessTableFunction( this.scan = scan; final RexCall call = (RexCall) scan.getCall(); validateAllowSystemArgs(call); - this.resolvedStaticArgs = resolveStaticArgs(call); - this.uid = deriveUniqueIdentifier(scan, resolvedStaticArgs); + this.uid = deriveUniqueIdentifier(scan); verifyInputSize(ShortcutUtils.unwrapTableConfig(cluster), inputs.size()); } - public List getResolvedStaticArgs() { - return resolvedStaticArgs; - } - public StreamPhysicalProcessTableFunction( RelOptCluster cluster, RelTraitSet traitSet, @@ -214,39 +206,22 @@ protected RelDataType deriveRowType() { * * @see SystemTypeInference */ - /** Resolves conditional traits for all static args at construction time. */ - private static List resolveStaticArgs(final RexCall call) { - final List staticArgs = getStaticArguments(call); - final List operands = call.getOperands(); - return IntStream.range(0, staticArgs.size()) - .mapToObj( - i -> { - final StaticArgument arg = staticArgs.get(i); - if (!arg.is(StaticArgumentTrait.TABLE) || !arg.hasConditionalTraits()) { - return arg; - } - final RexTableArgCall tableArgCall = (RexTableArgCall) operands.get(i); - return arg.applyConditionalTraits( - buildTraitContext(call, tableArgCall)); - }) - .collect(Collectors.toList()); - } - - private static @Nullable String deriveUniqueIdentifier( - FlinkLogicalTableFunctionScan scan, List resolvedStaticArgs) { + private static @Nullable String deriveUniqueIdentifier(FlinkLogicalTableFunctionScan scan) { final RexCall rexCall = (RexCall) scan.getCall(); final BridgingSqlFunction.WithTableFunction function = (BridgingSqlFunction.WithTableFunction) rexCall.getOperator(); + final List staticArgs = + function.getTypeInference() + .getStaticArguments() + .orElseThrow(IllegalStateException::new); final ContextResolvedFunction resolvedFunction = function.getResolvedFunction(); final List operands = rexCall.getOperands(); // Type inference ensures that uid is always added at the end final RexNode uidRexNode = operands.get(operands.size() - 1); if (uidRexNode.getKind() == SqlKind.DEFAULT) { // Optional for constant or row semantics functions - final boolean hasSetSemantics = - resolvedStaticArgs.stream() - .anyMatch(arg -> arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)); - if (!hasSetSemantics) { + if (staticArgs.stream() + .noneMatch(arg -> arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE))) { return null; } final String uid = @@ -404,42 +379,6 @@ public static List> getProvidedInputArgs(RexCall call) { .collect(Collectors.toList()); } - /** - * Builds a {@link TraitContext} for resolving conditional traits on a table argument at - * planning time. - */ - public static TraitContext buildTraitContext( - final RexCall call, final RexTableArgCall tableArgCall) { - final CallContext callContext = toCallContext(call); - final List declaredArgs = getStaticArguments(call); - - return new TraitContext() { - @Override - public boolean hasPartitionBy() { - return tableArgCall.getPartitionKeys().length > 0; - } - - @Override - public Optional getScalarArgument(final String name, final Class clazz) { - for (int i = 0; i < declaredArgs.size(); i++) { - final StaticArgument arg = declaredArgs.get(i); - if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) { - return callContext.getArgumentValue(i, clazz); - } - } - return Optional.empty(); - } - }; - } - - public static List getStaticArguments(final RexCall call) { - final BridgingSqlFunction.WithTableFunction function = - (BridgingSqlFunction.WithTableFunction) call.getOperator(); - return function.getTypeInference() - .getStaticArguments() - .orElseThrow(IllegalStateException::new); - } - public static Set deriveOnTimeFields(RexCall call) { final List operands = call.getOperands(); final RexCall onTimeOperand = @@ -574,14 +513,6 @@ public static Set toPartitionColumns(RexCall call) { return ImmutableSet.copyOf(partitionColumnsPerArg); } - /** - * Creates a CallContext for argument value extraction only, (no changelog and no input time - * columns). - */ - public static CallContext toCallContext(RexCall udfCall) { - return toCallContext(udfCall, null, null, null); - } - public static CallContext toCallContext( RexCall udfCall, List inputTimeColumns, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java index e79acfc8b211f..0785b1527bf88 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalProcessTableFunctionRule.java @@ -21,14 +21,13 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.planner.calcite.RexTableArgCall; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalProcessTableFunction; import org.apache.flink.table.planner.plan.rules.physical.common.PhysicalMLPredictTableFunctionRule; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.utils.ShortcutUtils; -import org.apache.flink.table.types.inference.StaticArgument; -import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.calcite.linq4j.Ord; import org.apache.calcite.plan.RelOptRule; @@ -38,6 +37,8 @@ import org.apache.calcite.rel.convert.ConverterRule; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.TableCharacteristic; +import org.apache.calcite.sql.TableCharacteristic.Semantics; import org.checkerframework.checker.nullness.qual.Nullable; import java.util.List; @@ -80,9 +81,11 @@ public boolean matches(RelOptRuleCall call) { final FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) rel; final RexCall rexCall = (RexCall) scan.getCall(); + final BridgingSqlFunction.WithTableFunction function = + (BridgingSqlFunction.WithTableFunction) rexCall.getOperator(); final List operands = rexCall.getOperands(); final List newInputs = - applyDistributionOnInputs(rexCall, operands, rel.getInputs()); + applyDistributionOnInputs(function, operands, rel.getInputs()); final RelTraitSet providedTraitSet = rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()); return new StreamPhysicalProcessTableFunction( @@ -90,43 +93,42 @@ public boolean matches(RelOptRuleCall call) { } private static List applyDistributionOnInputs( - RexCall rexCall, List operands, List inputs) { - final List staticArgs = - StreamPhysicalProcessTableFunction.getStaticArguments(rexCall); + BridgingSqlFunction.WithTableFunction function, + List operands, + List inputs) { return Ord.zip(operands).stream() .filter(operand -> operand.e instanceof RexTableArgCall) .map( tableOperand -> { + final int pos = tableOperand.i; final RexTableArgCall tableArgCall = (RexTableArgCall) tableOperand.e; - final StaticArgument tableArg = staticArgs.get(tableOperand.i); - final StaticArgument resolvedTableArg = - tableArg.applyConditionalTraits( - StreamPhysicalProcessTableFunction.buildTraitContext( - rexCall, tableArgCall)); + final TableCharacteristic tableCharacteristic = + function.tableCharacteristic(pos); + assert tableCharacteristic != null; return applyDistributionOnInput( tableArgCall, - resolvedTableArg, + tableCharacteristic, inputs.get(tableArgCall.getInputIndex())); }) .collect(Collectors.toList()); } private static RelNode applyDistributionOnInput( - RexTableArgCall tableOperand, StaticArgument resolvedTableArg, RelNode input) { - final FlinkRelDistribution distribution = - deriveDistribution(tableOperand, resolvedTableArg); + RexTableArgCall tableOperand, TableCharacteristic tableCharacteristic, RelNode input) { + final FlinkRelDistribution requiredDistribution = + deriveDistribution(tableOperand, tableCharacteristic); final RelTraitSet requiredTraitSet = input.getCluster() .getPlanner() .emptyTraitSet() - .replace(distribution) + .replace(requiredDistribution) .replace(FlinkConventions.STREAM_PHYSICAL()); return RelOptRule.convert(input, requiredTraitSet); } private static FlinkRelDistribution deriveDistribution( - RexTableArgCall tableOperand, StaticArgument resolvedTableArg) { - if (resolvedTableArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { + RexTableArgCall tableOperand, TableCharacteristic tableCharacteristic) { + if (tableCharacteristic.semantics == Semantics.SET) { final int[] partitionKeys = tableOperand.getPartitionKeys(); if (partitionKeys.length == 0) { return FlinkRelDistribution.SINGLETON(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index 1b8f864b6fef8..82c80aa1dffdb 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -1742,13 +1742,10 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } StreamPhysicalProcessTableFunction .getProvidedInputArgs(call) + .map(_.e) .foreach { - arg => - val tableArg = arg.e - val tableArgCall = call.operands.get(arg.i).asInstanceOf[RexTableArgCall] - val traitCtx = StreamPhysicalProcessTableFunction - .buildTraitContext(call, tableArgCall) - if (tableArg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE, traitCtx)) { + tableArg => + if (tableArg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE)) { throw new ValidationException( s"PTFs that take table arguments with row semantics don't support upsert output. " + s"Table argument '${tableArg.getName}' of function '${call.getOperator.toString}' " + From 83e33eec3b8ce1053c0331eb9638fe059316be24 Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Wed, 22 Apr 2026 18:02:27 +0200 Subject: [PATCH 06/10] [FLINK-39392][table] Implement hash and equals via BuiltInCondition for TraitCondition --- .../types/inference/SystemTypeInference.java | 39 +++---- .../table/types/inference/TraitCondition.java | 109 ++++++------------ 2 files changed, 57 insertions(+), 91 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java index 1ce06e1e36777..c4c95a09127b2 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/SystemTypeInference.java @@ -331,26 +331,25 @@ private List makeFieldNamesUnique(List fields) { } private List derivePassThroughFields( - CallContext callContext, List resolvedArgs) { + CallContext callContext, List staticArgs) { if (functionKind != FunctionKind.PROCESS_TABLE) { return List.of(); } final List argDataTypes = callContext.getArgumentDataTypes(); - return IntStream.range(0, resolvedArgs.size()) + return IntStream.range(0, staticArgs.size()) .mapToObj( pos -> { - final StaticArgument resolvedArg = resolvedArgs.get(pos); - if (resolvedArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { + final StaticArgument arg = staticArgs.get(pos); + if (arg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) { return DataType.getFields(argDataTypes.get(pos)).stream(); } - final TableSemantics semantics = - callContext.getTableSemantics(pos).orElse(null); - if (semantics == null) { - return Stream.empty(); - } - if (!resolvedArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { + if (!arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) { return Stream.empty(); } + final TableSemantics semantics = + callContext + .getTableSemantics(pos) + .orElseThrow(IllegalStateException::new); final DataType rowDataType = DataTypes.ROW(DataType.getFields(argDataTypes.get(pos))); final DataType projectedRow = @@ -379,7 +378,7 @@ private List deriveFunctionOutputFields(DataType functionDataType) { } private List deriveRowtimeField( - CallContext callContext, List resolvedArgs) { + CallContext callContext, List staticArgs) { if (this.functionKind != FunctionKind.PROCESS_TABLE) { return List.of(); } @@ -398,10 +397,10 @@ private List deriveRowtimeField( final List onTimeColumns = new ArrayList<>(); final List missingOnTimeColumns = new ArrayList<>(); - IntStream.range(0, resolvedArgs.size()) + IntStream.range(0, staticArgs.size()) .forEach( pos -> { - final StaticArgument staticArg = resolvedArgs.get(pos); + final StaticArgument staticArg = staticArgs.get(pos); if (!staticArg.is(StaticArgumentTrait.TABLE)) { return; } @@ -635,13 +634,13 @@ private static void checkUidArg(CallContext callContext) { } private static void checkTableArgs( - List resolvedArgs, CallContext callContext) { + List staticArgs, CallContext callContext) { final List tableSemantics = new ArrayList<>(); - IntStream.range(0, resolvedArgs.size()) + IntStream.range(0, staticArgs.size()) .forEach( pos -> { - final StaticArgument resolvedArg = resolvedArgs.get(pos); - if (!resolvedArg.is(StaticArgumentTrait.TABLE)) { + final StaticArgument staticArg = staticArgs.get(pos); + if (!staticArg.is(StaticArgumentTrait.TABLE)) { return; } final TableSemantics semantics = @@ -650,10 +649,10 @@ private static void checkTableArgs( throw new ValidationException( String.format( "Table expected for argument '%s'.", - resolvedArg.getName())); + staticArg.getName())); } - checkRowSemantics(resolvedArg, semantics); - checkSetSemantics(resolvedArg, semantics); + checkRowSemantics(staticArg, semantics); + checkSetSemantics(staticArg, semantics); tableSemantics.add(semantics); }); checkCoPartitioning(tableSemantics); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java index 5d0e9b84b5cc8..2118e3acd8611 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java @@ -20,7 +20,9 @@ import org.apache.flink.annotation.PublicEvolving; +import java.util.List; import java.util.Objects; +import java.util.function.Predicate; /** * A condition that determines whether a conditional trait on a {@link StaticArgument} should be @@ -31,10 +33,7 @@ * *

Implementations must implement {@code hashCode} and {@code equals} for {@link * StaticArgument#equals}/{@link StaticArgument#hashCode} to work correctly. The built-in factories - * below return value-comparable instances; user-supplied lambdas do not - prefer the factories or - * named classes. - * - *

Use the static factory methods for common conditions: + * below return value-comparable instances; user-supplied lambdas do not - prefer the factories. * *

{@code
  * import static org.apache.flink.table.types.inference.TraitCondition.*;
@@ -52,87 +51,54 @@ public interface TraitCondition {
 
     /** True when PARTITION BY is provided on the table argument. */
     static TraitCondition hasPartitionBy() {
-        return HasPartitionByCondition.INSTANCE;
+        return new BuiltInCondition(
+                BuiltInCondition.Kind.HAS_PARTITION_BY, List.of(), TraitContext::hasPartitionBy);
     }
 
     /** True when the named scalar argument equals the expected value. */
+    @SuppressWarnings("unchecked")
     static  TraitCondition argIsEqualTo(final String name, final T expected) {
-        return new ArgIsEqualToCondition<>(name, expected);
+        final Class clazz = (Class) expected.getClass();
+        return new BuiltInCondition(
+                BuiltInCondition.Kind.ARG_IS_EQUAL_TO,
+                List.of(name, expected),
+                ctx -> ctx.getScalarArgument(name, clazz).map(expected::equals).orElse(false));
     }
 
     /** Negates the given condition. */
     static TraitCondition not(final TraitCondition condition) {
-        return new NotCondition(condition);
-    }
-
-    // --------------------------------------------------------------------------------------------
-    // Built-in implementations - named so that StaticArgument equality cascades correctly.
-    // --------------------------------------------------------------------------------------------
-
-    /** Singleton condition that is true when PARTITION BY is provided on the table argument. */
-    final class HasPartitionByCondition implements TraitCondition {
-
-        private static final HasPartitionByCondition INSTANCE = new HasPartitionByCondition();
-
-        private HasPartitionByCondition() {}
-
-        @Override
-        public boolean test(final TraitContext ctx) {
-            return ctx.hasPartitionBy();
-        }
-
-        // equals/hashCode by identity - safe because there is exactly one instance.
+        return new BuiltInCondition(
+                BuiltInCondition.Kind.NOT, List.of(condition), ctx -> !condition.test(ctx));
     }
 
-    /** Condition that is true when the named scalar argument equals the expected value. */
-    final class ArgIsEqualToCondition implements TraitCondition {
-
-        private final String name;
-        private final T expected;
-        private final Class clazz;
-
-        @SuppressWarnings("unchecked")
-        ArgIsEqualToCondition(final String name, final T expected) {
-            this.name = name;
-            this.expected = expected;
-            this.clazz = (Class) expected.getClass();
-        }
-
-        @Override
-        public boolean test(final TraitContext ctx) {
-            return ctx.getScalarArgument(name, clazz).map(expected::equals).orElse(false);
-        }
-
-        @Override
-        public boolean equals(final Object o) {
-            if (this == o) {
-                return true;
-            }
-            if (!(o instanceof ArgIsEqualToCondition)) {
-                return false;
-            }
-            final ArgIsEqualToCondition that = (ArgIsEqualToCondition) o;
-            return name.equals(that.name) && expected.equals(that.expected);
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(name, expected);
+    /**
+     * Internal value-comparable wrapper used by all built-in factories. Equality is keyed by {@code
+     * kind + args}; the {@code impl} predicate is reused but never compared, so two conditions
+     * built from the same factory inputs are equal.
+     */
+    final class BuiltInCondition implements TraitCondition {
+
+        /** Tag identifying which factory produced the condition. */
+        enum Kind {
+            HAS_PARTITION_BY,
+            ARG_IS_EQUAL_TO,
+            NOT
         }
-    }
-
-    /** Condition that negates another condition. */
-    final class NotCondition implements TraitCondition {
 
-        private final TraitCondition condition;
+        private final Kind kind;
+        private final List args;
+        private final Predicate impl;
 
-        NotCondition(final TraitCondition condition) {
-            this.condition = condition;
+        BuiltInCondition(
+                final Kind kind, final List args, final Predicate impl) {
+            this.kind = kind;
+            this.args = args;
+            this.impl = impl;
         }
 
         @Override
         public boolean test(final TraitContext ctx) {
-            return !condition.test(ctx);
+            return impl.test(ctx);
         }
 
         @Override
@@ -140,15 +106,16 @@ public boolean equals(final Object o) {
             if (this == o) {
                 return true;
             }
-            if (!(o instanceof NotCondition)) {
+            if (!(o instanceof BuiltInCondition)) {
                 return false;
             }
-            return condition.equals(((NotCondition) o).condition);
+            final BuiltInCondition that = (BuiltInCondition) o;
+            return kind == that.kind && args.equals(that.args);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(NotCondition.class, condition);
+            return Objects.hash(kind, args);
         }
     }
 }

From 6d6b4a1d51378ea51217dfd4e17a44deff8b095c Mon Sep 17 00:00:00 2001
From: Gustavo de Morais 
Date: Thu, 23 Apr 2026 14:52:20 +0200
Subject: [PATCH 07/10] [FLINK-39392][table] Move BuiltInCondition out of
 TraitCondition for stricter encapsulation

---
 .../types/inference/BuiltInCondition.java     | 74 +++++++++++++++++++
 .../table/types/inference/TraitCondition.java | 50 -------------
 2 files changed, 74 insertions(+), 50 deletions(-)
 create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/BuiltInCondition.java

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/BuiltInCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/BuiltInCondition.java
new file mode 100644
index 0000000000000..2a06191ad13ab
--- /dev/null
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/BuiltInCondition.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.types.inference;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Predicate;
+
+/**
+ * Internal value-comparable wrapper used by all built-in {@link TraitCondition} factories. Equality
+ * is keyed by {@code kind + args}; the {@code impl} predicate is reused but never compared, so two
+ * conditions built from the same factory inputs are equal.
+ *
+ * 

Lives outside {@link TraitCondition} because Java forbids {@code private} nested types in + * interfaces (they are implicitly {@code public static}); top-level package-private gives the same + * encapsulation. + */ +final class BuiltInCondition implements TraitCondition { + + /** Tag identifying which factory produced the condition. */ + enum Kind { + HAS_PARTITION_BY, + ARG_IS_EQUAL_TO, + NOT + } + + private final Kind kind; + private final List args; + private final Predicate impl; + + BuiltInCondition(final Kind kind, final List args, final Predicate impl) { + this.kind = kind; + this.args = args; + this.impl = impl; + } + + @Override + public boolean test(final TraitContext ctx) { + return impl.test(ctx); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof BuiltInCondition)) { + return false; + } + final BuiltInCondition that = (BuiltInCondition) o; + return kind == that.kind && args.equals(that.args); + } + + @Override + public int hashCode() { + return Objects.hash(kind, args); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java index 2118e3acd8611..93bdc5c8ce479 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TraitCondition.java @@ -21,8 +21,6 @@ import org.apache.flink.annotation.PublicEvolving; import java.util.List; -import java.util.Objects; -import java.util.function.Predicate; /** * A condition that determines whether a conditional trait on a {@link StaticArgument} should be @@ -70,52 +68,4 @@ static TraitCondition not(final TraitCondition condition) { return new BuiltInCondition( BuiltInCondition.Kind.NOT, List.of(condition), ctx -> !condition.test(ctx)); } - - /** - * Internal value-comparable wrapper used by all built-in factories. Equality is keyed by {@code - * kind + args}; the {@code impl} predicate is reused but never compared, so two conditions - * built from the same factory inputs are equal. - */ - final class BuiltInCondition implements TraitCondition { - - /** Tag identifying which factory produced the condition. */ - enum Kind { - HAS_PARTITION_BY, - ARG_IS_EQUAL_TO, - NOT - } - - private final Kind kind; - private final List args; - private final Predicate impl; - - BuiltInCondition( - final Kind kind, final List args, final Predicate impl) { - this.kind = kind; - this.args = args; - this.impl = impl; - } - - @Override - public boolean test(final TraitContext ctx) { - return impl.test(ctx); - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (!(o instanceof BuiltInCondition)) { - return false; - } - final BuiltInCondition that = (BuiltInCondition) o; - return kind == that.kind && args.equals(that.args); - } - - @Override - public int hashCode() { - return Objects.hash(kind, args); - } - } } From cd9e2de6e001cac16cbeb0d9a0498272d10601c4 Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Thu, 23 Apr 2026 14:52:26 +0200 Subject: [PATCH 08/10] [FLINK-39392][table] Resolve scalar args via CallContext wrapper and dedup toCallContext --- .../bridging/BridgingSqlFunction.java | 68 +++++++++++++++---- .../StreamPhysicalProcessTableFunction.java | 23 ------- .../codegen/ProcessTableRunnerGenerator.scala | 8 +-- .../FlinkChangelogModeInferenceProgram.scala | 9 ++- 4 files changed, 59 insertions(+), 49 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java index e62b44830ecc6..059fcd6184d69 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java @@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.table.catalog.ContextResolvedFunction; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionIdentifier; @@ -31,8 +32,10 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexFactory; import org.apache.flink.table.planner.calcite.RexTableArgCall; +import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; @@ -45,7 +48,7 @@ import org.apache.calcite.rel.type.StructKind; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexCallBinding; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlKind; @@ -56,6 +59,7 @@ import org.apache.calcite.tools.RelBuilder; import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -250,9 +254,17 @@ public static RexCall resolveCallTraits(RexCall call) { return call; } final List operands = call.getOperands(); + final CallContext callContext = function.toCallContext(call); final List resolved = IntStream.range(0, declared.size()) - .mapToObj(i -> resolveArg(declared.get(i), declared, operands, i)) + .mapToObj( + i -> + resolveArg( + declared.get(i), + declared, + operands, + i, + callContext)) .collect(Collectors.toList()); if (resolved.equals(declared)) { return call; @@ -268,24 +280,25 @@ private static StaticArgument resolveArg( StaticArgument declaredArg, List declared, List operands, - int index) { + int index, + CallContext callContext) { // We only resolve conditional traits for the Table Argument with conditional traits if (!declaredArg.hasConditionalTraits() || !(operands.get(index) instanceof RexTableArgCall)) { return declaredArg; } return declaredArg.applyConditionalTraits( - buildTraitContext((RexTableArgCall) operands.get(index), declared, operands)); + buildTraitContext((RexTableArgCall) operands.get(index), declared, callContext)); } /** - * Planner-side adapter to {@link TraitContext}. Sourced from a {@link RexCall} (PARTITION BY - * via {@link RexTableArgCall}, scalar literals via the operand list) instead of a {@link - * org.apache.flink.table.types.inference.CallContext}, since the planner doesn't carry one. The - * validation-time equivalent is {@link TraitContext#of}. + * Planner-side adapter to {@link TraitContext}. Sourced from a {@link RexCall}: PARTITION BY + * via the {@link RexTableArgCall} operand, scalar literals via the {@link CallContext} wrapper. + * Equivalent to {@link TraitContext#of} but takes its inputs from a planner-side call instead + * of validation-side {@link org.apache.flink.table.functions.TableSemantics}. */ private static TraitContext buildTraitContext( - RexTableArgCall tableArgCall, List declared, List operands) { + RexTableArgCall tableArgCall, List declared, CallContext callContext) { return new TraitContext() { @Override public boolean hasPartitionBy() { @@ -296,19 +309,44 @@ public boolean hasPartitionBy() { public Optional getScalarArgument(String name, Class clazz) { for (int i = 0; i < declared.size(); i++) { final StaticArgument arg = declared.get(i); - if (!arg.is(StaticArgumentTrait.SCALAR) || !arg.getName().equals(name)) { - continue; + if (arg.is(StaticArgumentTrait.SCALAR) && arg.getName().equals(name)) { + return callContext.getArgumentValue(i, clazz); } - if (i >= operands.size() || !(operands.get(i) instanceof RexLiteral)) { - return Optional.empty(); - } - return Optional.ofNullable(((RexLiteral) operands.get(i)).getValueAs(clazz)); } return Optional.empty(); } }; } + /** + * Builds a {@link CallContext} from the given {@link RexCall} for this function. Wraps the call + * in an {@link OperatorBindingCallContext} so consumers (trait resolution, codegen, etc.) read + * scalar arguments through the same coercion path as validation. + */ + public CallContext toCallContext(RexCall call) { + return toCallContext(call, null, null, null); + } + + /** + * Variant of {@link #toCallContext(RexCall)} that additionally exposes the call's input time + * columns and changelog modes - needed by the streaming codegen path so PTFs can specialize + * themselves to the exact call. + */ + public CallContext toCallContext( + RexCall call, + @Nullable List inputTimeColumns, + @Nullable List inputChangelogModes, + @Nullable ChangelogMode outputChangelogMode) { + return new OperatorBindingCallContext( + dataTypeFactory, + getDefinition(), + RexCallBinding.create(typeFactory, call, Collections.emptyList()), + call.getType(), + inputTimeColumns, + inputChangelogModes, + outputChangelogMode); + } + /** * Returns a copy of this function whose {@link TypeInference} reports the given static * arguments. The wrapped input/output strategies are reused unchanged - they ran at validation diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java index 754fe1b328492..56a2bf1f6fcab 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalProcessTableFunction.java @@ -22,13 +22,11 @@ import org.apache.flink.table.api.config.OptimizerConfigOptions; import org.apache.flink.table.catalog.ContextResolvedFunction; import org.apache.flink.table.connector.ChangelogMode; -import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionIdentifier; import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexTableArgCall; import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction; -import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecProcessTableFunction; @@ -36,7 +34,6 @@ import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils; import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.planner.utils.ShortcutUtils; -import org.apache.flink.table.types.inference.CallContext; import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.SystemTypeInference; @@ -56,7 +53,6 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexCallBinding; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; @@ -66,7 +62,6 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -512,22 +507,4 @@ public static Set toPartitionColumns(RexCall call) { } return ImmutableSet.copyOf(partitionColumnsPerArg); } - - public static CallContext toCallContext( - RexCall udfCall, - List inputTimeColumns, - List inputChangelogModes, - @Nullable ChangelogMode outputChangelogMode) { - final BridgingSqlFunction function = ShortcutUtils.unwrapBridgingSqlFunction(udfCall); - assert function != null; - final FunctionDefinition definition = ShortcutUtils.unwrapFunctionDefinition(udfCall); - return new OperatorBindingCallContext( - function.getDataTypeFactory(), - definition, - RexCallBinding.create(function.getTypeFactory(), udfCall, Collections.emptyList()), - udfCall.getType(), - inputTimeColumns, - inputChangelogModes, - outputChangelogMode); - } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala index 4f1d8d65d419f..52df803d5c8f8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala @@ -34,7 +34,6 @@ import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.{verifyFunctionAwareOutputType, DefaultExpressionEvaluatorFactory} import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction -import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalProcessTableFunction import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.dataview.DataViewUtils import org.apache.flink.table.runtime.dataview.StateListView.KeyedStateListView @@ -77,11 +76,8 @@ object ProcessTableRunnerGenerator { // For specialized functions, this call context is able to provide the final changelog modes. // Thus, functions can reconfigure themselves for the exact use case. // Including updating their state layout. - val callContext = StreamPhysicalProcessTableFunction.toCallContext( - udfCall, - inputTimeColumns, - inputChangelogModes, - outputChangelogMode) + val callContext = + function.toCallContext(udfCall, inputTimeColumns, inputChangelogModes, outputChangelogMode) // Create the final UDF for runtime val udf = UserDefinedFunctionHelper.createSpecializedFunction( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index 82c80aa1dffdb..6839cc1c28c13 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -26,6 +26,7 @@ import org.apache.flink.table.connector.ChangelogMode import org.apache.flink.table.functions.{BuiltInFunctionDefinition, ChangelogFunction} import org.apache.flink.table.functions.ChangelogFunction.ChangelogContext import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexTableArgCall} +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import org.apache.flink.table.planner.plan.`trait`._ import org.apache.flink.table.planner.plan.`trait`.DeleteKindTrait.{deleteOnKeyOrNone, fullDeleteOrNone, DELETE_BY_KEY} import org.apache.flink.table.planner.plan.`trait`.UpdateKindTrait.{beforeAfterOrNone, onlyAfterOrNone, BEFORE_AND_AFTER, ONLY_UPDATE_AFTER} @@ -1673,11 +1674,9 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti outputChangelogMode: ChangelogMode): ChangelogContext = { val udfCall = StreamPhysicalProcessTableFunction.toUdfCall(process.getCall) val inputTimeColumns = StreamPhysicalProcessTableFunction.toInputTimeColumns(process.getCall) - val callContext = StreamPhysicalProcessTableFunction.toCallContext( - udfCall, - inputTimeColumns, - inputChangelogModes, - outputChangelogMode) + val function = udfCall.getOperator.asInstanceOf[BridgingSqlFunction] + val callContext = + function.toCallContext(udfCall, inputTimeColumns, inputChangelogModes, outputChangelogMode) // Expose a simplified context to let users focus on important characteristics. // If necessary, we can expose the full CallContext in the future. From 298397ba72035b867fc0fddf044b0517cd9982e4 Mon Sep 17 00:00:00 2001 From: Gustavo de Morais Date: Thu, 23 Apr 2026 15:23:33 +0200 Subject: [PATCH 09/10] [FLINK-39392][table] Document PTF conditional traits in planner AGENTS.md --- flink-table/flink-table-planner/AGENTS.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/flink-table/flink-table-planner/AGENTS.md b/flink-table/flink-table-planner/AGENTS.md index 969c5b366569d..8b45353ecbf6f 100644 --- a/flink-table/flink-table-planner/AGENTS.md +++ b/flink-table/flink-table-planner/AGENTS.md @@ -106,6 +106,22 @@ When bumping an ExecNode version, update the `@ExecNodeMetadata` annotation's `v New features often introduce `ExecutionConfigOptions` entries (in `flink-table-api-java`) for runtime tunability (e.g., cache sizes, timeouts, batch sizes). +### PTF conditional traits + +A *conditional trait* lets a PTF's table-argument traits depend on the call site instead of being fixed at declaration. Example for `TO_CHANGELOG`: the `input` argument is row-semantic by default (single stream, no PARTITION BY), but switches to set-semantic when the user writes `PARTITION BY` so the runtime can co-locate state per key. One declaration, two effective signatures depending on the call. + +**Declaration.** Built-in functions add conditional rules in `BuiltInFunctionDefinitions` via `StaticArgument.withConditionalTrait(trait, condition)`. The condition (a `TraitCondition`) is a small value-comparable predicate evaluated against a `TraitContext`. Built-in factories live on `TraitCondition` (`hasPartitionBy()`, `argIsEqualTo(name, value)`, `not(c)`); under the hood they wrap into the package-private `BuiltInCondition` so equality cascades correctly through `StaticArgument.equals`. + +**Evaluation.** A `TraitCondition` reads two things: whether `PARTITION BY` is present on this table arg, and the literal value of named scalar args. Both come through `TraitContext`. There are two factories: `TraitContext.of(TableSemantics, CallContext, declared)` for the validation side (called from `SystemTypeInference.resolveStaticArgs`) and a planner-side adapter inside `BridgingSqlFunction.buildTraitContext` that sources the same data from a `RexCall` + `RexTableArgCall`. Same logical context, different inputs because the two layers don't share types. + +**Resolution.** Three call sites bake conditional traits into the operator's effective signature: + +1. **Validation** — `SystemTypeInference.resolveStaticArgs` runs once each from `inferInputTypes` and `inferType`. Twice per validation pass; can't dedupe across Calcite hooks because each gets a different `CallContext` instance. +2. **Planning** — `BridgingSqlFunction.resolveCallTraits` is called from `FlinkLogicalTableFunctionScan.Converter.convert`. It rewrites the operator on the `RexCall` so all downstream readers see the resolved view via plain `function.getTypeInference().getStaticArguments()`. +3. **Compiled-plan restore** — `BridgingSqlFunction.resolveCallTraits` is called again from `StreamExecProcessTableFunction.@JsonCreator`, because the JSON path skips the logical converter. Without this hook, restore would silently produce wrong results for any conditional-trait PTF. + +The payoff: downstream rules, exec nodes, codegen, and changelog inference all use ordinary `staticArg.is(SET_SEMANTIC_TABLE)` checks. No consumer needs to know that conditional traits exist. Why three sites and not one. The three resolution points exist because they sit in different lifecycles that can't share state. + ## Testing Patterns Choose test types based on what you're changing: From dbdf51054e436fba9b5749ffe12594e40c884a5e Mon Sep 17 00:00:00 2001 From: Ramin Gharib Date: Fri, 24 Apr 2026 15:14:02 +0200 Subject: [PATCH 10/10] [FLINK-39537][table] Apply conditional SET_SEMANTIC_TABLE trait to FROM_CHANGELOG --- .../docs/sql/reference/queries/changelog.md | 12 +++++- .../org/apache/flink/table/api/Table.java | 11 ++++++ .../functions/BuiltInFunctionDefinitions.java | 18 ++++++--- .../stream/FromChangelogSemanticTests.java | 1 + .../stream/FromChangelogTestPrograms.java | 38 +++++++++++++++++++ .../plan/stream/sql/FromChangelogTest.java | 14 +++++++ .../plan/stream/sql/FromChangelogTest.xml | 20 ++++++++++ 7 files changed, 106 insertions(+), 8 deletions(-) diff --git a/docs/content/docs/sql/reference/queries/changelog.md b/docs/content/docs/sql/reference/queries/changelog.md index 494dcc60be892..c15ff255cfd29 100644 --- a/docs/content/docs/sql/reference/queries/changelog.md +++ b/docs/content/docs/sql/reference/queries/changelog.md @@ -45,7 +45,7 @@ Note: This version requires that your CDC data encodes updates using a full imag ```sql SELECT * FROM FROM_CHANGELOG( - input => TABLE source_table, + input => TABLE source_table [PARTITION BY key_col], [op => DESCRIPTOR(op_column_name),] [op_mapping => MAP[ 'c, r', 'INSERT', @@ -61,7 +61,7 @@ SELECT * FROM FROM_CHANGELOG( | Parameter | Required | Description | |:-------------|:---------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `input` | Yes | The input table. Must be append-only. | +| `input` | Yes | The input table. Must be append-only. Use `PARTITION BY` to ensure rows for the same key are processed together. This is required when downstream operators are keyed on that column. | | `op` | No | A `DESCRIPTOR` with a single column name for the operation code column. Defaults to `op`. The column must exist in the input table and be of type STRING. | | `op_mapping` | No | A `MAP` mapping user-defined codes to Flink change operation names. Keys are user-defined codes (e.g., `'c'`, `'u'`, `'d'`), values are Flink change operation names (`INSERT`, `UPDATE_BEFORE`, `UPDATE_AFTER`, `DELETE`). Keys can contain comma-separated codes to map multiple codes to the same operation (e.g., `'c, r'`). Each change operation may appear at most once across all entries. | | `error_handling` | No | Controls behavior when an input row's operation code is `NULL` or not present in the `op_mapping`. Valid values: `FAIL` (default) — throw a `TableRuntimeException`, `SKIP` — silently drop the row. | @@ -127,6 +127,14 @@ SELECT * FROM FROM_CHANGELOG( -- The operation column named 'operation' is used instead of 'op' ``` +#### Partitioning by a key + +```sql +SELECT * FROM FROM_CHANGELOG( + input => TABLE cdc_stream PARTITION BY id +) +``` + #### Invalid operation code handling Two `error_handling` modes are supported. The job can either fail upon an invalid or unknown op code, or skip the row and continue processing. diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java index eb61371329c2f..aa2417d4531ce 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Table.java @@ -1467,6 +1467,17 @@ default TableResult executeInsert( * TableRuntimeException} when an input row's op code is {@code NULL} or not present in the * mapping; pass {@code error_handling => 'SKIP'} to silently drop those rows instead. * + *

By default, the input is processed with row semantics (each row independently). To + * co-locate rows with the same key in the same parallel operator instance, partition the input + * first via {@link #partitionBy(Expression...)} and invoke the function via {@link + * PartitionedTable#process(String, Object...)}: + * + *

{@code
+     * Table result = cdcStream
+     *     .partitionBy($("id"))
+     *     .process("FROM_CHANGELOG");
+     * }
+ * *

Optional arguments can be passed using named expressions: * *

{@code
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index cd25361923c7e..8ecc43aa9827b 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -818,13 +818,19 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
                     .name("FROM_CHANGELOG")
                     .kind(PROCESS_TABLE)
                     .staticArguments(
+                            // Row semantics (no PARTITION BY).
+                            // With PARTITION BY, switches to set
+                            // semantics for co-located parallel execution.
                             StaticArgument.table(
-                                    "input",
-                                    Row.class,
-                                    false,
-                                    EnumSet.of(
-                                            StaticArgumentTrait.TABLE,
-                                            StaticArgumentTrait.ROW_SEMANTIC_TABLE)),
+                                            "input",
+                                            Row.class,
+                                            false,
+                                            EnumSet.of(
+                                                    StaticArgumentTrait.TABLE,
+                                                    StaticArgumentTrait.ROW_SEMANTIC_TABLE))
+                                    .withConditionalTrait(
+                                            StaticArgumentTrait.SET_SEMANTIC_TABLE,
+                                            TraitCondition.hasPartitionBy()),
                             StaticArgument.scalar("op", DataTypes.DESCRIPTOR(), true),
                             StaticArgument.scalar(
                                     "op_mapping",
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogSemanticTests.java
index 4643a51be52e0..dc62330b257da 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogSemanticTests.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogSemanticTests.java
@@ -42,6 +42,7 @@ public List programs() {
                 FromChangelogTestPrograms.DEFAULT_OP_MAPPING,
                 FromChangelogTestPrograms.CUSTOM_OP_MAPPING,
                 FromChangelogTestPrograms.CUSTOM_OP_NAME,
+                FromChangelogTestPrograms.SET_SEMANTICS_PARTITION_BY,
                 FromChangelogTestPrograms.SKIP_INVALID_OP_HANDLING,
                 FromChangelogTestPrograms.SKIP_NULL_OP_CODE,
                 FromChangelogTestPrograms.TABLE_API_DEFAULT,
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogTestPrograms.java
index 1858aabeed019..2cd785a46b695 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogTestPrograms.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/FromChangelogTestPrograms.java
@@ -172,6 +172,44 @@ public class FromChangelogTestPrograms {
                                     + "op => DESCRIPTOR(operation))")
                     .build();
 
+    // --------------------------------------------------------------------------------------------
+    // Set semantics with PARTITION BY
+    // --------------------------------------------------------------------------------------------
+
+    /**
+     * Verifies that {@code FROM_CHANGELOG(TABLE t PARTITION BY id)} produces the same logical
+     * output as the row-semantic call. The conditional {@code SET_SEMANTIC_TABLE} trait switches
+     * the execution to a co-located parallel mode but must not change row-level semantics.
+     */
+    public static final TableTestProgram SET_SEMANTICS_PARTITION_BY =
+            TableTestProgram.of(
+                            "from-changelog-set-semantics-partition-by",
+                            "PARTITION BY enables set semantics without altering output rows")
+                    .setupTableSource(
+                            SourceTestStep.newBuilder("cdc_stream")
+                                    .addSchema(SIMPLE_CDC_SCHEMA)
+                                    .producedValues(
+                                            Row.of(1, "INSERT", "Alice"),
+                                            Row.of(2, "INSERT", "Bob"),
+                                            Row.of(1, "UPDATE_BEFORE", "Alice"),
+                                            Row.of(1, "UPDATE_AFTER", "Alice2"),
+                                            Row.of(2, "DELETE", "Bob"))
+                                    .build())
+                    .setupTableSink(
+                            SinkTestStep.newBuilder("sink")
+                                    .addSchema("id INT", "name STRING")
+                                    .consumedValues(
+                                            Row.ofKind(RowKind.INSERT, 1, "Alice"),
+                                            Row.ofKind(RowKind.INSERT, 2, "Bob"),
+                                            Row.ofKind(RowKind.UPDATE_BEFORE, 1, "Alice"),
+                                            Row.ofKind(RowKind.UPDATE_AFTER, 1, "Alice2"),
+                                            Row.ofKind(RowKind.DELETE, 2, "Bob"))
+                                    .build())
+                    .runSql(
+                            "INSERT INTO sink SELECT * FROM FROM_CHANGELOG("
+                                    + "input => TABLE cdc_stream PARTITION BY id)")
+                    .build();
+
     // --------------------------------------------------------------------------------------------
     // Table API test
     // --------------------------------------------------------------------------------------------
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.java
index ba2c1a5690cb6..392abda3cca4d 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.java
@@ -75,4 +75,18 @@ void testCustomOpMapping() {
                         + "error_handling => 'SKIP')",
                 CHANGELOG_MODE);
     }
+
+    @Test
+    void testSetSemanticsWithPartitionBy() {
+        util.tableEnv()
+                .executeSql(
+                        "CREATE TABLE cdc_stream ("
+                                + "  id INT,"
+                                + "  op STRING,"
+                                + "  name STRING"
+                                + ") WITH ('connector' = 'values')");
+        util.verifyRelPlan(
+                "SELECT * FROM FROM_CHANGELOG(input => TABLE cdc_stream PARTITION BY id)",
+                CHANGELOG_MODE);
+    }
 }
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.xml
index 614eb5456b9aa..a1ab91c2998fc 100644
--- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.xml
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/FromChangelogTest.xml
@@ -51,6 +51,26 @@ LogicalProject(id=[$0], name=[$1])
       
+    
+  
+  
+    
+       TABLE cdc_stream PARTITION BY id)]]>
+    
+    
+      
+    
+    
+