[Fix](nereids) fix merge aggregate rule, rules should not have mutable members (#36223)

cherry-pick #36145  to branch-2.1
This commit is contained in:
feiniaofeiafei
2024-06-13 17:49:57 +08:00
committed by GitHub
parent d70751a808
commit e2f7e0da0a

View File

@ -52,7 +52,6 @@ import java.util.stream.Collectors;
public class MergeAggregate implements RewriteRuleFactory {
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
ImmutableSet.of("min", "max", "sum", "any_value");
private Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = new HashMap<>();
@Override
public List<Rule> buildRules() {
@ -75,7 +74,7 @@ public class MergeAggregate implements RewriteRuleFactory {
*/
private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
List<NamedExpression> newOutputExpressions = outerAgg.getOutputExpressions().stream()
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
@ -97,6 +96,7 @@ public class MergeAggregate implements RewriteRuleFactory {
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
// rewrite agg function. e.g. max(max)
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
@ -152,10 +152,7 @@ public class MergeAggregate implements RewriteRuleFactory {
private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) (List) PlanUtils.replaceExpressionByProjections(
@ -225,4 +222,11 @@ public class MergeAggregate implements RewriteRuleFactory {
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}
private Map<ExprId, AggregateFunction> getInnerAggExprIdToAggFuncMap(LogicalAggregate<Plan> innerAgg) {
return innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
}
}