diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 31867e82d4..c379e49b8f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -107,9 +107,10 @@ public class ExpressionTranslator extends DefaultExpressionVisitor { if (expression.anyMatch(WindowExpression.class::isInstance)) { + Set inputSlots = expression.getInputSlots().stream().collect(Collectors.toSet()); Set collects = expression.collect(WindowExpression.class::isInstance); - return collects.stream().flatMap(windowExpression -> - windowExpression.getExpressionsInWindowSpec().stream() - // constant arguments may in WindowFunctions(e.g. Lead, Lag), which shouldn't be pushed down - .filter(expr -> !expr.isConstant()) + Set windowInputSlots = collects.stream().flatMap( + win -> win.getInputSlots().stream() + ).collect(Collectors.toSet()); + /* + substr( + ref_1.cp_type, + max( + cast(ref_1.`cp_catalog_page_number` as int)) over (...) + ), + 1) + + in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in + substr.inputSlots, but not in windowExpression.inputSlots + + inputSlots= {ref_1.cp_type} + */ + inputSlots.removeAll(windowInputSlots); + return Stream.concat( + collects.stream().flatMap(windowExpression -> + windowExpression.getExpressionsInWindowSpec().stream() + // constant arguments may in WindowFunctions(e.g. Lead, Lag) + // which shouldn't be pushed down + .filter(expr -> !expr.isConstant()) + ), + inputSlots.stream() ); } return ImmutableList.of(expression).stream(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 07ddffeced..cf802245ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -34,10 +34,13 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * normalize aggregate's group keys and AggregateFunction's child to SlotReference @@ -57,6 +60,31 @@ import java.util.stream.Collectors; * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11]) * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3) *

+ * + * Note: window function will be moved to upper project + * all agg functions except the top agg should be pushed to Aggregate node. + * example 1: + * select min(x), sum(x) over () ... + * the 'sum(x)' is top agg of window function, it should be moved to upper project + * plan: + * project(sum(x) over()) + * Aggregate(min(x), x) + * + * example 2: + * select min(x), avg(sum(x)) over() ... + * the 'sum(x)' should be moved to Aggregate + * plan: + * project(avg(y) over()) + * Aggregate(min(x), sum(x) as y) + * example 3: + * select sum(x+1), x+1, sum(x+1) over() ... + * window function should use x instead of x+1 + * plan: + * project(sum(x+1) over()) + * Agg(sum(y), x) + * project(x+1 as y) + * + * * More example could get from UT {NormalizeAggregateTest} */ public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot { @@ -64,8 +92,37 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali public Rule build() { return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { // push expression to bottom project - Set existsAliases = ExpressionUtils.collect( + Set existsAliases = ExpressionUtils.mutableCollect( aggregate.getOutputExpressions(), Alias.class::isInstance); + Set aggregateFunctionsInWindow = collectAggregateFunctionsInWindow( + aggregate.getOutputExpressions()); + Set existsAggAlias = existsAliases.stream().map(alias -> alias.child()) + .filter(AggregateFunction.class::isInstance) + .collect(Collectors.toSet()); + + /* + * agg-functions inside window function is regarded as an output of aggregate. + * select sum(avg(c)) over ... + * is regarded as + * select avg(c), sum(avg(c)) over ... + * + * the plan: + * project(sum(y) over) + * Aggregate(avg(c) as y) + * + * after Aggregate, the 'y' is removed by upper project. + * + * aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))} + */ + List aliasOfAggFunInWindowUsedAsAggOutput = Lists.newArrayList(); + + for (AggregateFunction aggFun : aggregateFunctionsInWindow) { + if (!existsAggAlias.contains(aggFun)) { + Alias alias = new Alias(aggFun, aggFun.toSql()); + existsAliases.add(alias); + aliasOfAggFunInWindowUsedAsAggOutput.add(alias); + } + } Set needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate); NormalizeToSlotContext groupByAndArgumentToSlotContext = NormalizeToSlotContext.buildContext(existsAliases, needToSlots); @@ -80,13 +137,22 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali // replace groupBy and arguments of aggregate function to slot, may be this output contains // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace // the sum(value) to slot and move the `slot + 1` to the upper project later. - List normalizeOutputPhase1 = aggregate.getOutputExpressions().stream() - .map(expr -> { - if (expr.anyMatch(WindowExpression.class::isInstance)) { - return expr; - } - return groupByAndArgumentToSlotContext.normalizeToUseSlotRef(expr); - }).collect(Collectors.toList()); + List normalizeOutputPhase1 = Stream.concat( + aggregate.getOutputExpressions().stream(), + aliasOfAggFunInWindowUsedAsAggOutput.stream()) + .map(expr -> groupByAndArgumentToSlotContext + .normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance)) + .collect(Collectors.toList()); + + Set windowInputSlots = collectWindowInputSlots(aggregate.getOutputExpressions()); + Set itemsInWindow = Sets.newHashSet(windowInputSlots); + itemsInWindow.addAll(aggregateFunctionsInWindow); + NormalizeToSlotContext windowToSlotContext = + NormalizeToSlotContext.buildContext(existsAliases, itemsInWindow); + normalizeOutputPhase1 = normalizeOutputPhase1.stream() + .map(expr -> windowToSlotContext + .normalizeToUseSlotRefDown(expr, WindowExpression.class::isInstance, true)) + .collect(Collectors.toList()); Set normalizedAggregateFunctions = collectNonWindowedAggregateFunctions(normalizeOutputPhase1); @@ -115,14 +181,10 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali LogicalAggregate normalizedAggregate = aggregate.withNormalized( (List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild); + normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput); // exclude same-name functions in WindowExpression List upperProjects = normalizeOutputPhase1.stream() - .map(expr -> { - if (expr.anyMatch(WindowExpression.class::isInstance)) { - return expr; - } - return aggregateFunctionToSlotContext.normalizeToUseSlotRef(expr); - }).collect(Collectors.toList()); + .map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList()); return new LogicalProject<>(upperProjects, normalizedAggregate); }).toRule(RuleType.NORMALIZE_AGGREGATE); } @@ -146,21 +208,61 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali })) .collect(ImmutableSet.toImmutableSet()); + Set windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions()); + ImmutableSet needPushDown = ImmutableSet.builder() // group by should be pushed down, e.g. group by (k + 1), // we should push down the `k + 1` to the bottom plan .addAll(groupingByExpr) // e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan .addAll(argumentsOfAggregateFunction) + .addAll(windowFunctionKeys) .build(); return needPushDown; } - private Set collectNonWindowedAggregateFunctions(List aggOutput) { - List expressionsWithoutWindow = aggOutput.stream() - .filter(expr -> !expr.anyMatch(WindowExpression.class::isInstance)) - .collect(Collectors.toList()); + private Set collectWindowFunctionKeys(List aggOutput) { + Set windowInputs = Sets.newHashSet(); + for (Expression expr : aggOutput) { + Set windows = expr.collect(WindowExpression.class::isInstance); + for (WindowExpression win : windows) { + windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk -> pk.getInputSlots().stream()).collect( + Collectors.toList())); + windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> ok.getInputSlots().stream()).collect( + Collectors.toList())); + } + } + return windowInputs; + } - return ExpressionUtils.collect(expressionsWithoutWindow, AggregateFunction.class::isInstance); + /** + * select sum(c2), avg(min(c2)) over (partition by max(c1) order by count(c1)) from T ... + * extract {sum, min, max, count}. avg is not extracted. + */ + private Set collectNonWindowedAggregateFunctions(List aggOutput) { + return ExpressionUtils.collect(aggOutput, expr -> { + if (expr instanceof AggregateFunction) { + return !((AggregateFunction) expr).isWindowFunction(); + } + return false; + }); + } + + private Set collectAggregateFunctionsInWindow(List aggOutput) { + + List windows = Lists.newArrayList( + ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance)); + return ExpressionUtils.collect(windows, expr -> { + if (expr instanceof AggregateFunction) { + return !((AggregateFunction) expr).isWindowFunction(); + } + return false; + }); + } + + private Set collectWindowInputSlots(List aggOutput) { + List windows = Lists.newArrayList( + ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance)); + return windows.stream().flatMap(win -> win.getInputSlots().stream()).collect(Collectors.toSet()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java index aee929c8be..8ef966496e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; +import java.util.function.Predicate; import javax.annotation.Nullable; /** NormalizeToSlot */ @@ -88,6 +89,24 @@ public interface NormalizeToSlot { })).collect(ImmutableList.toImmutableList()); } + public E normalizeToUseSlotRefUp(E expression, Predicate skip) { + return (E) expression.rewriteDownShortCircuitUp(child -> { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + }, skip); + } + + /** + * rewrite subtrees whose root matches predicate border + * when we traverse to the node satisfies border predicate, aboveBorder becomes false + */ + public E normalizeToUseSlotRefDown(E expression, Predicate border, boolean aboveBorder) { + return (E) expression.rewriteDownShortCircuitDown(child -> { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + }, border, aboveBorder); + } + /** * generate bottom projections with groupByExpressions. * eg: diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index d8f10a362d..92b99ec68e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -96,6 +96,64 @@ public interface TreeNode> { return currentNode; } + /** + * same as rewriteDownShortCircuit, + * except that subtrees, whose root satisfies predicate is satisfied, are not rewritten + */ + default NODE_TYPE rewriteDownShortCircuitUp(Function rewriteFunction, Predicate skip) { + NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this); + if (skip.test(currentNode)) { + return currentNode; + } + if (currentNode == this) { + Builder newChildren = ImmutableList.builderWithExpectedSize(arity()); + boolean changed = false; + for (NODE_TYPE child : children()) { + NODE_TYPE newChild = child.rewriteDownShortCircuitUp(rewriteFunction, skip); + if (child != newChild) { + changed = true; + } + newChildren.add(newChild); + } + + if (changed) { + currentNode = currentNode.withChildren(newChildren.build()); + } + } + return currentNode; + } + + /** + * similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies + * border predicate are rewritten. + */ + default NODE_TYPE rewriteDownShortCircuitDown(Function rewriteFunction, + Predicate border, boolean aboveBorder) { + NODE_TYPE currentNode = (NODE_TYPE) this; + if (border.test(this)) { + aboveBorder = false; + } + if (!aboveBorder) { + currentNode = rewriteFunction.apply((NODE_TYPE) this); + } + if (currentNode == this) { + Builder newChildren = ImmutableList.builderWithExpectedSize(arity()); + boolean changed = false; + for (NODE_TYPE child : children()) { + NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder); + if (child != newChild) { + changed = true; + } + newChildren.add(newChild); + } + + if (changed) { + currentNode = currentNode.withChildren(newChildren.build()); + } + } + return currentNode; + } + /** * bottom-up rewrite. * @param rewriteFunction rewrite function. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java index c9e96da1ec..3da1610d9b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -53,6 +54,9 @@ public class WindowExpression extends Expression { .addAll(orderKeys) .build().toArray(new Expression[0])); this.function = function; + if (function instanceof AggregateFunction) { + ((AggregateFunction) function).setWindowFunction(true); + } this.partitionKeys = ImmutableList.copyOf(partitionKeys); this.orderKeys = ImmutableList.copyOf(orderKeys); this.windowFrame = Optional.empty(); @@ -68,6 +72,9 @@ public class WindowExpression extends Expression { .add(windowFrame) .build().toArray(new Expression[0])); this.function = function; + if (function instanceof AggregateFunction) { + ((AggregateFunction) function).setWindowFunction(true); + } this.partitionKeys = ImmutableList.copyOf(partitionKeys); this.orderKeys = ImmutableList.copyOf(orderKeys); this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 7d4dd262a0..a170ae0dd5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -38,6 +38,7 @@ import java.util.stream.Collectors; public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes { protected final boolean distinct; + protected boolean isWindowFunction = false; public AggregateFunction(String name, Expression... arguments) { this(name, false, arguments); @@ -77,6 +78,14 @@ public abstract class AggregateFunction extends BoundFunction implements Expects return distinct; } + public boolean isWindowFunction() { + return isWindowFunction; + } + + public void setWindowFunction(boolean windowFunction) { + isWindowFunction = windowFunction; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -86,7 +95,8 @@ public abstract class AggregateFunction extends BoundFunction implements Expects return false; } AggregateFunction that = (AggregateFunction) o; - return Objects.equals(distinct, that.distinct) + return isWindowFunction == that.isWindowFunction + && Objects.equals(distinct, that.distinct) && Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children); } @@ -123,4 +133,5 @@ public abstract class AggregateFunction extends BoundFunction implements Expects .collect(Collectors.joining(", ")); return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")"; } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index f9cbd57a89..90e7c73c8e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -411,6 +411,13 @@ public class ExpressionUtils { .collect(ImmutableSet.toImmutableSet()); } + public static Set mutableCollect(List expressions, + Predicate> predicate) { + return expressions.stream() + .flatMap(expr -> expr.>collect(predicate).stream()) + .collect(Collectors.toSet()); + } + public static List collectAll(List expressions, Predicate> predicate) { return expressions.stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java index fb5a63b10f..ff218ce7e3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java @@ -106,7 +106,11 @@ public class AnalyticEvalNode extends PlanNode { AnalyticWindow analyticWindow, TupleDescriptor intermediateTupleDesc, TupleDescriptor outputTupleDesc, Expr partitionByEq, Expr orderByEq, TupleDescriptor bufferedTupleDesc) { - super(id, input.getTupleIds(), "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE); + super(id, + (input.getOutputTupleDesc() != null + ? Lists.newArrayList(input.getOutputTupleDesc().getId()) : + input.getTupleIds()), + "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE); Preconditions.checkState(!tupleIds.contains(outputTupleDesc.getId())); // we're materializing the input row augmented with the analytic output tuple tupleIds.add(outputTupleDesc.getId()); diff --git a/regression-test/data/nereids_syntax_p0/window_function.out b/regression-test/data/nereids_syntax_p0/window_function.out index 921703d421..8434d65a73 100644 --- a/regression-test/data/nereids_syntax_p0/window_function.out +++ b/regression-test/data/nereids_syntax_p0/window_function.out @@ -389,3 +389,53 @@ -- !window_use_agg -- 20 +-- !winExpr_not_agg_expr -- +2 5 +3 5 +3 5 +4 5 +4 5 +6 5 +6 5 +6 5 + +-- !on_notgroupbycolumn -- +1.0 +2.0 +3.0 +3.0 +3.0 +6.0 +6.0 +6.0 + +-- !orderby -- +1 2 2 +1 4 2 +1 4 2 +1 6 2 +2 3 5 +2 3 5 +2 6 5 +2 6 5 + +-- !winExpr_with_others -- +0.4 +0.4 +0.5 +0.8 +0.8 +1.0 +1.0 +1.5 + +-- !winExpr_with_others2 -- +0.4 +0.4 +0.5 +0.8 +0.8 +1.0 +1.0 +1.5 + diff --git a/regression-test/suites/nereids_syntax_p0/window_function.groovy b/regression-test/suites/nereids_syntax_p0/window_function.groovy index 76318169b1..4c0509bbe6 100644 --- a/regression-test/suites/nereids_syntax_p0/window_function.groovy +++ b/regression-test/suites/nereids_syntax_p0/window_function.groovy @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -suite("test_window_function") { +suite("window_function") { sql "SET enable_nereids_planner=true" sql "DROP TABLE IF EXISTS window_test" @@ -118,4 +118,34 @@ suite("test_window_function") { SELECT sum(sum(c1)) over(partition by avg(c2)) FROM window_test """ + + order_qt_winExpr_not_agg_expr """ + select sum(c1+1), sum(c1+1) over (partition by avg(c2)) + from window_test + group by c1, c2 + """ + + order_qt_on_notgroupbycolumn """ + select sum(sum(c3)) over (partition by avg(c2) order by c1) + from window_test + group by c1, c2 + """ + + order_qt_orderby """ + select c1, sum(c1+1), sum(c1+1) over (partition by avg(c2) order by c1) + from window_test + group by c1, c2 + """ + + order_qt_winExpr_with_others """ + select sum(c1)/sum(c1+1) over (partition by c2 order by c1) + from window_test + group by c1, c2 + """ + + order_qt_winExpr_with_others2""" + select sum(c1)/sum(c1+1) over (partition by c2 order by c1) + from window_test + group by c1, c2 + """ }