[Fix](nereids) fix rule merge_aggregate when has project (#33892)

This commit is contained in:
feiniaofeiafei
2024-04-26 12:34:24 +08:00
committed by yiguolei
parent a34ed4643a
commit b41a5339d3
4 changed files with 148 additions and 10 deletions

View File

@ -303,7 +303,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
new MergeAggregate())
new MergeAggregate(),
// need to adjust min/max/sum nullable attribute after merge aggregate
new AdjustAggregateNullableForEmptySet())
),
topic("Eager aggregation",

View File

@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory {
private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
LogicalAggregate<Plan> innerAgg = project.child();
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
// rewrite agg function. e.g. max(max)
List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream()
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
// rewrite agg function directly refer to the slot below the project
List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(),
(List) aggFunc);
// replace groupByKeys directly refer to the slot below the project
List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory {
}
boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy) {
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
for (AggregateFunction outerFunc : aggregateFunctions) {
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
.orElse(new ArrayList<>(aggregateFunctions));
for (AggregateFunction outerFunc : replacedAggFunctions) {
if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
return false;
}
@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory {
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy);
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
}
private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy);
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}
}