diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index c4e67d6bc1..c356d80263 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -303,7 +303,9 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eliminate GroupBy", topDown(new EliminateGroupBy(), - new MergeAggregate()) + new MergeAggregate(), + // need to adjust min/max/sum nullable attribute after merge aggregate + new AdjustAggregateNullableForEmptySet()) ), topic("Eager aggregation", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 9a0b9f8b5e..a2c23dd9b4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory { private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { LogicalProject> project = outerAgg.child(); LogicalAggregate innerAgg = project.child(); - + List outputExpressions = outerAgg.getOutputExpressions(); + List replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( + project.getProjects(), (List) outputExpressions); // rewrite agg function. e.g. max(max) - List aggFunc = outerAgg.getOutputExpressions().stream() + List replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); - // rewrite agg function directly refer to the slot below the project - List replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(), - (List) aggFunc); // replace groupByKeys directly refer to the slot below the project List replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(), outerAgg.getGroupByExpressions()); @@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory { } boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, - boolean sameGroupBy) { + boolean sameGroupBy, Optional projectOptional) { innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), (existValue, newValue) -> existValue)); Set aggregateFunctions = outerAgg.getAggregateFunctions(); - for (AggregateFunction outerFunc : aggregateFunctions) { + List replacedAggFunctions = projectOptional.map(project -> + (List) PlanUtils.replaceExpressionByProjections( + projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions))) + .orElse(new ArrayList<>(aggregateFunctions)); + for (AggregateFunction outerFunc : replacedAggFunctions) { if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) { return false; } @@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory { } boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); - return commonCheck(outerAgg, innerAgg, sameGroupBy); + return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty()); } private boolean canMergeAggregateWithProject(LogicalAggregate>> outerAgg) { @@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory { return false; } boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); - return commonCheck(outerAgg, innerAgg, sameGroupBy); + return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project)); } } diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index ba5b127a56..fba17e8d7b 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -246,3 +246,54 @@ PhysicalResultSink --------------------PhysicalProject ----------------------PhysicalOlapScan[mal_test1] +-- !test_has_project_distinct_cant_transform -- +1 + +-- !test_has_project_distinct_cant_transform_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_distinct_expr_transform -- +-1 + +-- !test_distinct_expr_transform_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalProject +----------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_has_project_distinct_expr_transform -- +1 +1 +1 + +-- !test_has_project_distinct_expr_transform -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_sum_empty_table -- +\N \N \N + +-- !test_sum_empty_table_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalOlapScan[mal_test2] + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 44c256e2f5..46cd4a0a9b 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -174,4 +174,84 @@ suite("merge_aggregate") { group by a order by 1,2; """ + sql "drop table if exists mal_test_merge_agg" + sql """ + create table mal_test_merge_agg( + k1 int null, + k2 int not null, + k3 string null, + k4 varchar(100) null + ) + duplicate key (k1,k2) + distributed BY hash(k1) buckets 3 + properties("replication_num" = "1"); + """ + sql "insert into mal_test_merge_agg select 1,1,'1','a';" + sql "insert into mal_test_merge_agg select 2,2,'2','b';" + sql "insert into mal_test_merge_agg select 3,-3,null,'c';" + sql "sync" + + qt_test_has_project_distinct_cant_transform """ + select max(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + qt_test_has_project_distinct_cant_transform_shape """ + explain shape plan + select max(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + + qt_test_distinct_expr_transform """ + select max(count_col) + from ( + select k4, + max(-abs(k1)) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + qt_test_distinct_expr_transform_shape """ + explain shape plan + select max(count_col) + from ( + select k4, + max(-abs(k1)) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + + qt_test_has_project_distinct_expr_transform """ + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ + + qt_test_has_project_distinct_expr_transform """ + explain shape plan + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ + + qt_test_sum_empty_table """ + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ + + qt_test_sum_empty_table_shape """ + explain shape plan + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ }