Skip to content

Commit 0ee34eb

Browse files
authored
[KQP RBO] Add group by on expressions (#28922)
1 parent 98de507 commit 0ee34eb

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ void BuildSimpleMapElementLambda(TExprNode::TPtr resultExpr, const TVector<std::
151151
TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
152152
const TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& aggFieldsExpressionsMap,
153153
const TVector<std::pair<TInfoUnit, TInfoUnit>>& aggFieldsRenamesMap,
154-
const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap, TExprContext& ctx,
155-
TPositionHandle pos) {
154+
const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap,
155+
const THashMap<uint32_t, std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap,
156+
TExprContext& ctx, TPositionHandle pos) {
156157
// Add expressions
157158
TVector<TExprNode::TPtr> mapElements;
158159
for (const auto& [colName, expr] : aggFieldsExpressionsMap) {
@@ -171,6 +172,19 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
171172
BuildSimpleMapElementLambda(resultExpr, aggFieldsRenamesMap, mapElements, ctx, pos);
172173
BuildSimpleMapElementLambda(resultExpr, groupByKeysRenamesMap, mapElements, ctx, pos);
173174

175+
// Add expressions for group by keys.
176+
for (const auto& [_, pair] : groupByKeysExpressionsMap) {
177+
// clang-format off
178+
mapElements.push_back(Build<TKqpOpMapElementLambda>(ctx, pos)
179+
.Input(resultExpr)
180+
.Variable()
181+
.Value(pair.first.GetFullName())
182+
.Build()
183+
.Lambda(pair.second)
184+
.Done().Ptr());
185+
// clang-format on
186+
}
187+
174188
// clang-format off
175189
return Build<TKqpOpMap>(ctx, pos)
176190
.Input(resultExpr)
@@ -184,8 +198,8 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
184198
// clang-format on
185199
}
186200

187-
void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap, TVector<TExprNode::TPtr>& mapElements,
188-
TExprContext& ctx, TPositionHandle pos) {
201+
void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap,
202+
TVector<TExprNode::TPtr>& mapElements, TExprContext& ctx, TPositionHandle pos) {
189203
for (const auto& [colName, newColName] : renamesMap) {
190204
// clang-format off
191205
mapElements.push_back(Build<TKqpOpMapElementRename>(ctx, pos)
@@ -552,20 +566,38 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
552566
filterExpr = Build<TKqpOpEmptySource>(ctx, node->Pos()).Done().Ptr();
553567
}
554568

555-
// FIXME: Group by key can be an expression, we need to handle this case
569+
THashSet<TString> aggregationColumnsRequireCastToPgType;
556570
TVector<std::pair<TInfoUnit, TInfoUnit>> groupByKeysRenamesMap;
571+
THashMap<uint32_t, std::pair<TInfoUnit, TExprNode::TPtr>> groupByKeysExpressionsMap;
557572
TVector<TInfoUnit> groupByKeys;
558573
auto groupOps = GetSetting(setItem->Tail(), "group_exprs");
559574
if (groupOps) {
560575
const auto groupByList = groupOps->TailPtr();
561576
for (ui32 i = 0; i < groupByList->ChildrenSize(); ++i) {
562577
auto lambda = TCoLambda(ctx.DeepCopyLambda(*(groupByList->ChildPtr(i)->Child(1))));
563578
auto body = lambda.Body().Ptr();
564-
TVector<TInfoUnit> keys;
565-
GetAllMembers(body, keys);
566-
groupByKeys.insert(groupByKeys.end(), keys.begin(), keys.end());
567-
for (const auto &infoUnit : keys) {
568-
groupByKeysRenamesMap.push_back({infoUnit, infoUnit});
579+
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
580+
// Expression for group by keys.
581+
if (pgResolvedOp) {
582+
auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp});
583+
584+
// clang-format off
585+
auto groupExprLambda = Build<TCoLambda>(ctx, node->Pos())
586+
.Args(lambda.Args())
587+
.Body(fromPg)
588+
.Done().Ptr();
589+
// clang-format on
590+
591+
const auto newColName = TInfoUnit(GenerateUniqueColumnName("_group_expr_"));
592+
groupByKeysExpressionsMap[i] = std::make_pair(newColName, groupExprLambda);
593+
groupByKeys.push_back(newColName);
594+
} else {
595+
TVector<TInfoUnit> keys;
596+
GetAllMembers(body, keys);
597+
Y_ENSURE(keys.size() == 1, "Invalid size of the group keys.");
598+
const auto groupKeyName = keys.front();
599+
groupByKeys.push_back(groupKeyName);
600+
groupByKeysRenamesMap.push_back({groupKeyName, groupKeyName});
569601
}
570602
}
571603
}
@@ -575,8 +607,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
575607
Y_ENSURE(result);
576608
auto finalType = node->GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
577609

578-
// This is a hack to enable convertion for aggregation columns.
579-
THashSet<TString> aggregationColumns;
580610
THashSet<TString> columnNames;
581611
// Collect PgAgg for each result item at first pass.
582612
TVector<TExprNode::TPtr> aggTraitsList;
@@ -598,14 +628,18 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
598628
TVector<TInfoUnit> originalColNames;
599629
GetAllMembers(pgAgg, originalColNames);
600630
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
601-
//Y_ENSURE(originalColNames.size() > 1 && pgResolvedOp, "Invalid column size for aggregation columns.");
602631

603632
auto originalColName = originalColNames.front();
604633
auto renamedColName = originalColName;
605634

606635
if (pgResolvedOp) {
607636
auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp});
608-
auto exprLambda = Build<TCoLambda>(ctx, node->Pos()).Args(lambda.Args()).Body(fromPg).Done().Ptr();
637+
// clang-format off
638+
auto exprLambda = Build<TCoLambda>(ctx, node->Pos())
639+
.Args(lambda.Args())
640+
.Body(fromPg)
641+
.Done().Ptr();
642+
// clang-format on
609643

610644
// Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression))
611645
originalColName = TInfoUnit(GenerateUniqueColumnName("_expr_"));
@@ -623,7 +657,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
623657
columnNames.insert(renamedColName.GetFullName());
624658
Y_ENSURE(!GetAtom(pgAgg->ChildPtr(1), "distinct"), "Aggregation on distinct is not supported");
625659

626-
aggregationColumns.insert(resultColName);
660+
aggregationColumnsRequireCastToPgType.insert(resultColName);
627661
const TString aggFuncName = TString(pgAgg->ChildPtr(0)->Content());
628662
auto aggregationTraits = BuildAggregationTraits(renamedColName.GetFullName(), resultColName, aggFuncName,
629663
aggFuncResultType, ctx, node->Pos());
@@ -637,7 +671,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
637671
}
638672
// This case covers distinct all on just columns without aggregation functions.
639673
} else if (!pgAgg && distinctAll) {
640-
aggregationColumns.insert(resultColName);
674+
aggregationColumnsRequireCastToPgType.insert(resultColName);
641675
Y_ENSURE(aggFuncResultType, "Cannot find type for aggregation result.");
642676
TVector<TInfoUnit> originalColNames;
643677
GetAllMembers(resultItem->ChildPtr(2), originalColNames);
@@ -658,9 +692,10 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
658692
if (needRenameMap) {
659693
resultExpr = BuildRenameMap(resultExpr, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos());
660694
}
661-
// In case we have an expression for aggregation - f(a + b ..)
662-
if (!aggFieldsExpressionsMap.empty()) {
663-
resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos());
695+
// In case we have an expression for aggregation - f(a + b ..) or group by.
696+
if (!aggFieldsExpressionsMap.empty() || !groupByKeysExpressionsMap.empty()) {
697+
resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap,
698+
groupByKeysExpressionsMap, ctx, node->Pos());
664699
}
665700
resultExpr = BuildAggregate(resultExpr, groupByKeys, aggTraitsList, distinctAll, ctx, node->Pos());
666701
}
@@ -671,7 +706,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
671706
}
672707

