From 0f895640d9c1ab871a803c4f15b9451067862f2f Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:42:23 +0800 Subject: [PATCH] [opt](Nereids)(WIP) optimize agg and window normalization step 1 (#19168) 1. move SimplifyAggGroupBy behind NormalizeAggregate 2. fix project to agg rule for the project containing window expression --- .../nereids/analyzer/NereidsAnalyzer.java | 2 +- .../nereids/jobs/batch/NereidsRewriter.java | 7 +--- .../processor/post/TwoPhaseReadOpt.java | 3 +- .../analysis/ProjectToGlobalAggregate.java | 37 +++++++++++++++---- .../rewrite/logical/NormalizeAggregate.java | 3 +- .../trees/expressions/OrderExpression.java | 1 + .../trees/expressions/WindowExpression.java | 19 ++++------ 7 files changed, 44 insertions(+), 28 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java index 2413e775c1..bc80cef12a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java @@ -59,7 +59,7 @@ public class NereidsAnalyzer extends BatchRewriteJob { ), bottomUp( new ProjectToGlobalAggregate(), - // this rule check's the logicalProject node's isDisinct property + // this rule check's the logicalProject node's isDistinct property // and replace the logicalProject node with a LogicalAggregate node // so any rule before this, if create a new logicalProject node // should make sure isDistinct property is correctly passed around. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java index 4668a771cb..07c8903334 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java @@ -145,19 +145,14 @@ public class NereidsRewriter extends BatchRewriteJob { // but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained, // resulting in inconsistent output results and results in apply topDown( + new SimplifyAggGroupBy(), new NormalizeAggregate(), new NormalizeSort() ), topic("Window analysis", - topDown( - new SimplifyAggGroupBy() - ), topDown( new ExtractAndNormalizeWindowExpression(), - // execute NormalizeAggregate() again to resolve nested AggregateFunctions in WindowExpression, - // e.g. sum(sum(c1)) over(partition by avg(c1)) - new NormalizeAggregate(), new CheckAndStandardizeWindowFunctionAndFrame() ) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TwoPhaseReadOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TwoPhaseReadOpt.java index 18b7dbdf7f..72d1a7d2db 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TwoPhaseReadOpt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TwoPhaseReadOpt.java @@ -66,7 +66,6 @@ public class TwoPhaseReadOpt extends PlanPostProcessor { @Override public PhysicalTopN visitPhysicalTopN(PhysicalTopN mergeTopN, CascadesContext ctx) { mergeTopN.child().accept(this, ctx); - Plan child = mergeTopN.child(); if (mergeTopN.getSortPhase() != SortPhase.MERGE_SORT || !(mergeTopN.child() instanceof PhysicalDistribute)) { return mergeTopN; } @@ -92,7 +91,7 @@ public class TwoPhaseReadOpt extends PlanPostProcessor { PhysicalOlapScan olapScan; PhysicalProject project = null; PhysicalFilter filter = null; - child = localTopN.child(); + Plan child = localTopN.child(); while (child instanceof Project || child instanceof Filter) { if (child instanceof Filter) { filter = (PhysicalFilter) child; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java index 5a798cf7bc..a4cf1d1a8c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java @@ -22,13 +22,14 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import com.google.common.collect.ImmutableList; /** * ProjectToGlobalAggregate. - * + *

* example sql: *

  * select sum(value)
@@ -36,7 +37,7 @@ import com.google.common.collect.ImmutableList;
  * 
* * origin plan: transformed plan: - * + *

* LogicalProject(projects=[sum(value)]) LogicalAggregate(groupBy=[], output=[sum(value)]) * | => | * LogicalOlapScan(table=tbl) LogicalOlapScan(table=tbl) @@ -48,7 +49,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory { logicalProject().then(project -> { boolean needGlobalAggregate = project.getProjects() .stream() - .anyMatch(this::hasNonWindowedAggregateFunction); + .anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE, null)); if (needGlobalAggregate) { return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child()); @@ -59,9 +60,31 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory { ); } - private boolean hasNonWindowedAggregateFunction(Expression expression) { - return expression.anyMatch(WindowExpression.class::isInstance) - ? false - : expression.anyMatch(AggregateFunction.class::isInstance); + private static class NeedAggregateChecker extends DefaultExpressionVisitor { + + private static final NeedAggregateChecker INSTANCE = new NeedAggregateChecker(); + + @Override + public Boolean visit(Expression expr, Void context) { + boolean needAggregate = false; + for (Expression child : expr.children()) { + needAggregate = needAggregate || child.accept(this, context); + } + return needAggregate; + } + + @Override + public Boolean visitWindow(WindowExpression windowExpression, Void context) { + boolean needAggregate = false; + for (Expression child : windowExpression.getExpressionsInWindowSpec()) { + needAggregate = needAggregate || child.accept(this, context); + } + return needAggregate; + } + + @Override + public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) { + return true; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 7938cf8769..fccd933094 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -99,7 +100,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali aggregate.getOutputExpressions(), Alias.class::isInstance); Set aggregateFunctionsInWindow = collectAggregateFunctionsInWindow( aggregate.getOutputExpressions()); - Set existsAggAlias = existsAliases.stream().map(alias -> alias.child()) + Set existsAggAlias = existsAliases.stream().map(UnaryNode::child) .filter(AggregateFunction.class::isInstance) .collect(Collectors.toSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/OrderExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/OrderExpression.java index da80d2e7a3..5f0a4fb21d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/OrderExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/OrderExpression.java @@ -33,6 +33,7 @@ import java.util.Objects; * e.g. group_concat(id, ',' order by num desc), the num desc is order expression */ public class OrderExpression extends Expression implements UnaryExpression, PropagateNullable { + private final OrderKey orderKey; public OrderExpression(OrderKey orderKey) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java index 3da1610d9b..ffc0522498 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; +import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -93,7 +94,7 @@ public class WindowExpression extends Expression { expressions.addAll(function.children()); expressions.addAll(partitionKeys); expressions.addAll(orderKeys.stream() - .map(orderExpression -> orderExpression.child()) + .map(UnaryNode::child) .collect(Collectors.toList())); return expressions; } @@ -115,17 +116,13 @@ public class WindowExpression extends Expression { } public WindowExpression withOrderKeyList(List orderKeyList) { - if (windowFrame.isPresent()) { - return new WindowExpression(function, partitionKeys, orderKeyList, windowFrame.get()); - } - return new WindowExpression(function, partitionKeys, orderKeyList); + return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeyList, frame)) + .orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeyList)); } public WindowExpression withFunction(Expression function) { - if (windowFrame.isPresent()) { - return new WindowExpression(function, partitionKeys, orderKeys, windowFrame.get()); - } - return new WindowExpression(function, partitionKeys, orderKeys); + return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame)) + .orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys)); } @Override @@ -177,7 +174,7 @@ public class WindowExpression extends Expression { @Override public String toSql() { StringBuilder sb = new StringBuilder(); - sb.append(function.toSql() + " OVER("); + sb.append(function.toSql()).append(" OVER("); if (!partitionKeys.isEmpty()) { sb.append("PARTITION BY ").append(partitionKeys.stream() .map(Expression::toSql) @@ -195,7 +192,7 @@ public class WindowExpression extends Expression { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(function + " WindowSpec("); + sb.append(function).append(" WindowSpec("); if (!partitionKeys.isEmpty()) { sb.append("PARTITION BY ").append(partitionKeys.stream() .map(Expression::toString)