From ac495762292d3efb0f563d07414624808c56a563 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei <53502832+feiniaofeiafei@users.noreply.github.com> Date: Mon, 27 May 2024 20:40:57 +0800 Subject: [PATCH] [Fix](nereids) fix merge aggregate setting top projection bug (#35348) introduced by #31811 sql like this: select col1, col2 from (select a as col1, a as col2 from mal_test1 group by a) t group by col1, col2 ; Transformation Description: In the process of optimizing the query, an agg-project-agg pattern is transformed into a project-agg pattern: Before Transformation: LogicalAggregate +-- LogicalPrject +-- LogicalAggregate After Transformation: LogicalProject +-- LogicalAggregate Before the transformation, the projection in the LogicalProject was a AS col1, a AS col2, and the outer aggregate group by keys were col1, col2. After the transformation, the aggregate group by keys became a, a, and the projection remained a AS col1, a AS col2. Problem: When building the project projections, the group by key a, a needed to be transformed to a AS col1, a AS col2. The old code had a bug where it used the slot as the map key and the alias in the projections as the map value. This approach did not account for the situation where aliases might have the same slot. Solution: The new code fixes this issue by using the original outer aggregate group by expression's exprId. It searches within the original project projections to find the NamedExpression that has the same exprId. These expressions are then placed into the new projections. This method ensures that the correct aliases are maintained, resolving the bug. --- .../nereids/rules/rewrite/MergeAggregate.java | 20 +++++++++++++++---- .../merge_aggregate/merge_aggregate.out | 9 +++++++++ .../merge_aggregate/merge_aggregate.groovy | 6 ++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 8ea8a7f217..889adfb69f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -17,8 +17,10 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.annotation.DependsRules; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -44,6 +46,9 @@ import java.util.Set; import java.util.stream.Collectors; /**MergeAggregate*/ +@DependsRules({ + NormalizeAggregate.class +}) public class MergeAggregate implements RewriteRuleFactory { private static final ImmutableSet ALLOW_MERGE_AGGREGATE_FUNCTIONS = ImmutableSet.of("min", "max", "sum", "any_value"); @@ -108,10 +113,17 @@ public class MergeAggregate implements RewriteRuleFactory { .withChildren(innerAgg.children()); // construct upper project - Map childToAlias = project.getProjects().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) - .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias)); - List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); + Map exprIdToNameExpressionMap = new HashMap<>(); + for (NamedExpression pro : project.getProjects()) { + exprIdToNameExpressionMap.put(pro.getExprId(), pro); + } + List originOuterAggGroupBy = outerAgg.getGroupByExpressions(); + List projectGroupBy = new ArrayList<>(); + for (Expression expression : originOuterAggGroupBy) { + ExprId exprId = ((NamedExpression) expression).getExprId(); + NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); + projectGroupBy.add(namedExpression); + } List upperProjects = ImmutableList.builder() .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator()) diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index fba17e8d7b..d7103bfed9 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -297,3 +297,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalOlapScan[mal_test2] +-- !agg_project_agg_the_project_has_duplicate_slot_output -- +1 7 7 +2 4 4 +6 \N \N +7 1 1 +8 2 2 +8 5 5 +9 3 3 + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 039f087c93..4a20cf4d68 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -256,4 +256,10 @@ suite("merge_aggregate") { explain shape plan select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; """ + + qt_agg_project_agg_the_project_has_duplicate_slot_output """ + select max(col1), col10, col11 from + (select a,max(b) as col1, count(b) as col4, a as col10, a as col11 + from mal_test1 group by a) t group by col10, col11 order by 1,2,3; + """ }