[opt](Nereids)(WIP) optimize agg and window normalization step 2 #19305

1. refactor aggregate normalization to avoid data amplification before aggregate
2. remove useless aggreagte processing in ExtractAndNormalizeWindowExpression
3. only push distinct aggregate function children

TODO:
1. push down redundant expression in aggregate functions
2. refactor normalize repeat rule
3. move expression normalization and optimization after plan normalization to avoid unexpected expression optimization.
This commit is contained in:
Zhang Wenxin
2023-05-12 14:00:13 +08:00
committed by GitHub
parent 0477a9f5de
commit a1da57c63e
14 changed files with 287 additions and 293 deletions

View File

@ -84,27 +84,27 @@ import java.util.List;
*/
public class NereidsRewriter extends BatchRewriteJob {
private static final List<RewriteJob> REWRITE_JOBS = jobs(
topic("Normalization",
topic("Plan Normalization",
topDown(
new EliminateOrderByConstant(),
new EliminateGroupByConstant(),
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
// rewrite expressions, no depends
// TODO: we should do expression normalization after plan normalization
// because some rewritten depends on sub expression tree matching
// such as group by key matching and replaced
// but we need to do some normalization before subquery unnesting,
// such as extract common expression.
new ExpressionNormalization(),
new ExpressionOptimization(),
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
new ExtractFilterFromCrossJoin()
),
// ExtractSingleTableExpressionFromDisjunction conflict to InPredicateToEqualToRule
// in the ExpressionNormalization, so must invoke in another job, or else run into
// dead loop
topDown(
// ExtractSingleTableExpressionFromDisjunction conflict to InPredicateToEqualToRule
// in the ExpressionNormalization, so must invoke in another job, or else run into
// dead loop
new ExtractSingleTableExpressionFromDisjunction()
)
),
@ -131,15 +131,15 @@ public class NereidsRewriter extends BatchRewriteJob {
)
),
// we should eliminate hint again because some hint maybe exist in the CTE or subquery.
// so this rule should invoke after "Subquery unnesting"
custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new),
// please note: this rule must run before NormalizeAggregate
topDown(
new AdjustAggregateNullableForEmptySet()
),
// we should eliminate hint again because some hint maybe exist in the CTE or subquery.
// so this rule should invoke after "Subquery unnesting"
custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new),
// The rule modification needs to be done after the subquery is unnested,
// because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase,
// but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained,

View File

