[Fix](nereids) fix merge aggregate setting top projection bug (#35348)

introduced by #31811

sql like this:

    select col1, col2 from  (select a as col1, a as col2 from mal_test1 group by a) t group by col1, col2 ;

Transformation Description:
In the process of optimizing the query, an agg-project-agg pattern is transformed into a project-agg pattern:
Before Transformation:

LogicalAggregate
+-- LogicalPrject
    +-- LogicalAggregate

After Transformation:

LogicalProject
+-- LogicalAggregate

Before the transformation, the projection in the LogicalProject was a AS col1, a AS col2, and the outer aggregate group by keys were col1, col2. After the transformation, the aggregate group by keys became a, a, and the projection remained a AS col1, a AS col2.

Problem:
When building the project projections, the group by key a, a needed to be transformed to a AS col1, a AS col2. The old code had a bug where it used the slot as the map key and the alias in the projections as the map value. This approach did not account for the situation where aliases might have the same slot.

Solution:
The new code fixes this issue by using the original outer aggregate group by expression's exprId. It searches within the original project projections to find the NamedExpression that has the same exprId. These expressions are then placed into the new projections. This method ensures that the correct aliases are maintained, resolving the bug.
This commit is contained in:
feiniaofeiafei
2024-05-27 20:40:57 +08:00
committed by yiguolei
parent 7c808fcecf
commit ac49576229
3 changed files with 31 additions and 4 deletions

View File

@ -17,8 +17,10 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -44,6 +46,9 @@ import java.util.Set;
import java.util.stream.Collectors;
/**MergeAggregate*/
@DependsRules({
NormalizeAggregate.class
})
public class MergeAggregate implements RewriteRuleFactory {
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
ImmutableSet.of("min", "max", "sum", "any_value");
@ -108,10 +113,17 @@ public class MergeAggregate implements RewriteRuleFactory {
.withChildren(innerAgg.children());
// construct upper project
Map<SlotReference, Alias> childToAlias = project.getProjects().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference))
.collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias));
List<Expression> projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias);
Map<ExprId, NamedExpression> exprIdToNameExpressionMap = new HashMap<>();
for (NamedExpression pro : project.getProjects()) {
exprIdToNameExpressionMap.put(pro.getExprId(), pro);
}
List<Expression> originOuterAggGroupBy = outerAgg.getGroupByExpressions();
List<Expression> projectGroupBy = new ArrayList<>();
for (Expression expression : originOuterAggGroupBy) {
ExprId exprId = ((NamedExpression) expression).getExprId();
NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId);
projectGroupBy.add(namedExpression);
}
List<NamedExpression> upperProjects = ImmutableList.<NamedExpression>builder()
.addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator())
.addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator())