[Fix](nereids) fix merge aggregate rule, rules should not have mutable members (#36223)
cherry-pick #36145 to branch-2.1
This commit is contained in:
@ -52,7 +52,6 @@ import java.util.stream.Collectors;
|
||||
public class MergeAggregate implements RewriteRuleFactory {
|
||||
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
|
||||
ImmutableSet.of("min", "max", "sum", "any_value");
|
||||
private Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = new HashMap<>();
|
||||
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
@ -75,7 +74,7 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
*/
|
||||
private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
|
||||
LogicalAggregate<Plan> innerAgg = outerAgg.child();
|
||||
|
||||
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
|
||||
List<NamedExpression> newOutputExpressions = outerAgg.getOutputExpressions().stream()
|
||||
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
|
||||
.collect(Collectors.toList());
|
||||
@ -97,6 +96,7 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
|
||||
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
|
||||
project.getProjects(), (List) outputExpressions);
|
||||
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
|
||||
// rewrite agg function. e.g. max(max)
|
||||
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
|
||||
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
|
||||
@ -152,10 +152,7 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
|
||||
private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
|
||||
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
|
||||
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
|
||||
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
|
||||
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
|
||||
(existValue, newValue) -> existValue));
|
||||
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
|
||||
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
|
||||
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
|
||||
(List<AggregateFunction>) (List) PlanUtils.replaceExpressionByProjections(
|
||||
@ -225,4 +222,11 @@ public class MergeAggregate implements RewriteRuleFactory {
|
||||
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
|
||||
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
|
||||
}
|
||||
|
||||
private Map<ExprId, AggregateFunction> getInnerAggExprIdToAggFuncMap(LogicalAggregate<Plan> innerAgg) {
|
||||
return innerAgg.getOutputExpressions().stream()
|
||||
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
|
||||
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
|
||||
(existValue, newValue) -> existValue));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user