diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index e920036247..d09204403f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -190,10 +190,14 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { // 1. group by exprs // 2. trivialAgg children // 3. trivialAgg input slots - Set allPushDownExprs = - Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots)); - NormalizeToSlotContext bottomSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs); + // We need to distinguish between expressions in aggregate function arguments and group by expressions. + NormalizeToSlotContext groupByExprContext = NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs); + Set existsAliasAndGroupByAlias = getExistsAlias(existsAlias, groupByExprContext.getNormalizeToSlotMap()); + Set argsOfAggFuncNeedPushDown = Sets.union(needPushSelf, needPushInputSlots); + NormalizeToSlotContext argsOfAggFuncNeedPushDownContext = NormalizeToSlotContext + .buildContext(existsAliasAndGroupByAlias, argsOfAggFuncNeedPushDown); + NormalizeToSlotContext bottomSlotContext = argsOfAggFuncNeedPushDownContext.mergeContext(groupByExprContext); + Set pushedGroupByExprs = bottomSlotContext.pushDownToNamedExpression(groupingByExprs); Set pushedTrivialAggChildren = @@ -256,8 +260,12 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan); // create upper projects by normalize all output exprs in old LogicalAggregate + // In aggregateOutput, the expressions inside the agg function can be rewritten + // with expressions in aggregate function arguments and group by expressions, + // but the ones outside the agg function can only be rewritten with group by expressions. + // After the above two rewrites are completed, use aggregate output agg functions to rewrite. List upperProjects = normalizeOutput(aggregateOutput, - bottomSlotContext, normalizedAggFuncsToSlotContext); + groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext); // create a parent project node LogicalProject project = new LogicalProject<>(upperProjects, newAggregate); @@ -302,11 +310,18 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { } private List normalizeOutput(List aggregateOutput, - NormalizeToSlotContext groupByToSlotContext, NormalizeToSlotContext normalizedAggFuncsToSlotContext) { + NormalizeToSlotContext groupByToSlotContext, NormalizeToSlotContext argsOfAggFuncNeedPushDownContext, + NormalizeToSlotContext normalizedAggFuncsToSlotContext) { // build upper project, use two context to do pop up, because agg output maybe contain two part: - // group by keys and agg expressions - List upperProjects = groupByToSlotContext - .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput); + // group by keys and agg expressions + List upperProjects = new ArrayList<>(); + for (Expression expr : aggregateOutput) { + Expression rewrittenExpr = expr.rewriteDownShortCircuit( + e -> normalizeAggFuncChildren( + argsOfAggFuncNeedPushDownContext, e)); + upperProjects.add((NamedExpression) rewrittenExpr); + } + upperProjects = groupByToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); Builder builder = new ImmutableList.Builder<>(); @@ -338,4 +353,28 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { slots.addAll(ExpressionUtils.getInputSlotSet(expressions)); return slots; } + + private Set getExistsAlias(Set originAliases, + Map groupingExprMap) { + Set existsAlias = Sets.newHashSet(); + existsAlias.addAll(originAliases); + for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) { + if (triplet.pushedExpr instanceof Alias) { + Alias alias = (Alias) triplet.pushedExpr; + existsAlias.add(alias); + } + } + return existsAlias; + } + + private Expression normalizeAggFuncChildren(NormalizeToSlotContext context, Expression expr) { + if (expr instanceof AggregateFunction) { + AggregateFunction function = (AggregateFunction) expr; + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + function = function.withChildren(normalizedRealExpressions); + return function; + } else { + return expr; + } + } } diff --git a/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out b/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out new file mode 100644 index 0000000000..50c132b2f7 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out @@ -0,0 +1,3 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test_upper_project_projections_rewrite2 -- + diff --git a/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy b/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy new file mode 100644 index 0000000000..f3c732c067 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("normalize_aggregate") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "drop table if exists normalize_aggregate_tab" + sql """CREATE TABLE normalize_aggregate_tab(col0 INTEGER, col1 INTEGER, col2 INTEGER) distributed by hash(col0) buckets 10 + properties('replication_num' = '1'); """ + qt_test_upper_project_projections_rewrite2 """ + SELECT - + AVG ( DISTINCT - col0 ) * - col0 FROM + normalize_aggregate_tab WHERE + - col0 IS NULL GROUP BY col0 HAVING NULL IS NULL;""" +} \ No newline at end of file