673708
finalColumnOrder.clear();
674-
THashMap<TString, TExprNode::TPtr> aggProjectionMap;
675709

676710
for (auto resultItem : result->Child(1)->Children()) {
677711
auto column = resultItem->Child(0);
@@ -691,7 +725,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
691725

692726
bool needPgCast = (expectedType->GetId() != actualPgTypeId);
693727
auto lambda = TCoLambda(ctx.DeepCopyLambda(*(resultItem->Child(2))));
694-
bool needPgCastForAgg = aggregationColumns.count(columnName);
728+
bool needPgCastForAgg = aggregationColumnsRequireCastToPgType.count(columnName);
695729

696730
auto pgAgg = GetPgCallable(lambda.Body().Ptr(), "PgAgg");
697731
if (pgAgg) {
@@ -715,14 +749,28 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
715749
// Eliminate `PgGroupRef` from projection lambda.
716750
auto pgGroupRef = GetPgCallable(lambda.Body().Ptr(), "PgGroupRef");
717751
if (pgGroupRef) {
718-
Y_ENSURE(pgGroupRef->ChildrenSize() == 4);
752+
TString columnName;
753+
if (pgGroupRef->ChildrenSize() == 4) {
754+
columnName = TString(pgGroupRef->ChildPtr(3)->Content());
755+
} else if (pgGroupRef->ChildrenSize() == 3) {
756+
// In this case we can get a column name from group expr map
757+
const auto groupByKeyExprId = FromString<uint32_t>(TString(pgGroupRef->ChildPtr(2)->Content()));
758+
auto it = groupByKeysExpressionsMap.find(groupByKeyExprId);
759+
Y_ENSURE(it != groupByKeysExpressionsMap.end(), "Group by key expression has invalid content.");
760+
columnName = it->second.first.GetFullName();
761+
// Always need a pg cast for expressions.
762+
needPgCast = true;
763+
} else {
764+
Y_ENSURE(false, "Invalid children size for `pgGroupRef`");
765+
}
766+
719767
// clang-format off
720768
lambda = Build<TCoLambda>(ctx, node->Pos())
721769
.Args(lambda.Args())
722770
.Body<TCoMember>()
723771
.Struct(lambda.Args().Arg(0))
724772
.Name<TCoAtom>()
725-
.Value(pgGroupRef->ChildPtr(3)->Content())
773+
.Value(columnName)
726774
.Build()
727775
.Build()
728776
.Done();

ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,21 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
609609
SET TablePathPrefix = "/Root/";
610610
select sum(t1.a + 1 + t1.c) as sumExpr0, sum(t1.c + 2) as sumExpr1 from t1 group by t1.b order by sumExpr0;
611611
)",
612+
R"(
613+
--!syntax_pg
614+
SET TablePathPrefix = "/Root/";
615+
select sum(t1.c) as sum0, sum(t1.a + 3) as sum1 from t1 group by t1.b + 1 order by sum0;
616+
)",
617+
R"(
618+
--!syntax_pg
619+
SET TablePathPrefix = "/Root/";
620+
select sum(t1.c) as sum0, t1.b + 1, t1.c + 2 from t1 group by t1.b + 1, t1.c + 2 order by sum0;
621+
)",
622+
R"(
623+
--!syntax_pg
624+
SET TablePathPrefix = "/Root/";
625+
select sum(t1.c + 2) as sum0 from t1 group by t1.b + t1.a order by sum0;
626+
)",
612627
};
613628

614629
std::vector<std::string> results = {
@@ -625,7 +640,10 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
625640
R"([["0";"2"];["1";"1"];["2";"2"];["3";"1"];["4";"2"]])",
626641
R"([["4";"4"];["6";"6"]])",
627642
R"([["0";"4"];["1";"3"]])",
628-
R"([["10";"8"];["15";"12"]])"
643+
R"([["10";"8"];["15";"12"]])",
644+
R"([["4";"10"];["6";"15"]])",
645+
R"([["4";"2";"4"];["6";"3";"4"]])",
646+
R"([["4"];["8"];["8"]])"
629647
};
630648

631649
for (ui32 i = 0; i < queries.size(); ++i) {

0 commit comments

Comments
 (0)