[fix](nereids) push down subquery exprs in non-distinct agg functions (#25955)
This commit is contained in:
@ -20,12 +20,14 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
@ -101,11 +103,17 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
|
||||
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
|
||||
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
||||
// we need push down subquery exprs in side non-distinct agg functions
|
||||
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(
|
||||
Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput,
|
||||
expr -> expr instanceof AggregateFunction
|
||||
&& !((AggregateFunction) expr).isDistinct())),
|
||||
SubqueryExpr.class::isInstance);
|
||||
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
||||
NormalizeToSlotContext groupByToSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs);
|
||||
Set<NamedExpression> bottomGroupByProjects =
|
||||
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);
|
||||
NormalizeToSlotContext bottomSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
|
||||
Set<NamedExpression> bottomOutputs =
|
||||
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
|
||||
|
||||
List<AggregateFunction> aggFuncs = Lists.newArrayList();
|
||||
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
|
||||
@ -119,8 +127,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
// after normalize:
|
||||
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
|
||||
// +-- project((a[#0] + 1)[#1])
|
||||
List<AggregateFunction> normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
|
||||
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomGroupByProjects);
|
||||
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
|
||||
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
|
||||
// TODO: if we have distinct agg, we must push down its children,
|
||||
// because need use it to generate distribution enforce
|
||||
// step 1: split agg functions into 2 parts: distinct and not distinct
|
||||
@ -174,7 +182,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
|
||||
// agg output include 2 part, normalized group by slots and normalized agg functions
|
||||
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
|
||||
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
|
||||
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
|
||||
.build();
|
||||
// add normalized agg's input slots to bottom projects
|
||||
@ -188,7 +196,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
.collect(Collectors.toSet());
|
||||
bottomProjects.addAll(aggInputSlots);
|
||||
// build group by exprs
|
||||
List<Expression> normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
||||
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
||||
|
||||
Plan bottomPlan;
|
||||
if (!bottomProjects.isEmpty()) {
|
||||
@ -198,7 +206,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
}
|
||||
|
||||
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
|
||||
groupByToSlotContext, normalizedAggFuncsToSlotContext);
|
||||
bottomSlotContext, normalizedAggFuncsToSlotContext);
|
||||
|
||||
return new LogicalProject<>(upperProjects,
|
||||
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
|
||||
|
||||
Reference in New Issue
Block a user