From 00c30f075ffdde6c7232dec7019fcf12bbe724f8 Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Mon, 30 Oct 2023 11:32:10 +0800 Subject: [PATCH] [fix](nereids)only push down subquery in non-window agg functions (#26034) --- .../rules/analysis/NormalizeAggregate.java | 14 +++++++------- .../subquery/test_subquery_in_project.out | 12 ++++++++++++ .../subquery/test_subquery_in_project.groovy | 16 ++++++++++++++-- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 0acd366f1c..203a7f0969 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -104,11 +104,13 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali List aggregateOutput = aggregate.getOutputExpressions(); Set existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); - // we need push down subquery exprs in side non-distinct agg functions - Set subqueryExprs = ExpressionUtils.mutableCollect( - Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput, - expr -> expr instanceof AggregateFunction - && !((AggregateFunction) expr).isDistinct())), + + List aggFuncs = Lists.newArrayList(); + aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + + // we need push down subquery exprs inside non-window and non-distinct agg functions + Set subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream() + .filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()), SubqueryExpr.class::isInstance); Set groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); NormalizeToSlotContext bottomSlotContext = @@ -116,8 +118,6 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali Set bottomOutputs = bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs)); - List aggFuncs = Lists.newArrayList(); - aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); // use group by context to normalize agg functions to process // sql like: select sum(a + 1) from t group by a + 1 // diff --git a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out index 4d8bd4c736..c0d289b51f 100644 --- a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out +++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out @@ -54,3 +54,15 @@ true -- !sql16 -- 12 +-- !sql17 -- +12 +12 + +-- !sql18 -- +12 +12 + +-- !sql20 -- +5 +7 + diff --git a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy index b9de14e530..25920848c6 100644 --- a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy +++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy @@ -117,11 +117,23 @@ suite("test_subquery_in_project") { """ qt_sql15 """ - select sum(age + (select sum(age) from test_sql)) from test_sql; + select sum(age + (select sum(age) from test_sql)) from test_sql order by 1; """ qt_sql16 """ - select sum(distinct age + (select sum(age) from test_sql)) from test_sql; + select sum(distinct age + (select sum(age) from test_sql)) from test_sql order by 1; + """ + + qt_sql17 """ + select sum(age + (select sum(age) from test_sql)) over() from test_sql order by 1; + """ + + qt_sql18 """ + select sum(age + (select sum(age) from test_sql)) over() from test_sql group by dt, age order by 1; + """ + + qt_sql20 """ + select sum(age + (select sum(age) from test_sql)) from test_sql group by dt, age order by 1; """ sql """drop table if exists test_sql;"""