@ -49,7 +49,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
logicalProject().then(project -> {
boolean needGlobalAggregate = project.getProjects()
.stream()
.anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE, null));
.anyMatch(p -> p.accept(ContainsAggregateChecker.INSTANCE, null));
if (needGlobalAggregate) {
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
@ -60,9 +60,9 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
);
}
private static class NeedAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
private static final NeedAggregateChecker INSTANCE = new NeedAggregateChecker();
private static final ContainsAggregateChecker INSTANCE = new ContainsAggregateChecker();
@Override
public Boolean visit(Expression expr, Void context) {

View File

@ -58,6 +58,7 @@ import java.util.stream.Collectors;
* function binder
*/
public class FunctionBinder extends AbstractExpressionRewriteRule {
public static final FunctionBinder INSTANCE = new FunctionBinder();
@Override

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@ -215,6 +216,25 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return canNotPush;
}
// TODO: refactor this to process slot reference or expression together
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
if (argument instanceof Cast) {
return argument.child(0) instanceof SlotReference
&& argument.getDataType().isNumericType()
&& argument.child(0).getDataType().isNumericType();
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return canNotPush;
}
// we already normalize the arguments to slotReference
List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
@ -228,7 +248,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.collect(ImmutableList.toImmutableList());
}
boolean onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {

View File

@ -57,8 +57,12 @@ public class EliminateGroupByConstant extends OneRewriteRuleFactory {
Set<Expression> slotGroupByExprs = Sets.newLinkedHashSet();
Expression lit = null;
for (Expression expression : groupByExprs) {
expression = FoldConstantRule.INSTANCE.rewrite(expression, context);
if (!(expression instanceof Literal)) {
// NOTICE: we should not use the expression after fold as new aggregate's output or group expr
// because we rely on expression matching to replace subtree that same as group by expr in output
// if we do constant folding before normalize aggregate, the subtree will change and matching fail
// such as: select a + 1 + 2 + 3, sum(b) from t group by a + 1 + 2
Expression foldExpression = FoldConstantRule.INSTANCE.rewrite(expression, context);
if (!(foldExpression instanceof Literal)) {
slotGroupByExprs.add(expression);
} else {
lit = expression;

View File

@ -25,9 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -41,7 +39,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* extract window expressions from LogicalProject.projects and Normalize LogicalWindow
* extract window expressions from LogicalProject#projects and Normalize LogicalWindow
*/
public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory implements NormalizeToSlot {
@ -60,15 +58,8 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
if (bottomProjects.isEmpty()) {
normalizedChild = project.child();
} else {
boolean needAggregate = bottomProjects.stream().anyMatch(expr ->
expr.anyMatch(AggregateFunction.class::isInstance));
if (needAggregate) {
normalizedChild = new LogicalAggregate<>(ImmutableList.of(),
ImmutableList.copyOf(bottomProjects), project.child());
} else {
normalizedChild = project.withProjectsAndChild(
ImmutableList.copyOf(bottomProjects), project.child());
}
normalizedChild = project.withProjectsAndChild(
ImmutableList.copyOf(bottomProjects), project.child());
}
// 2. handle window's outputs and windowExprs
@ -96,35 +87,32 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
// bottomProjects includes:
// 1. expressions from function and WindowSpec's partitionKeys and orderKeys
// 2. other slots of outputExpressions
/*
avg(c) / sum(a+1) over (order by avg(b)) group by a
win(x/sum(z) over y)
prj(x, y, a+1 as z)
agg(avg(c) x, avg(b) y, a)
proj(a b c)
toBePushDown = {avg(c), a+1, avg(b)}
*/
//
// avg(c) / sum(a+1) over (order by avg(b)) group by a
// win(x/sum(z) over y)
// prj(x, y, a+1 as z)
// agg(avg(c) x, avg(b) y, a)
// proj(a b c)
// toBePushDown = {avg(c), a+1, avg(b)}
return expressions.stream()
.flatMap(expression -> {
if (expression.anyMatch(WindowExpression.class::isInstance)) {
Set<Slot> inputSlots = expression.getInputSlots().stream().collect(Collectors.toSet());
Set<Slot> inputSlots = Sets.newHashSet(expression.getInputSlots());
Set<WindowExpression> collects = expression.collect(WindowExpression.class::isInstance);
Set<Slot> windowInputSlots = collects.stream().flatMap(
win -> win.getInputSlots().stream()
).collect(Collectors.toSet());
/*
substr(
ref_1.cp_type,
max(
cast(ref_1.`cp_catalog_page_number` as int)) over (...)
),
1)
in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in
substr.inputSlots, but not in windowExpression.inputSlots
inputSlots= {ref_1.cp_type}
*/
Set<Slot> windowInputSlots = collects.stream()
.flatMap(win -> win.getInputSlots().stream())
.collect(Collectors.toSet());
// substr(
// ref_1.cp_type,
// max(
// cast(ref_1.`cp_catalog_page_number` as int)) over (...)
// ),
// 1)
//
// in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in
// substr.inputSlots, but not in windowExpression.inputSlots
//
// inputSlots= {ref_1.cp_type}
inputSlots.removeAll(windowInputSlots);
return Stream.concat(
collects.stream().flatMap(windowExpression ->

View File

@ -20,14 +20,14 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -36,12 +36,12 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
@ -95,173 +95,144 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// push expression to bottom project
Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
aggregate.getOutputExpressions(), Alias.class::isInstance);
Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow(
aggregate.getOutputExpressions());
Set<Expression> existsAggAlias = existsAliases.stream().map(UnaryNode::child)
.filter(AggregateFunction.class::isInstance)
.collect(Collectors.toSet());
/*
* agg-functions inside window function is regarded as an output of aggregate.
* select sum(avg(c)) over ...
* is regarded as
* select avg(c), sum(avg(c)) over ...
*
* the plan:
* project(sum(y) over)
* Aggregate(avg(c) as y)
*
* after Aggregate, the 'y' is removed by upper project.
*
* aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
*/
List<Alias> aliasOfAggFunInWindowUsedAsAggOutput = Lists.newArrayList();
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext groupByToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs);
Set<NamedExpression> bottomGroupByProjects =
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);
for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
if (!existsAggAlias.contains(aggFun)) {
Alias alias = new Alias(aggFun, aggFun.toSql());
existsAliases.add(alias);
aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//
// before normalize:
// agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
// +-- project(a[#0], (a[#0] + 1)[#1])
//
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
List<NamedExpression> bottomProjects = Lists.newArrayList(bottomGroupByProjects);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
if (aggregateFunction.isDistinct()) {
distinctAggFuncs.add(aggregateFunction);
} else {
nonDistinctAggFuncs.add(aggregateFunction);
}
}
Set<Expression> needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
NormalizeToSlotContext groupByAndArgumentToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases, needToSlots);
Set<NamedExpression> bottomProjects =
groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots);
Plan normalizedChild = bottomProjects.isEmpty()
? aggregate.child()
: new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
// begin normalize aggregate
// replace groupBy and arguments of aggregate function to slot, may be this output contains
// some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace
// the sum(value) to slot and move the `slot + 1` to the upper project later.
List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
aggregate.getOutputExpressions().stream(),
aliasOfAggFunInWindowUsedAsAggOutput.stream())
.map(expr -> groupByAndArgumentToSlotContext
.normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance))
.collect(Collectors.toList());
Set<Slot> windowInputSlots = collectWindowInputSlots(aggregate.getOutputExpressions());
Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
itemsInWindow.addAll(aggregateFunctionsInWindow);
NormalizeToSlotContext windowToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases, itemsInWindow);
normalizeOutputPhase1 = normalizeOutputPhase1.stream()
.map(expr -> windowToSlotContext
.normalizeToUseSlotRefDown(expr, WindowExpression.class::isInstance, true))
.collect(Collectors.toList());
Set<AggregateFunction> normalizedAggregateFunctions =
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
existsAliases = ExpressionUtils.collect(normalizeOutputPhase1, Alias.class::isInstance);
// now reuse the exists alias for the aggregate functions,
// or create new alias for the aggregate functions
NormalizeToSlotContext aggregateFunctionToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions);
Set<NamedExpression> normalizedAggregateFunctionsWithAlias =
aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions);
List<Slot> normalizedGroupBy =
(List) groupByAndArgumentToSlotContext
.normalizeToUseSlotRef(aggregate.getGroupByExpressions());
// we can safely add all groupBy and aggregate functions to output, because we will
// add a project on it, and the upper project can protect the scope of visible of slot
List<NamedExpression> normalizedAggregateOutput = ImmutableList.<NamedExpression>builder()
.addAll(normalizedGroupBy)
.addAll(normalizedAggregateFunctionsWithAlias)
// step 2: if we only have one distinct agg function, we do push down for it
if (!distinctAggFuncs.isEmpty()) {
// process distinct normalize and put it back to normalizedAggFuncs
List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList();
Map<Expression, Expression> replaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> aliasCache = Maps.newHashMap();
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : distinctAggFunc.children()) {
if (child instanceof SlotReference) {
newChildren.add(child);
} else {
NamedExpression alias;
if (aliasCache.containsKey(child)) {
alias = aliasCache.get(child);
} else {
alias = new Alias(child, child.toSql());
aliasCache.put(child, alias);
}
bottomProjects.add(alias);
newChildren.add(alias.toSlot());
}
}
AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren);
replaceMap.put(distinctAggFunc, newDistinctAggFunc);
newDistinctAggFuncs.add(newDistinctAggFunc);
}
aggregateOutput = aggregateOutput.stream()
.map(e -> ExpressionUtils.replace(e, replaceMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
distinctAggFuncs = newDistinctAggFuncs;
}
normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
normalizedAggFuncs.addAll(distinctAggFuncs);
// TODO: process redundant expressions in aggregate functions children
// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Set<Slot> bottomProjectSlots = bottomProjects.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toSet());
Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(e -> !bottomProjectSlots.contains(e))
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
// build upper project, use two context to do pop up, because agg output maybe contain two part:
// group by keys and agg expressions
List<NamedExpression> upperProjects = groupByToSlotContext
.normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput);
upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects);
// process Expression like Alias(SlotReference#0)#0
upperProjects = upperProjects.stream().map(e -> {
if (e instanceof Alias) {
Alias alias = (Alias) e;
if (alias.child() instanceof SlotReference) {
SlotReference slotReference = (SlotReference) alias.child();
if (slotReference.getExprId().equals(alias.getExprId())) {
return slotReference;
}
}
}
return e;
}).collect(Collectors.toList());
LogicalAggregate<Plan> normalizedAggregate = aggregate.withNormalized(
(List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild);
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
// exclude same-name functions in WindowExpression
List<NamedExpression> upperProjects = normalizeOutputPhase1.stream()
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
return new LogicalProject<>(upperProjects, normalizedAggregate);
return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}
private Set<Expression> collectGroupByAndArgumentsOfAggregateFunctions(LogicalAggregate<? extends Plan> aggregate) {
// 2 parts need push down:
// groupingByExpr, argumentsOfAggregateFunction
private static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> {
Set<Expression> groupingByExpr = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
private static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs();
Set<AggregateFunction> aggregateFunctions = collectNonWindowedAggregateFunctions(
aggregate.getOutputExpressions());
Set<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream()
.map(expr -> expr instanceof OrderExpression ? expr.child(0) : expr))
.collect(ImmutableSet.toImmutableSet());
Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions());
Set<Expression> needPushDown = ImmutableSet.<Expression>builder()
// group by should be pushed down, e.g. group by (k + 1),
// we should push down the `k + 1` to the bottom plan
.addAll(groupingByExpr)
// e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan
.addAll(argumentsOfAggregateFunction)
.addAll(windowFunctionKeys)
.build();
return needPushDown;
}
private Set<Expression> collectWindowFunctionKeys(List<NamedExpression> aggOutput) {
Set<Expression> windowInputs = Sets.newHashSet();
for (Expression expr : aggOutput) {
Set<WindowExpression> windows = expr.collect(WindowExpression.class::isInstance);
for (WindowExpression win : windows) {
windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk -> pk.getInputSlots().stream()).collect(
Collectors.toList()));
windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> ok.getInputSlots().stream()).collect(
Collectors.toList()));
@Override
public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> context) {
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
child.accept(this, context);
}
return null;
}
return windowInputs;
}
/**
* select sum(c2), avg(min(c2)) over (partition by max(c1) order by count(c1)) from T ...
* extract {sum, min, max, count}. avg is not extracted.
*/
private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
return ExpressionUtils.collect(aggOutput, expr -> {
if (expr instanceof AggregateFunction) {
return !((AggregateFunction) expr).isWindowFunction();
}
return false;
});
}
private Set<AggregateFunction> collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
List<WindowExpression> windows = Lists.newArrayList(
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
return ExpressionUtils.collect(windows, expr -> {
if (expr instanceof AggregateFunction) {
return !((AggregateFunction) expr).isWindowFunction();
}
return false;
});
}
private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput) {
List<WindowExpression> windows = Lists.newArrayList(
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
return windows.stream().flatMap(win -> win.getInputSlots().stream()).collect(Collectors.toSet());
@Override
public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> context) {
context.add(aggregateFunction);
return null;
}
}
}

