[fix](nereids) fix bugs in nereids window function (#17284)

fix two problems:

1. push agg-fun in windowExpression down to AggregateNode
for example, sql:
select sum(sum(a)) over (order by b)
Plan:
windowExpression( sum(y) over (order by b))
+--- Agg(sum(a) as y, b)

2. push other expr to upper proj
for example, sql:
select sum(a+1) over ()
Plan:
windowExpression(sum(y) over ())
+--- Project(a + 1 as y,...)
+--- Agg(a,...)
This commit is contained in:
minghong
2023-03-07 16:35:37 +08:00
committed by GitHub
parent 8ccc805cd0
commit fd8adb492d
12 changed files with 352 additions and 30 deletions

View File

@ -107,9 +107,10 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
Expr staleExpr = expression.accept(INSTANCE, context);
try {
staleExpr.finalizeForNereids();
} catch (org.apache.doris.common.AnalysisException e) {
} catch (Exception e) {
throw new AnalysisException(
"Translate Nereids expression to stale expression failed. " + e.getMessage(), e);
"Translate Nereids expression `" + expression.toSql()
+ "` to stale expression failed. " + e.getMessage(), e);
}
return staleExpr;
}

View File

@ -64,10 +64,10 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
.collect(Collectors.toSet());
notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput);
if (!notFromChildren.isEmpty()) {
throw new AnalysisException(String.format("Input slot(s) not in child's output: %s",
throw new AnalysisException(String.format("Input slot(s) not in child's output: %s in plan: %s",
StringUtils.join(notFromChildren.stream()
.map(ExpressionTrait::toSql)
.collect(Collectors.toSet()), ", ")));
.collect(Collectors.toSet()), ", "), plan));
}
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
@ -38,6 +39,8 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* extract window expressions from LogicalProject.projects and Normalize LogicalWindow
@ -94,14 +97,44 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
// bottomProjects includes:
// 1. expressions from function and WindowSpec's partitionKeys and orderKeys
// 2. other slots of outputExpressions
/*
avg(c) / sum(a+1) over (order by avg(b)) group by a
win(x/sum(z) over y)
prj(x, y, a+1 as z)
agg(avg(c) x, avg(b) y, a)
proj(a b c)
toBePushDown = {avg(c), a+1, avg(b)}
*/
return expressions.stream()
.flatMap(expression -> {
if (expression.anyMatch(WindowExpression.class::isInstance)) {
Set<Slot> inputSlots = expression.getInputSlots().stream().collect(Collectors.toSet());
Set<WindowExpression> collects = expression.collect(WindowExpression.class::isInstance);
return collects.stream().flatMap(windowExpression ->
windowExpression.getExpressionsInWindowSpec().stream()
// constant arguments may in WindowFunctions(e.g. Lead, Lag), which shouldn't be pushed down
.filter(expr -> !expr.isConstant())
Set<Slot> windowInputSlots = collects.stream().flatMap(
win -> win.getInputSlots().stream()
).collect(Collectors.toSet());
/*
substr(
ref_1.cp_type,
max(
cast(ref_1.`cp_catalog_page_number` as int)) over (...)
),
1)
in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in
substr.inputSlots, but not in windowExpression.inputSlots
inputSlots= {ref_1.cp_type}
*/
inputSlots.removeAll(windowInputSlots);
return Stream.concat(
collects.stream().flatMap(windowExpression ->
windowExpression.getExpressionsInWindowSpec().stream()
// constant arguments may in WindowFunctions(e.g. Lead, Lag)
// which shouldn't be pushed down
.filter(expr -> !expr.isConstant())
),
inputSlots.stream()
);
}
return ImmutableList.of(expression).stream();

View File

@ -34,10 +34,13 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
@ -57,6 +60,31 @@ import java.util.stream.Collectors;
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11])
* +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
* <p>
*
* Note: window function will be moved to upper project
* all agg functions except the top agg should be pushed to Aggregate node.
* example 1:
* select min(x), sum(x) over () ...
* the 'sum(x)' is top agg of window function, it should be moved to upper project
* plan:
* project(sum(x) over())
* Aggregate(min(x), x)
*
* example 2:
* select min(x), avg(sum(x)) over() ...
* the 'sum(x)' should be moved to Aggregate
* plan:
* project(avg(y) over())
* Aggregate(min(x), sum(x) as y)
* example 3:
* select sum(x+1), x+1, sum(x+1) over() ...
* window function should use x instead of x+1
* plan:
* project(sum(x+1) over())
* Agg(sum(y), x)
* project(x+1 as y)
*
*
* More example could get from UT {NormalizeAggregateTest}
*/
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
@ -64,8 +92,37 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// push expression to bottom project
Set<Alias> existsAliases = ExpressionUtils.collect(
Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
aggregate.getOutputExpressions(), Alias.class::isInstance);
Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow(
aggregate.getOutputExpressions());
Set<Expression> existsAggAlias = existsAliases.stream().map(alias -> alias.child())
.filter(AggregateFunction.class::isInstance)
.collect(Collectors.toSet());
/*
* agg-functions inside window function is regarded as an output of aggregate.
* select sum(avg(c)) over ...
* is regarded as
* select avg(c), sum(avg(c)) over ...
*
* the plan:
* project(sum(y) over)
* Aggregate(avg(c) as y)
*
* after Aggregate, the 'y' is removed by upper project.
*
* aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
*/
List<Alias> aliasOfAggFunInWindowUsedAsAggOutput = Lists.newArrayList();
for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
if (!existsAggAlias.contains(aggFun)) {
Alias alias = new Alias(aggFun, aggFun.toSql());
existsAliases.add(alias);
aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
}
}
Set<Expression> needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
NormalizeToSlotContext groupByAndArgumentToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases, needToSlots);
@ -80,13 +137,22 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
// replace groupBy and arguments of aggregate function to slot, may be this output contains
// some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace
// the sum(value) to slot and move the `slot + 1` to the upper project later.
List<NamedExpression> normalizeOutputPhase1 = aggregate.getOutputExpressions().stream()
.map(expr -> {
if (expr.anyMatch(WindowExpression.class::isInstance)) {
return expr;
}
return groupByAndArgumentToSlotContext.normalizeToUseSlotRef(expr);
}).collect(Collectors.toList());
List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
aggregate.getOutputExpressions().stream(),
aliasOfAggFunInWindowUsedAsAggOutput.stream())
.map(expr -> groupByAndArgumentToSlotContext
.normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance))
.collect(Collectors.toList());
Set<Slot> windowInputSlots = collectWindowInputSlots(aggregate.getOutputExpressions());
Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
itemsInWindow.addAll(aggregateFunctionsInWindow);
NormalizeToSlotContext windowToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases, itemsInWindow);
normalizeOutputPhase1 = normalizeOutputPhase1.stream()
.map(expr -> windowToSlotContext
.normalizeToUseSlotRefDown(expr, WindowExpression.class::isInstance, true))
.collect(Collectors.toList());
Set<AggregateFunction> normalizedAggregateFunctions =
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
@ -115,14 +181,10 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
LogicalAggregate<Plan> normalizedAggregate = aggregate.withNormalized(
(List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild);
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
// exclude same-name functions in WindowExpression
List<NamedExpression> upperProjects = normalizeOutputPhase1.stream()
.map(expr -> {
if (expr.anyMatch(WindowExpression.class::isInstance)) {
return expr;
}
return aggregateFunctionToSlotContext.normalizeToUseSlotRef(expr);
}).collect(Collectors.toList());
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
return new LogicalProject<>(upperProjects, normalizedAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}
@ -146,21 +208,61 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
}))
.collect(ImmutableSet.toImmutableSet());
Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions());
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
// group by should be pushed down, e.g. group by (k + 1),
// we should push down the `k + 1` to the bottom plan
.addAll(groupingByExpr)
// e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan
.addAll(argumentsOfAggregateFunction)
.addAll(windowFunctionKeys)
.build();
return needPushDown;
}
private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
List<Expression> expressionsWithoutWindow = aggOutput.stream()
.filter(expr -> !expr.anyMatch(WindowExpression.class::isInstance))
.collect(Collectors.toList());
private Set<Expression> collectWindowFunctionKeys(List<NamedExpression> aggOutput) {
Set<Expression> windowInputs = Sets.newHashSet();
for (Expression expr : aggOutput) {
Set<WindowExpression> windows = expr.collect(WindowExpression.class::isInstance);
for (WindowExpression win : windows) {
windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk -> pk.getInputSlots().stream()).collect(
Collectors.toList()));
windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> ok.getInputSlots().stream()).collect(
Collectors.toList()));
}
}
return windowInputs;
}
return ExpressionUtils.collect(expressionsWithoutWindow, AggregateFunction.class::isInstance);
/**
* select sum(c2), avg(min(c2)) over (partition by max(c1) order by count(c1)) from T ...
* extract {sum, min, max, count}. avg is not extracted.
*/
private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
return ExpressionUtils.collect(aggOutput, expr -> {
if (expr instanceof AggregateFunction) {
return !((AggregateFunction) expr).isWindowFunction();
}
return false;
});
}
private Set<AggregateFunction> collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
List<WindowExpression> windows = Lists.newArrayList(
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
return ExpressionUtils.collect(windows, expr -> {
if (expr instanceof AggregateFunction) {
return !((AggregateFunction) expr).isWindowFunction();
}
return false;
});
}
private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput) {
List<WindowExpression> windows = Lists.newArrayList(
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
return windows.stream().flatMap(win -> win.getInputSlots().stream()).collect(Collectors.toSet());
}
}

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import javax.annotation.Nullable;
/** NormalizeToSlot */
@ -88,6 +89,24 @@ public interface NormalizeToSlot {
})).collect(ImmutableList.toImmutableList());
}
public <E extends Expression> E normalizeToUseSlotRefUp(E expression, Predicate skip) {
return (E) expression.rewriteDownShortCircuitUp(child -> {
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
}, skip);
}
/**
* rewrite subtrees whose root matches predicate border
* when we traverse to the node satisfies border predicate, aboveBorder becomes false
*/
public <E extends Expression> E normalizeToUseSlotRefDown(E expression, Predicate border, boolean aboveBorder) {
return (E) expression.rewriteDownShortCircuitDown(child -> {
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
}, border, aboveBorder);
}
/**
* generate bottom projections with groupByExpressions.
* eg:

View File

@ -96,6 +96,64 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
return currentNode;
}
/**
* same as rewriteDownShortCircuit,
* except that subtrees, whose root satisfies predicate is satisfied, are not rewritten
*/
default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE> rewriteFunction, Predicate skip) {
NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
if (skip.test(currentNode)) {
return currentNode;
}
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuitUp(rewriteFunction, skip);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
if (changed) {
currentNode = currentNode.withChildren(newChildren.build());
}
}
return currentNode;
}
/**
* similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies
* border predicate are rewritten.
*/
default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE, NODE_TYPE> rewriteFunction,
Predicate border, boolean aboveBorder) {
NODE_TYPE currentNode = (NODE_TYPE) this;
if (border.test(this)) {
aboveBorder = false;
}
if (!aboveBorder) {
currentNode = rewriteFunction.apply((NODE_TYPE) this);
}
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
if (changed) {
currentNode = currentNode.withChildren(newChildren.build());
}
}
return currentNode;
}
/**
* bottom-up rewrite.
* @param rewriteFunction rewrite function.

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -53,6 +54,9 @@ public class WindowExpression extends Expression {
.addAll(orderKeys)
.build().toArray(new Expression[0]));
this.function = function;
if (function instanceof AggregateFunction) {
((AggregateFunction) function).setWindowFunction(true);
}
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.empty();
@ -68,6 +72,9 @@ public class WindowExpression extends Expression {
.add(windowFrame)
.build().toArray(new Expression[0]));
this.function = function;
if (function instanceof AggregateFunction) {
((AggregateFunction) function).setWindowFunction(true);
}
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));

View File

@ -38,6 +38,7 @@ import java.util.stream.Collectors;
public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes {
protected final boolean distinct;
protected boolean isWindowFunction = false;
public AggregateFunction(String name, Expression... arguments) {
this(name, false, arguments);
@ -77,6 +78,14 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return distinct;
}
public boolean isWindowFunction() {
return isWindowFunction;
}
public void setWindowFunction(boolean windowFunction) {
isWindowFunction = windowFunction;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -86,7 +95,8 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
return false;
}
AggregateFunction that = (AggregateFunction) o;
return Objects.equals(distinct, that.distinct)
return isWindowFunction == that.isWindowFunction
&& Objects.equals(distinct, that.distinct)
&& Objects.equals(getName(), that.getName())
&& Objects.equals(children, that.children);
}
@ -123,4 +133,5 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
.collect(Collectors.joining(", "));
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
}
}

View File

@ -411,6 +411,13 @@ public class ExpressionUtils {
.collect(ImmutableSet.toImmutableSet());
}
public static <E> Set<E> mutableCollect(List<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
.collect(Collectors.toSet());
}
public static <E> List<E> collectAll(List<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()

View File

@ -106,7 +106,11 @@ public class AnalyticEvalNode extends PlanNode {
AnalyticWindow analyticWindow, TupleDescriptor intermediateTupleDesc,
TupleDescriptor outputTupleDesc, Expr partitionByEq, Expr orderByEq,
TupleDescriptor bufferedTupleDesc) {
super(id, input.getTupleIds(), "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
super(id,
(input.getOutputTupleDesc() != null
? Lists.newArrayList(input.getOutputTupleDesc().getId()) :
input.getTupleIds()),
"ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
Preconditions.checkState(!tupleIds.contains(outputTupleDesc.getId()));
// we're materializing the input row augmented with the analytic output tuple
tupleIds.add(outputTupleDesc.getId());