@@ -151,8 +151,9 @@ void BuildSimpleMapElementLambda(TExprNode::TPtr resultExpr, const TVector<std::
151151TExprNode::TPtr BuildExpressionMap (TExprNode::TPtr resultExpr,
152152 const TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& aggFieldsExpressionsMap,
153153 const TVector<std::pair<TInfoUnit, TInfoUnit>>& aggFieldsRenamesMap,
154- const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap, TExprContext& ctx,
155- TPositionHandle pos) {
154+ const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap,
155+ const THashMap<uint32_t , std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap,
156+ TExprContext& ctx, TPositionHandle pos) {
156157 // Add expressions
157158 TVector<TExprNode::TPtr> mapElements;
158159 for (const auto & [colName, expr] : aggFieldsExpressionsMap) {
@@ -171,6 +172,19 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
171172 BuildSimpleMapElementLambda (resultExpr, aggFieldsRenamesMap, mapElements, ctx, pos);
172173 BuildSimpleMapElementLambda (resultExpr, groupByKeysRenamesMap, mapElements, ctx, pos);
173174
175+ // Add expressions for group by keys.
176+ for (const auto & [_, pair] : groupByKeysExpressionsMap) {
177+ // clang-format off
178+ mapElements.push_back (Build<TKqpOpMapElementLambda>(ctx, pos)
179+ .Input (resultExpr)
180+ .Variable ()
181+ .Value (pair.first .GetFullName ())
182+ .Build ()
183+ .Lambda (pair.second )
184+ .Done ().Ptr ());
185+ // clang-format on
186+ }
187+
174188 // clang-format off
175189 return Build<TKqpOpMap>(ctx, pos)
176190 .Input (resultExpr)
@@ -184,8 +198,8 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
184198 // clang-format on
185199}
186200
187- void BuildMapElementRename (TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap, TVector<TExprNode::TPtr>& mapElements,
188- TExprContext& ctx, TPositionHandle pos) {
201+ void BuildMapElementRename (TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap,
202+ TVector<TExprNode::TPtr>& mapElements, TExprContext& ctx, TPositionHandle pos) {
189203 for (const auto & [colName, newColName] : renamesMap) {
190204 // clang-format off
191205 mapElements.push_back (Build<TKqpOpMapElementRename>(ctx, pos)
@@ -552,20 +566,38 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
552566 filterExpr = Build<TKqpOpEmptySource>(ctx, node->Pos ()).Done ().Ptr ();
553567 }
554568
555- // FIXME: Group by key can be an expression, we need to handle this case
569+ THashSet<TString> aggregationColumnsRequireCastToPgType;
556570 TVector<std::pair<TInfoUnit, TInfoUnit>> groupByKeysRenamesMap;
571+ THashMap<uint32_t , std::pair<TInfoUnit, TExprNode::TPtr>> groupByKeysExpressionsMap;
557572 TVector<TInfoUnit> groupByKeys;
558573 auto groupOps = GetSetting (setItem->Tail (), " group_exprs" );
559574 if (groupOps) {
560575 const auto groupByList = groupOps->TailPtr ();
561576 for (ui32 i = 0 ; i < groupByList->ChildrenSize (); ++i) {
562577 auto lambda = TCoLambda (ctx.DeepCopyLambda (*(groupByList->ChildPtr (i)->Child (1 ))));
563578 auto body = lambda.Body ().Ptr ();
564- TVector<TInfoUnit> keys;
565- GetAllMembers (body, keys);
566- groupByKeys.insert (groupByKeys.end (), keys.begin (), keys.end ());
567- for (const auto &infoUnit : keys) {
568- groupByKeysRenamesMap.push_back ({infoUnit, infoUnit});
579+ auto pgResolvedOp = GetPgCallable (lambda.Body ().Ptr (), " PgResolvedOp" );
580+ // Expression for group by keys.
581+ if (pgResolvedOp) {
582+ auto fromPg = ctx.NewCallable (node->Pos (), " FromPg" , {pgResolvedOp});
583+
584+ // clang-format off
585+ auto groupExprLambda = Build<TCoLambda>(ctx, node->Pos ())
586+ .Args (lambda.Args ())
587+ .Body (fromPg)
588+ .Done ().Ptr ();
589+ // clang-format on
590+
591+ const auto newColName = TInfoUnit (GenerateUniqueColumnName (" _group_expr_" ));
592+ groupByKeysExpressionsMap[i] = std::make_pair (newColName, groupExprLambda);
593+ groupByKeys.push_back (newColName);
594+ } else {
595+ TVector<TInfoUnit> keys;
596+ GetAllMembers (body, keys);
597+ Y_ENSURE (keys.size () == 1 , " Invalid size of the group keys." );
598+ const auto groupKeyName = keys.front ();
599+ groupByKeys.push_back (groupKeyName);
600+ groupByKeysRenamesMap.push_back ({groupKeyName, groupKeyName});
569601 }
570602 }
571603 }
@@ -575,8 +607,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
575607 Y_ENSURE (result);
576608 auto finalType = node->GetTypeAnn ()->Cast <TListExprType>()->GetItemType ()->Cast <TStructExprType>();
577609
578- // This is a hack to enable convertion for aggregation columns.
579- THashSet<TString> aggregationColumns;
580610 THashSet<TString> columnNames;
581611 // Collect PgAgg for each result item at first pass.
582612 TVector<TExprNode::TPtr> aggTraitsList;
@@ -598,14 +628,18 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
598628 TVector<TInfoUnit> originalColNames;
599629 GetAllMembers (pgAgg, originalColNames);
600630 auto pgResolvedOp = GetPgCallable (lambda.Body ().Ptr (), " PgResolvedOp" );
601- // Y_ENSURE(originalColNames.size() > 1 && pgResolvedOp, "Invalid column size for aggregation columns.");
602631
603632 auto originalColName = originalColNames.front ();
604633 auto renamedColName = originalColName;
605634
606635 if (pgResolvedOp) {
607636 auto fromPg = ctx.NewCallable (node->Pos (), " FromPg" , {pgResolvedOp});
608- auto exprLambda = Build<TCoLambda>(ctx, node->Pos ()).Args (lambda.Args ()).Body (fromPg).Done ().Ptr ();
637+ // clang-format off
638+ auto exprLambda = Build<TCoLambda>(ctx, node->Pos ())
639+ .Args (lambda.Args ())
640+ .Body (fromPg)
641+ .Done ().Ptr ();
642+ // clang-format on
609643
610644 // Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression))
611645 originalColName = TInfoUnit (GenerateUniqueColumnName (" _expr_" ));
@@ -623,7 +657,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
623657 columnNames.insert (renamedColName.GetFullName ());
624658 Y_ENSURE (!GetAtom (pgAgg->ChildPtr (1 ), " distinct" ), " Aggregation on distinct is not supported" );
625659
626- aggregationColumns .insert (resultColName);
660+ aggregationColumnsRequireCastToPgType .insert (resultColName);
627661 const TString aggFuncName = TString (pgAgg->ChildPtr (0 )->Content ());
628662 auto aggregationTraits = BuildAggregationTraits (renamedColName.GetFullName (), resultColName, aggFuncName,
629663 aggFuncResultType, ctx, node->Pos ());
@@ -637,7 +671,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
637671 }
638672 // This case covers distinct all on just columns without aggregation functions.
639673 } else if (!pgAgg && distinctAll) {
640- aggregationColumns .insert (resultColName);
674+ aggregationColumnsRequireCastToPgType .insert (resultColName);
641675 Y_ENSURE (aggFuncResultType, " Cannot find type for aggregation result." );
642676 TVector<TInfoUnit> originalColNames;
643677 GetAllMembers (resultItem->ChildPtr (2 ), originalColNames);
@@ -658,9 +692,10 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
658692 if (needRenameMap) {
659693 resultExpr = BuildRenameMap (resultExpr, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos ());
660694 }
661- // In case we have an expression for aggregation - f(a + b ..)
662- if (!aggFieldsExpressionsMap.empty ()) {
663- resultExpr = BuildExpressionMap (resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos ());
695+ // In case we have an expression for aggregation - f(a + b ..) or group by.
696+ if (!aggFieldsExpressionsMap.empty () || !groupByKeysExpressionsMap.empty ()) {
697+ resultExpr = BuildExpressionMap (resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap,
698+ groupByKeysExpressionsMap, ctx, node->Pos ());
664699 }
665700 resultExpr = BuildAggregate (resultExpr, groupByKeys, aggTraitsList, distinctAll, ctx, node->Pos ());
666701 }
@@ -671,7 +706,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
671706 }
672707
673708 finalColumnOrder.clear ();
674- THashMap<TString, TExprNode::TPtr> aggProjectionMap;
675709
676710 for (auto resultItem : result->Child (1 )->Children ()) {
677711 auto column = resultItem->Child (0 );
@@ -691,7 +725,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
691725
692726 bool needPgCast = (expectedType->GetId () != actualPgTypeId);
693727 auto lambda = TCoLambda (ctx.DeepCopyLambda (*(resultItem->Child (2 ))));
694- bool needPgCastForAgg = aggregationColumns .count (columnName);
728+ bool needPgCastForAgg = aggregationColumnsRequireCastToPgType .count (columnName);
695729
696730 auto pgAgg = GetPgCallable (lambda.Body ().Ptr (), " PgAgg" );
697731 if (pgAgg) {
@@ -715,14 +749,28 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
715749 // Eliminate `PgGroupRef` from projection lambda.
716750 auto pgGroupRef = GetPgCallable (lambda.Body ().Ptr (), " PgGroupRef" );
717751 if (pgGroupRef) {
718- Y_ENSURE (pgGroupRef->ChildrenSize () == 4 );
752+ TString columnName;
753+ if (pgGroupRef->ChildrenSize () == 4 ) {
754+ columnName = TString (pgGroupRef->ChildPtr (3 )->Content ());
755+ } else if (pgGroupRef->ChildrenSize () == 3 ) {
756+ // In this case we can get a column name from group expr map
757+ const auto groupByKeyExprId = FromString<uint32_t >(TString (pgGroupRef->ChildPtr (2 )->Content ()));
758+ auto it = groupByKeysExpressionsMap.find (groupByKeyExprId);
759+ Y_ENSURE (it != groupByKeysExpressionsMap.end (), " Group by key expression has invalid content." );
760+ columnName = it->second .first .GetFullName ();
761+ // Always need a pg cast for expressions.
762+ needPgCast = true ;
763+ } else {
764+ Y_ENSURE (false , " Invalid children size for `pgGroupRef`" );
765+ }
766+
719767 // clang-format off
720768 lambda = Build<TCoLambda>(ctx, node->Pos ())
721769 .Args (lambda.Args ())
722770 .Body <TCoMember>()
723771 .Struct (lambda.Args ().Arg (0 ))
724772 .Name <TCoAtom>()
725- .Value (pgGroupRef-> ChildPtr ( 3 )-> Content () )
773+ .Value (columnName )
726774 .Build ()
727775 .Build ()
728776 .Done ();
0 commit comments