[Fix](nereids) fix rule merge_aggregate when has project (#33892)
This commit is contained in:
@ -303,7 +303,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
|
||||
topic("Eliminate GroupBy",
|
||||
topDown(new EliminateGroupBy(),
|
||||
new MergeAggregate())
|
||||
new MergeAggregate(),
|
||||
// need to adjust min/max/sum nullable attribute after merge aggregate
|
||||
new AdjustAggregateNullableForEmptySet())
|
||||
),
|
||||
|
||||
topic("Eager aggregation",
|
||||
|
||||
@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
|
||||
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
|
||||
LogicalAggregate<Plan> innerAgg = project.child();
|
||||
|
||||
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
|
||||
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
|
||||
project.getProjects(), (List) outputExpressions);
|
||||
// rewrite agg function. e.g. max(max)
|
||||
List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream()
|
||||
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
|
||||
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
|
||||
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
|
||||
.collect(Collectors.toList());
|
||||
// rewrite agg function directly refer to the slot below the project
|
||||
List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(),
|
||||
(List) aggFunc);
|
||||
// replace groupByKeys directly refer to the slot below the project
|
||||
List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
|
||||
outerAgg.getGroupByExpressions());
|
||||
@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
}
|
||||
|
||||
boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
|
||||
boolean sameGroupBy) {
|
||||
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
|
||||
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
|
||||
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
|
||||
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
|
||||
(existValue, newValue) -> existValue));
|
||||
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
|
||||
for (AggregateFunction outerFunc : aggregateFunctions) {
|
||||
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
|
||||
(List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
|
||||
projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
|
||||
.orElse(new ArrayList<>(aggregateFunctions));
|
||||
for (AggregateFunction outerFunc : replacedAggFunctions) {
|
||||
if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
|
||||
return false;
|
||||
}
|
||||
@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
}
|
||||
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
|
||||
|
||||
return commonCheck(outerAgg, innerAgg, sameGroupBy);
|
||||
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
|
||||
}
|
||||
|
||||
private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
|
||||
@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
return false;
|
||||
}
|
||||
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
|
||||
return commonCheck(outerAgg, innerAgg, sameGroupBy);
|
||||
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user