[fix](nereids)only push down subquery in non-window agg functions (#26034)

This commit is contained in:
starocean999
2023-10-30 11:32:10 +08:00
committed by GitHub
parent f6a2faf967
commit 00c30f075f
3 changed files with 33 additions and 9 deletions

View File

@ -104,11 +104,13 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// we need push down subquery exprs in side non-distinct agg functions
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(
Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput,
expr -> expr instanceof AggregateFunction
&& !((AggregateFunction) expr).isDistinct())),
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// we need push down subquery exprs inside non-window and non-distinct agg functions
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext bottomSlotContext =
@ -116,8 +118,6 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//

View File

@ -54,3 +54,15 @@ true
-- !sql16 --
12
-- !sql17 --
12
12
-- !sql18 --
12
12
-- !sql20 --
5
7

View File

@ -117,11 +117,23 @@ suite("test_subquery_in_project") {
"""
qt_sql15 """
select sum(age + (select sum(age) from test_sql)) from test_sql;
select sum(age + (select sum(age) from test_sql)) from test_sql order by 1;
"""
qt_sql16 """
select sum(distinct age + (select sum(age) from test_sql)) from test_sql;
select sum(distinct age + (select sum(age) from test_sql)) from test_sql order by 1;
"""
qt_sql17 """
select sum(age + (select sum(age) from test_sql)) over() from test_sql order by 1;
"""
qt_sql18 """
select sum(age + (select sum(age) from test_sql)) over() from test_sql group by dt, age order by 1;
"""
qt_sql20 """
select sum(age + (select sum(age) from test_sql)) from test_sql group by dt, age order by 1;
"""
sql """drop table if exists test_sql;"""