diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index db46335412..c4fa3abd0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -81,6 +81,7 @@ import org.apache.doris.nereids.rules.rewrite.InferPredicates; import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct; import org.apache.doris.nereids.rules.rewrite.InlineLogicalView; import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN; +import org.apache.doris.nereids.rules.rewrite.MergeAggregate; import org.apache.doris.nereids.rules.rewrite.MergeFilters; import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; import org.apache.doris.nereids.rules.rewrite.MergeProjects; @@ -341,7 +342,8 @@ public class Rewriter extends AbstractBatchJobExecutor { ), topic("Eliminate GroupBy", - topDown(new EliminateGroupBy()) + topDown(new EliminateGroupBy(), + new MergeAggregate()) ), topic("Eager aggregation", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1ec52fe8e1..1fda32a400 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -196,6 +196,7 @@ public enum RuleType { MERGE_LIMITS(RuleTypeClass.REWRITE), MERGE_GENERATES(RuleTypeClass.REWRITE), // Eliminate plan + MERGE_AGGREGATE(RuleTypeClass.REWRITE), ELIMINATE_AGGREGATE(RuleTypeClass.REWRITE), ELIMINATE_LIMIT(RuleTypeClass.REWRITE), ELIMINATE_LIMIT_ON_ONE_ROW_RELATION(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index cf94caa25c..4d0c9be368 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -151,9 +151,7 @@ public class ColumnPruning extends DefaultPlanRewriter implements if (union.getQualifier() == Qualifier.DISTINCT) { return skipPruneThisAndFirstLevelChildren(union); } - - LogicalUnion prunedOutputUnion = pruneOutput(union, union.getOutputs(), union::pruneOutputs, context); - + LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context); // start prune children of union List originOutput = union.getOutput(); Set prunedOutput = prunedOutputUnion.getOutputSet(); @@ -303,6 +301,48 @@ public class ColumnPruning extends DefaultPlanRewriter implements } } + private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) { + List originOutput = union.getOutputs(); + if (originOutput.isEmpty()) { + return union; + } + List prunedOutputs = Lists.newArrayList(); + List> constantExprsList = union.getConstantExprsList(); + List> prunedConstantExprsList = Lists.newArrayList(); + List extractColumnIndex = Lists.newArrayList(); + for (int i = 0; i < originOutput.size(); i++) { + NamedExpression output = originOutput.get(i); + if (context.requiredSlots.contains(output.toSlot())) { + prunedOutputs.add(output); + extractColumnIndex.add(i); + } + } + int len = extractColumnIndex.size(); + for (List row : constantExprsList) { + ArrayList newRow = new ArrayList<>(len); + for (int idx : extractColumnIndex) { + newRow.add(row.get(idx)); + } + prunedConstantExprsList.add(newRow); + } + + if (prunedOutputs.isEmpty()) { + List candidates = Lists.newArrayList(originOutput); + candidates.retainAll(keys); + if (candidates.isEmpty()) { + candidates = originOutput; + } + NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates); + prunedOutputs = ImmutableList.of(minimumColumn); + } + + if (prunedOutputs.equals(originOutput)) { + return union; + } else { + return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList); + } + } + private

P pruneChildren(P plan) { return pruneChildren(plan, ImmutableSet.of()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java new file mode 100644 index 0000000000..3bdfbc582a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/**MergeAggregate*/ +public class MergeAggregate implements RewriteRuleFactory { + private static final ImmutableSet ALLOW_MERGE_AGGREGATE_FUNCTIONS = + ImmutableSet.of("min", "max", "sum", "any_value"); + private Map innerAggExprIdToAggFunc = new HashMap<>(); + + @Override + public List buildRules() { + return ImmutableList.of( + logicalAggregate(logicalAggregate()).when(this::canMergeAggregateWithoutProject) + .then(this::mergeTwoAggregate) + .toRule(RuleType.MERGE_AGGREGATE), + logicalAggregate(logicalProject(logicalAggregate())) + .when(this::canMergeAggregateWithProject) + .then(this::mergeAggProjectAgg) + .toRule(RuleType.MERGE_AGGREGATE)); + } + + /** + * before: + * LogicalAggregate + * +--LogicalAggregate + * after: + * LogicalAggregate + */ + private Plan mergeTwoAggregate(LogicalAggregate> outerAgg) { + LogicalAggregate innerAgg = outerAgg.child(); + + List newOutputExpressions = outerAgg.getOutputExpressions().stream() + .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) + .collect(Collectors.toList()); + return outerAgg.withAggOutput(newOutputExpressions).withChildren(innerAgg.children()); + } + + /** + * before: + * LogicalAggregate (outputExpressions = [col2, sum(col1)], groupByKeys = [col2]) + * +--LogicalProject (projects = [a as col2, col1]) + * +--LogicalAggregate (outputExpressions = [a, b, sum(c) as col1], groupByKeys = [a,b]) + * after: + * LogicalProject (projects = [a as col2, sum(col1) as sum(col1)] + * +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a]) + */ + private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { + LogicalProject> project = outerAgg.child(); + LogicalAggregate innerAgg = project.child(); + + // rewrite agg function. e.g. max(max) + List aggFunc = outerAgg.getOutputExpressions().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) + .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) + .collect(Collectors.toList()); + // rewrite agg function directly refer to the slot below the project + List replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(), + (List) aggFunc); + // replace groupByKeys directly refer to the slot below the project + List replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(), + outerAgg.getGroupByExpressions()); + List newOutputExpressions = ImmutableList.builder() + .addAll(replacedGroupBy.stream().map(slot -> (NamedExpression) slot).iterator()) + .addAll(replacedAggFunc.stream().map(alias -> (NamedExpression) alias).iterator()).build(); + // construct agg + LogicalAggregate resAgg = outerAgg.withGroupByAndOutput(replacedGroupBy, newOutputExpressions) + .withChildren(innerAgg.children()); + + // construct upper project + Map childToAlias = project.getProjects().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) + .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias)); + List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); + List upperProjects = ImmutableList.builder() + .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) + .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator()) + .build(); + return new LogicalProject(upperProjects, resAgg); + } + + private NamedExpression rewriteAggregateFunction(NamedExpression e, + Map innerAggExprIdToAggFunc) { + return (NamedExpression) e.rewriteDownShortCircuit(expr -> { + if (expr instanceof Alias && ((Alias) expr).child() instanceof AggregateFunction) { + Alias alias = (Alias) expr; + AggregateFunction aggFunc = (AggregateFunction) alias.child(); + ExprId childExprId = ((SlotReference) aggFunc.child(0)).getExprId(); + if (innerAggExprIdToAggFunc.containsKey(childExprId)) { + return new Alias(alias.getExprId(), innerAggExprIdToAggFunc.get(childExprId), + alias.getName()); + } else { + return expr; + } + } else { + return expr; + } + }); + } + + boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, + boolean sameGroupBy) { + innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) + .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), + (existValue, newValue) -> existValue)); + Set aggregateFunctions = outerAgg.getAggregateFunctions(); + for (AggregateFunction outerFunc : aggregateFunctions) { + if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) { + return false; + } + if (outerFunc.isDistinct() && !sameGroupBy) { + return false; + } + // not support outerAggFunc: sum(a+1),sum(a+b) + if (!(outerFunc.child(0) instanceof SlotReference)) { + return false; + } + ExprId childExprId = ((SlotReference) outerFunc.child(0)).getExprId(); + if (innerAggExprIdToAggFunc.containsKey(childExprId)) { + AggregateFunction innerFunc = innerAggExprIdToAggFunc.get(childExprId); + if (innerFunc.isDistinct() && !sameGroupBy) { + return false; + } + // support sum(sum),min(min),max(max),any_value(any_value),sum(count) + // sum(count) -> count() need outerAgg having group by keys (reason: nullable) + if (!(outerFunc.getName().equals("sum") && innerFunc.getName().equals("count") + && !outerAgg.getGroupByExpressions().isEmpty()) + && !innerFunc.getName().equals(outerFunc.getName())) { + return false; + } + } else { + // select a, max(b), min(b), any_value(b) from (select a,b from t1 group by a, b) group by a; + // equals select a, max(b), min(b), any_value(b) from t1 group by a; + if (!outerFunc.getName().equals("max") + && !outerFunc.getName().equals("min") + && !outerFunc.getName().equals("any_value")) { + return false; + } + } + } + return true; + } + + private boolean canMergeAggregateWithoutProject(LogicalAggregate> outerAgg) { + LogicalAggregate innerAgg = outerAgg.child(); + if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAgg.getGroupByExpressions())) { + return false; + } + boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); + + return commonCheck(outerAgg, innerAgg, sameGroupBy); + } + + private boolean canMergeAggregateWithProject(LogicalAggregate>> outerAgg) { + LogicalProject> project = outerAgg.child(); + LogicalAggregate innerAgg = project.child(); + + List outerAggGroupByKeys = PlanUtils.replaceExpressionByProjections(project.getProjects(), + outerAgg.getGroupByExpressions()); + if (!new HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAggGroupByKeys)) { + return false; + } + // project cannot have expressions like a+1 + if (ExpressionUtils.anyMatch(project.getProjects(), + expr -> !(expr instanceof SlotReference) && !(expr instanceof Alias))) { + return false; + } + boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); + return commonCheck(outerAgg, innerAgg, sameGroupBy); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java index 3a88020ac9..dac6996c0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java @@ -165,6 +165,12 @@ public class LogicalUnion extends LogicalSetOperation implements Union, OutputPr hasPushedFilter, Optional.empty(), Optional.empty(), children); } + public LogicalUnion withNewOutputsAndConstExprsList(List newOutputs, + List> constantExprsList) { + return new LogicalUnion(qualifier, newOutputs, regularChildrenOutputs, constantExprsList, + hasPushedFilter, Optional.empty(), Optional.empty(), children); + } + public LogicalUnion withChildrenAndConstExprsList(List children, List> childrenOutputs, List> constantExprsList) { return new LogicalUnion(qualifier, outputs, childrenOutputs, constantExprsList, hasPushedFilter, children); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 3e8d5cd1d9..a4e25e2141 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -100,6 +100,12 @@ public class PlanUtils { return ExpressionUtils.replaceNamedExpressions(parentProjects, replaceMap); } + public static List replaceExpressionByProjections(List childProjects, + List targetExpression) { + Map replaceMap = ExpressionUtils.generateReplaceMap(childProjects); + return ExpressionUtils.replace(targetExpression, replaceMap); + } + public static Plan skipProjectFilterLimit(Plan plan) { if (plan instanceof LogicalProject && ((LogicalProject) plan).isAllSlots() || plan instanceof LogicalFilter || plan instanceof LogicalLimit) { diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out new file mode 100644 index 0000000000..ba5b127a56 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -0,0 +1,248 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sumCount_empty_table -- +\N + +-- !maxMax_minMin_sumSum_sumCount -- +1 1 1 1 +2 2 2 1 +6 6 6 1 +7 2 20 5 +8 6 26 4 +8 8 8 1 +9 5 20 3 + +-- !maxGroupKey_minGroupKey -- +\N \N 6 6 +1 1 2 1 +2 2 3 3 +3 3 2 1 +4 4 2 2 +5 5 4 3 +7 7 6 6 + +-- !agg_project_agg -- +\N \N \N 6 1 +1 1 1 20 5 +2 2 2 8 1 +3 3 3 20 3 +4 4 4 2 1 +5 5 5 26 4 +7 7 7 1 1 + +-- !upper_plan_can_use_name -- +2 +3 +7 +8 +9 +9 +10 + +-- !outer_agg_has_distinct_same_keys -- +1 1 1 1 +2 2 2 1 +4 2 6 2 +6 6 6 1 +6 6 6 1 +6 6 6 1 +7 3 14 3 +8 6 20 3 +8 8 8 1 +9 5 14 2 + +-- !inner_agg_has_distinct_same_keys -- +1 1 1 1 +2 2 2 1 +4 2 6 2 +6 6 6 1 +6 6 6 1 +6 6 6 1 +7 3 14 3 +8 6 14 3 +8 8 8 1 +9 5 14 2 + +-- !sumCount_empty_table_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test2] + +-- !agg_project_agg_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test1] + +-- !maxMax_minMin_sumSum_sumCount_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test1] + +-- !maxGroupKey_minGroupKey_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test1] + +-- !outer_agg_has_distinct_same_keys_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[LOCAL] +------------PhysicalOlapScan[mal_test1] + +-- !inner_agg_has_distinct_same_keys_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[DISTINCT_LOCAL] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------PhysicalOlapScan[mal_test1] + +-- !middle_project_has_expression_cannot_merge_shape1 -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[LOCAL] +--------------------PhysicalProject +----------------------PhysicalOlapScan[mal_test1] + +-- !middle_project_has_expression_cannot_merge_shape2 -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[LOCAL] +--------------------PhysicalOlapScan[mal_test1] + +-- !maxGroupKey_minGroupKey_sumGroupKey_cannot_merge_shape -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------hashAgg[LOCAL] +--------------PhysicalProject +----------------PhysicalOlapScan[mal_test1] + +-- !maxMin_cannot_merge_shape -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------hashAgg[LOCAL] +----------------PhysicalOlapScan[mal_test1] + +-- !group_key_not_contain_cannot_merge_shape -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------hashAgg[LOCAL] +----------------PhysicalOlapScan[mal_test1] + +-- !outer_agg_has_distinct_cannot_merge_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[LOCAL] +--------------------PhysicalOlapScan[mal_test1] + +-- !inner_agg_has_distinct_cannot_merge_shape -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[DISTINCT_LOCAL] +--------------------hashAgg[GLOBAL] +----------------------hashAgg[LOCAL] +------------------------PhysicalOlapScan[mal_test1] + +-- !agg_with_expr_cannot_merge_shape1 -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[LOCAL] +--------------------PhysicalProject +----------------------PhysicalOlapScan[mal_test1] + +-- !agg_with_expr_cannot_merge_shape2 -- +PhysicalResultSink +--PhysicalQuickSort[MERGE_SORT] +----PhysicalDistribute[DistributionSpecGather] +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashAgg[LOCAL] +--------------------PhysicalProject +----------------------PhysicalOlapScan[mal_test1] + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy new file mode 100644 index 0000000000..44c256e2f5 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("merge_aggregate") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS mal_test1 + """ + + sql """ + create table mal_test1(pk int, a int, b int) distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + + sql """ + insert into mal_test1 values(2,1,3),(1,1,2),(3,5,6),(6,null,6),(4,5,6),(2,1,4),(2,3,5),(1,1,4) + ,(3,5,6),(3,5,null),(6,7,1),(2,1,7),(2,4,2),(2,3,9),(1,3,6),(3,5,8),(3,2,8); + """ + sql "drop table if exists mal_test2" + sql """ + create table mal_test2(pk int, a int, b int) distributed by hash(pk) buckets 10 + properties('replication_num' = '1'); + """ + + sql "sync" + + + qt_sumCount_empty_table """ + select sum(col) from (select count(a) col from mal_test2 group by a) t; + """ + + qt_maxMax_minMin_sumSum_sumCount """ + select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 + from mal_test1 group by pk,a) t group by a order by 1,2,3,4; + """ + + qt_maxGroupKey_minGroupKey """ + select max(a),min(a),max(pk),min(pk) from + (select pk,a from mal_test1 group by pk,a) t + group by a order by 1,2,3,4; + """ + + qt_agg_project_agg """ + select col2, max(col2),min(col2),sum(col3),sum(col4) from + (select pk as col1,a as col2,sum(b) col3, count(b) col4 from mal_test1 group by pk,a) t + group by col2 order by 1,2,3,4; + """ + + qt_upper_plan_can_use_name """ + select c1+1 from ( + select max(col1) c1, min(col2) c2, sum(col3) c3, sum(col4) c4 from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a order by 1,2,3,4) outert order by 1; + """ + + qt_outer_agg_has_distinct_same_keys """ + select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by pk,a order by 1,2,3,4; + """ + + qt_inner_agg_has_distinct_same_keys """ + select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a,pk order by 1,2,3,4; + """ + + qt_sumCount_empty_table_shape """ + explain shape plan select sum(col) from (select count(a) col from mal_test2 group by a) t; + """ + + qt_agg_project_agg_shape """ + explain shape plan select max(col2),min(col2),sum(col3),sum(col4) from + (select pk as col1,a as col2,sum(b) col3, count(b) col4 from mal_test1 group by pk,a) t + group by col2 order by 1,2,3,4; + """ + + qt_maxMax_minMin_sumSum_sumCount_shape """ + explain shape plan select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 + from mal_test1 group by pk,a) t group by a order by 1,2,3,4; + """ + + qt_maxGroupKey_minGroupKey_shape """ + explain shape plan select max(a),min(a),max(pk),min(pk) from + (select pk,a from mal_test1 group by pk,a) t + group by a order by 1,2,3,4; + """ + + qt_outer_agg_has_distinct_same_keys_shape """ + explain shape plan + select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by pk,a order by 1,2,3,4; + """ + + qt_inner_agg_has_distinct_same_keys_shape """ + explain shape plan + select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a,pk order by 1,2,3,4; + """ + + qt_middle_project_has_expression_cannot_merge_shape1 """ + explain shape plan + select max(col1),min(col1) from + (select pk+1 as col1,a from mal_test1 group by pk,a) t + group by a order by 1,2; + """ + + qt_middle_project_has_expression_cannot_merge_shape2 """ + explain shape plan + select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b)+1 as col1, min(b) as col2, sum(b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a order by 1,2,3,4; + """ + + qt_maxGroupKey_minGroupKey_sumGroupKey_cannot_merge_shape """ + explain shape plan select max(a),min(a),max(pk),min(pk),sum(pk) from + (select pk,a from mal_test1 group by pk,a) t + group by a; + """ + + qt_maxMin_cannot_merge_shape """ + explain shape plan select max(col), max(col2) from + (select pk,a,min(b) col,max(b) col2 from mal_test1 group by pk,a) t + group by a; + """ + + qt_group_key_not_contain_cannot_merge_shape """ + explain shape plan select max(col2) from + (select pk,a,max(b) col2 from mal_test1 group by pk,a) t + group by a,col2; + """ + + qt_outer_agg_has_distinct_cannot_merge_shape """ + explain shape plan + select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a order by 1,2,3,4; + """ + + qt_inner_agg_has_distinct_cannot_merge_shape """ + explain shape plan + select max(col1), min(col2), sum(col3), sum(col4) from + (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, count(b) as col4 from mal_test1 group by pk,a) t + group by a order by 1,2,3,4; + """ + + qt_agg_with_expr_cannot_merge_shape1 """ + explain shape plan select max(col1+a),min(col1) from + (select pk as col1, a from mal_test1 group by pk,a) t + group by a order by 1,2; + """ + + qt_agg_with_expr_cannot_merge_shape2 """ + explain shape plan select max(col1+1),min(col1) from + (select pk as col1, a from mal_test1 group by pk,a) t + group by a order by 1,2; + """ + +}