View File

@ -21,17 +21,20 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/** NormalizeToSlot */
@ -45,9 +48,16 @@ public interface NormalizeToSlot {
this.normalizeToSlotMap = normalizeToSlotMap;
}
/** buildContext */
/**
* build normalization context by follow step.
* 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias
* 2. for all input source expressions, use existsAliasMap to construct triple:
* origin expr, pushed expr and alias to replace origin expr,
* see more detail in {@link NormalizeToSlotTriplet}
* 3. construct a map: original expr -> triple constructed by step 2
*/
public static NormalizeToSlotContext buildContext(
Set<Alias> existsAliases, Set<? extends Expression> sourceExpressions) {
Set<Alias> existsAliases, Collection<? extends Expression> sourceExpressions) {
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap();
@ -70,13 +80,21 @@ public interface NormalizeToSlot {
return normalizeToUseSlotRef(ImmutableList.of(expression)).get(0);
}
/** normalizeToUseSlotRef, no custom normalize */
public <E extends Expression> List<E> normalizeToUseSlotRef(List<E> expressions) {
/**
* normalizeToUseSlotRef, no custom normalize.
* This function use a lambda that always return original expression as customNormalize
* So always use normalizeToSlotMap to process normalization when we call this function
*/
public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions) {
return normalizeToUseSlotRef(expressions, (context, expr) -> expr);
}
/** normalizeToUseSlotRef */
public <E extends Expression> List<E> normalizeToUseSlotRef(List<E> expressions,
/**
* normalizeToUseSlotRef.
* try to use customNormalize do normalization first. if customNormalize cannot handle current expression,
* use normalizeToSlotMap to get the default replaced expression.
*/
public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions,
BiFunction<NormalizeToSlotContext, Expression, Expression> customNormalize) {
return expressions.stream()
.map(expr -> (E) expr.rewriteDownShortCircuit(child -> {
@ -89,22 +107,11 @@ public interface NormalizeToSlot {
})).collect(ImmutableList.toImmutableList());
}
public <E extends Expression> E normalizeToUseSlotRefUp(E expression, Predicate skip) {
return (E) expression.rewriteDownShortCircuitUp(child -> {
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
}, skip);
}
/**
* rewrite subtrees whose root matches predicate border
* when we traverse to the node satisfies border predicate, aboveBorder becomes false
*/
public <E extends Expression> E normalizeToUseSlotRefDown(E expression, Predicate border, boolean aboveBorder) {
return (E) expression.rewriteDownShortCircuitDown(child -> {
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
}, border, aboveBorder);
public <E extends Expression> List<E> normalizeToUseSlotRefWithoutWindowFunction(
Collection<E> expressions) {
return expressions.stream()
.map(e -> (E) e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap))
.collect(Collectors.toList());
}
/**
@ -124,6 +131,54 @@ public interface NormalizeToSlot {
}
}
/**
* replace any expression except window function.
* because the window function could be same with aggregate function and should never be replaced.
*/
class NormalizeWithoutWindowFunction
extends DefaultExpressionRewriter<Map<Expression, NormalizeToSlotTriplet>> {
public static final NormalizeWithoutWindowFunction INSTANCE = new NormalizeWithoutWindowFunction();
private NormalizeWithoutWindowFunction() {
}
@Override
public Expression visit(Expression expr, Map<Expression, NormalizeToSlotTriplet> replaceMap) {
if (replaceMap.containsKey(expr)) {
return replaceMap.get(expr).remainExpr;
}
return super.visit(expr, replaceMap);
}
@Override
public Expression visitWindow(WindowExpression windowExpression,
Map<Expression, NormalizeToSlotTriplet> replaceMap) {
if (replaceMap.containsKey(windowExpression)) {
return replaceMap.get(windowExpression).remainExpr;
}
List<Expression> newChildren = new ArrayList<>();
Expression function = super.visit(windowExpression.getFunction(), replaceMap);
newChildren.add(function);
boolean hasNewChildren = function != windowExpression.getFunction();
for (Expression partitionKey : windowExpression.getPartitionKeys()) {
Expression newChild = partitionKey.accept(this, replaceMap);
if (newChild != partitionKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
for (Expression orderKey : windowExpression.getOrderKeys()) {
Expression newChild = orderKey.accept(this, replaceMap);
if (newChild != orderKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression;
}
}
/** NormalizeToSlotTriplet */
class NormalizeToSlotTriplet {
// which expression need to normalized to slot?
@ -142,7 +197,12 @@ public interface NormalizeToSlot {
this.pushedExpr = pushedExpr;
}
/** toTriplet */
/**
* construct triplet by three conditions.
* 1. already has exists alias: use this alias as pushed expr
* 2. expression is {@link NamedExpression}, use itself as pushed expr
* 3. other expression, construct a new Alias contains current expression as pushed expr
*/
public static NormalizeToSlotTriplet toTriplet(Expression expression, @Nullable Alias existsAlias) {
if (existsAlias != null) {
return new NormalizeToSlotTriplet(expression, existsAlias.toSlot(), existsAlias);
@ -150,9 +210,7 @@ public interface NormalizeToSlot {
if (expression instanceof NamedExpression) {
NamedExpression namedExpression = (NamedExpression) expression;
NormalizeToSlotTriplet normalizeToSlotTriplet =
new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression);
return normalizeToSlotTriplet;
return new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression);
}
Alias alias = new Alias(expression, expression.toSql());

View File

@ -96,33 +96,6 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
return currentNode;
}
/**
* same as rewriteDownShortCircuit,
* except that subtrees, whose root satisfies predicate is satisfied, are not rewritten
*/
default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE> rewriteFunction, Predicate skip) {
NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
if (skip.test(currentNode)) {
return currentNode;
}
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuitUp(rewriteFunction, skip);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
if (changed) {
currentNode = currentNode.withChildren(newChildren.build());
}
}
return currentNode;
}
/**
* similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies
* border predicate are rewritten.

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -55,9 +54,6 @@ public class WindowExpression extends Expression {
.addAll(orderKeys)
.build().toArray(new Expression[0]));
this.function = function;
if (function instanceof AggregateFunction) {
((AggregateFunction) function).setWindowFunction(true);
}
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.empty();
@ -73,9 +69,6 @@ public class WindowExpression extends Expression {
.add(windowFrame)
.build().toArray(new Expression[0]));
this.function = function;
if (function instanceof AggregateFunction) {
((AggregateFunction) function).setWindowFunction(true);
}
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));

View File

@ -38,7 +38,6 @@ import java.util.stream.Collectors;
public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes {
protected final boolean distinct;
protected boolean isWindowFunction = false;
public AggregateFunction(String name, Expression... arguments) {
this(name, false, arguments);
@ -78,14 +77,6 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return distinct;
}
public boolean isWindowFunction() {
return isWindowFunction;
}
public void setWindowFunction(boolean windowFunction) {
isWindowFunction = windowFunction;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -95,8 +86,7 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return false;
}
AggregateFunction that = (AggregateFunction) o;
return isWindowFunction == that.isWindowFunction
&& Objects.equals(distinct, that.distinct)
return Objects.equals(distinct, that.distinct)
&& Objects.equals(getName(), that.getName())
&& Objects.equals(children, that.children);
}