Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAg

private boolean canMergeAggregateWithoutProject(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();
if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAgg.getGroupByExpressions())) {
Set<Expression> innerGroupByExpressions = new HashSet<>(innerAgg.getGroupByExpressions());
Set<Expression> outerGroupByExpressions = new HashSet<>(outerAgg.getGroupByExpressions());
if (!innerGroupByExpressions.containsAll(outerGroupByExpressions)) {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
boolean sameGroupBy = innerGroupByExpressions.equals(outerGroupByExpressions);

return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
}
Expand All @@ -210,15 +212,17 @@ private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<Log

List<Expression> outerAggGroupByKeys = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAggGroupByKeys)) {
Set<Expression> innerGroupByExpressions = new HashSet<>(innerAgg.getGroupByExpressions());
Set<Expression> outerGroupByExpressions = new HashSet<>(outerAggGroupByKeys);
if (!innerGroupByExpressions.containsAll(outerGroupByExpressions)) {
return false;
}
// project cannot have expressions like a+1
if (ExpressionUtils.deapAnyMatch(project.getProjects(),
expr -> !(expr instanceof SlotReference) && !(expr instanceof Alias))) {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
boolean sameGroupBy = innerGroupByExpressions.equals(outerGroupByExpressions);
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
Expand All @@ -40,10 +43,21 @@
* Unit tests for {@link MergeAggregate}, specifically testing the fix for filtering
* aggregate functions in mergeAggProjectAgg method.
*/
public class MergeAggregateTest {
public class MergeAggregateTest extends TestWithFeService implements MemoPatternMatchSupported {

private MergeAggregate mergeAggregate;

@Override
protected void runBeforeAll() throws Exception {
createDatabase("merge_aggregate_test");
connectContext.setDatabase("merge_aggregate_test");
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
createTable("CREATE TABLE merge_aggregate_test.duplicate_alias_table ("
+ "a INT NOT NULL, b INT NOT NULL, c INT NULL) "
+ "DUPLICATE KEY(a, b, c) DISTRIBUTED BY HASH(a) BUCKETS 1 "
+ "PROPERTIES('replication_num' = '1')");
}

@BeforeEach
public void setUp() {
mergeAggregate = new MergeAggregate();
Expand Down Expand Up @@ -121,4 +135,16 @@ public void testMergeAggProjectAggWithMixedExpressions() throws Exception {
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) resultProject.child(0);
Assertions.assertEquals(aggregate.getOutput().size(), 2);
}

@Test
public void testDoNotMergeDistinctAggregateWithDuplicateProjectedGroupBy() {
String sql = "SELECT g1, g2, SUM(s) FROM ("
+ "SELECT a AS g1, a AS g2, COUNT(DISTINCT c) AS s "
+ "FROM duplicate_alias_table GROUP BY a, b) t GROUP BY g1, g2";

PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate(logicalProject(logicalAggregate())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,8 @@ PhysicalResultSink
8 5 5
9 3 3

-- !duplicate_alias_distinct_result --
1 1 3
2 2 3
3 3 0

Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,34 @@ suite("merge_aggregate") {
(select a,max(b) as col1, count(b) as col4, a as col10, a as col11
from mal_test1 group by a) t group by col10, col11 order by 1,2,3;
"""

sql "drop table if exists mal_duplicate_alias_distinct"
sql """
create table mal_duplicate_alias_distinct (
a int not null,
b int not null,
c int null
)
duplicate key (a, b, c)
distributed by hash(a) buckets 1
properties("replication_num" = "1");
"""
sql """
insert into mal_duplicate_alias_distinct values
(1, 10, 100), (1, 10, 101), (1, 20, 100),
(2, 10, 200), (2, 20, 200), (2, 30, 201),
(3, 30, null), (3, 40, null);
"""
sql "sync"

order_qt_duplicate_alias_distinct_result """
select g1, g2, sum(s) as total_s
from (
select a as g1, a as g2, count(distinct c) as s
from mal_duplicate_alias_distinct
group by a, b
) t
group by g1, g2
order by g1, g2;
"""
}
Loading