[refactor](nereids) make NormalizeAggregate rule more clear and readable (#28607)
This commit is contained in:
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -40,9 +41,9 @@ import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@ -100,22 +101,94 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
|
||||
// The LogicalAggregate node may contain window agg functions and usual agg functions
|
||||
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
|
||||
// This rule simplify LogicalAggregate node by:
|
||||
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
|
||||
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
|
||||
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
|
||||
// Push down exprs:
|
||||
// 1. all group by exprs
|
||||
// 2. child contains subquery expr in trival-agg
|
||||
// 3. child contains window expr in trival-agg
|
||||
// 4. all input slots of trival-agg
|
||||
// 5. expr(including subquery) in distinct trival-agg
|
||||
// Normalize LogicalAggregate's output.
|
||||
// 1. normalize group by exprs by outputs of bottom LogicalProject
|
||||
// 2. normalize trival-aggs by outputs of bottom LogicalProject
|
||||
// 3. build normalized agg outputs
|
||||
// Pull up exprs:
|
||||
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
|
||||
// 1. simple slots
|
||||
// 2. aliases
|
||||
// a. alias with no aggs child
|
||||
// b. alias with trival-agg child
|
||||
// c. alias with window-agg
|
||||
|
||||
// Push down exprs:
|
||||
// collect group by exprs
|
||||
Set<Expression> groupingByExprs =
|
||||
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
||||
|
||||
// collect all trival-agg
|
||||
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
|
||||
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
||||
|
||||
List<AggregateFunction> aggFuncs = Lists.newArrayList();
|
||||
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
|
||||
|
||||
// we need push down subquery exprs inside non-window and non-distinct agg functions
|
||||
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
|
||||
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
|
||||
SubqueryExpr.class::isInstance);
|
||||
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
||||
// split non-distinct agg child as two part
|
||||
// TRUE part 1: need push down itself, if it contains subqury or window expression
|
||||
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
|
||||
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
|
||||
.filter(aggFunc -> !aggFunc.isDistinct())
|
||||
.flatMap(agg -> agg.children().stream())
|
||||
.collect(Collectors.groupingBy(
|
||||
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
|
||||
Collectors.toSet()));
|
||||
|
||||
// split distinct agg child as two parts
|
||||
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
|
||||
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
|
||||
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
|
||||
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
|
||||
.collect(Collectors.groupingBy(
|
||||
child -> !(child instanceof SlotReference || child instanceof Literal),
|
||||
Collectors.toSet()));
|
||||
|
||||
Set<Expression> needPushSelf = Sets.union(
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
|
||||
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
|
||||
|
||||
Set<Alias> existsAlias =
|
||||
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
||||
|
||||
// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
|
||||
// 1. group by exprs
|
||||
// 2. trivalAgg children
|
||||
// 3. trivalAgg input slots
|
||||
Set<Expression> allPushDownExprs =
|
||||
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
|
||||
NormalizeToSlotContext bottomSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
|
||||
Set<NamedExpression> bottomOutputs =
|
||||
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
|
||||
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
|
||||
Set<NamedExpression> pushedGroupByExprs =
|
||||
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
|
||||
Set<NamedExpression> pushedTrivalAggChildren =
|
||||
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
|
||||
Set<NamedExpression> pushedTrivalAggInputSlots =
|
||||
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
|
||||
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
|
||||
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));
|
||||
|
||||
// create bottom project
|
||||
Plan bottomPlan;
|
||||
if (!bottomProjects.isEmpty()) {
|
||||
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
|
||||
aggregate.child());
|
||||
} else {
|
||||
bottomPlan = aggregate.child();
|
||||
}
|
||||
|
||||
// use group by context to normalize agg functions to process
|
||||
// sql like: select sum(a + 1) from t group by a + 1
|
||||
@ -127,89 +200,37 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
// after normalize:
|
||||
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
|
||||
// +-- project((a[#0] + 1)[#1])
|
||||
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
|
||||
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
|
||||
// TODO: if we have distinct agg, we must push down its children,
|
||||
// because need use it to generate distribution enforce
|
||||
// step 1: split agg functions into 2 parts: distinct and not distinct
|
||||
List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
|
||||
List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
|
||||
for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
|
||||
if (aggregateFunction.isDistinct()) {
|
||||
distinctAggFuncs.add(aggregateFunction);
|
||||
} else {
|
||||
nonDistinctAggFuncs.add(aggregateFunction);
|
||||
}
|
||||
}
|
||||
// step 2: if we only have one distinct agg function, we do push down for it
|
||||
if (!distinctAggFuncs.isEmpty()) {
|
||||
// process distinct normalize and put it back to normalizedAggFuncs
|
||||
List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList();
|
||||
Map<Expression, Expression> replaceMap = Maps.newHashMap();
|
||||
Map<Expression, NamedExpression> aliasCache = Maps.newHashMap();
|
||||
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (Expression child : distinctAggFunc.children()) {
|
||||
if (child instanceof SlotReference || child instanceof Literal) {
|
||||
newChildren.add(child);
|
||||
} else {
|
||||
NamedExpression alias;
|
||||
if (aliasCache.containsKey(child)) {
|
||||
alias = aliasCache.get(child);
|
||||
} else {
|
||||
alias = new Alias(child);
|
||||
aliasCache.put(child, alias);
|
||||
}
|
||||
bottomProjects.add(alias);
|
||||
newChildren.add(alias.toSlot());
|
||||
}
|
||||
}
|
||||
AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren);
|
||||
replaceMap.put(distinctAggFunc, newDistinctAggFunc);
|
||||
newDistinctAggFuncs.add(newDistinctAggFunc);
|
||||
}
|
||||
aggregateOutput = aggregateOutput.stream()
|
||||
.map(e -> ExpressionUtils.replace(e, replaceMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
distinctAggFuncs = newDistinctAggFuncs;
|
||||
}
|
||||
normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
|
||||
normalizedAggFuncs.addAll(distinctAggFuncs);
|
||||
// TODO: process redundant expressions in aggregate functions children
|
||||
|
||||
// normalize group by exprs by bottomProjects
|
||||
List<Expression> normalizedGroupExprs =
|
||||
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
||||
|
||||
// normalize trival-aggs by bottomProjects
|
||||
List<AggregateFunction> normalizedAggFuncs =
|
||||
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
|
||||
|
||||
// build normalized agg output
|
||||
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
|
||||
// agg output include 2 part, normalized group by slots and normalized agg functions
|
||||
|
||||
// agg output include 2 parts
|
||||
// pushedGroupByExprs and normalized agg functions
|
||||
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
|
||||
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
|
||||
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
|
||||
.addAll(normalizedAggFuncsToSlotContext
|
||||
.pushDownToNamedExpression(normalizedAggFuncs))
|
||||
.build();
|
||||
// add normalized agg's input slots to bottom projects
|
||||
Set<Slot> bottomProjectSlots = bottomProjects.stream()
|
||||
.map(NamedExpression::toSlot)
|
||||
.collect(Collectors.toSet());
|
||||
Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.filter(e -> !bottomProjectSlots.contains(e))
|
||||
.collect(Collectors.toSet());
|
||||
bottomProjects.addAll(aggInputSlots);
|
||||
// build group by exprs
|
||||
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
|
||||
|
||||
Plan bottomPlan;
|
||||
if (!bottomProjects.isEmpty()) {
|
||||
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
|
||||
} else {
|
||||
bottomPlan = aggregate.child();
|
||||
}
|
||||
// create new agg node
|
||||
LogicalAggregate newAggregate =
|
||||
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
|
||||
|
||||
// create upper projects by normalize all output exprs in old LogicalAggregate
|
||||
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
|
||||
bottomSlotContext, normalizedAggFuncsToSlotContext);
|
||||
|
||||
return new LogicalProject<>(upperProjects,
|
||||
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
|
||||
// create a parent project node
|
||||
return new LogicalProject<>(upperProjects, newAggregate);
|
||||
}).toRule(RuleType.NORMALIZE_AGGREGATE);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user