[opt](Nereids)(WIP) optimize agg and window normalization step 1 (#19168)

1. move SimplifyAggGroupBy behind NormalizeAggregate
2. fix project to agg rule for the project containing window expression
This commit is contained in:
morrySnow
2023-04-27 21:42:23 +08:00
committed by GitHub
parent 027bd3f998
commit 0f895640d9
7 changed files with 44 additions and 28 deletions

View File

@ -59,7 +59,7 @@ public class NereidsAnalyzer extends BatchRewriteJob {
),
bottomUp(
new ProjectToGlobalAggregate(),
// this rule check's the logicalProject node's isDisinct property
// this rule check's the logicalProject node's isDistinct property
// and replace the logicalProject node with a LogicalAggregate node
// so any rule before this, if create a new logicalProject node
// should make sure isDistinct property is correctly passed around.

View File

@ -145,19 +145,14 @@ public class NereidsRewriter extends BatchRewriteJob {
// but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained,
// resulting in inconsistent output results and results in apply
topDown(
new SimplifyAggGroupBy(),
new NormalizeAggregate(),
new NormalizeSort()
),
topic("Window analysis",
topDown(
new SimplifyAggGroupBy()
),
topDown(
new ExtractAndNormalizeWindowExpression(),
// execute NormalizeAggregate() again to resolve nested AggregateFunctions in WindowExpression,
// e.g. sum(sum(c1)) over(partition by avg(c1))
new NormalizeAggregate(),
new CheckAndStandardizeWindowFunctionAndFrame()
)
),

View File

@ -66,7 +66,6 @@ public class TwoPhaseReadOpt extends PlanPostProcessor {
@Override
public PhysicalTopN visitPhysicalTopN(PhysicalTopN<? extends Plan> mergeTopN, CascadesContext ctx) {
mergeTopN.child().accept(this, ctx);
Plan child = mergeTopN.child();
if (mergeTopN.getSortPhase() != SortPhase.MERGE_SORT || !(mergeTopN.child() instanceof PhysicalDistribute)) {
return mergeTopN;
}
@ -92,7 +91,7 @@ public class TwoPhaseReadOpt extends PlanPostProcessor {
PhysicalOlapScan olapScan;
PhysicalProject<Plan> project = null;
PhysicalFilter<Plan> filter = null;
child = localTopN.child();
Plan child = localTopN.child();
while (child instanceof Project || child instanceof Filter) {
if (child instanceof Filter) {
filter = (PhysicalFilter<Plan>) child;

View File

@ -22,13 +22,14 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import com.google.common.collect.ImmutableList;
/**
* ProjectToGlobalAggregate.
*
* <p>
* example sql:
* <pre>
* select sum(value)
@ -36,7 +37,7 @@ import com.google.common.collect.ImmutableList;
* </pre>
*
* origin plan: transformed plan:
*
* <p>
* LogicalProject(projects=[sum(value)]) LogicalAggregate(groupBy=[], output=[sum(value)])
* | => |
* LogicalOlapScan(table=tbl) LogicalOlapScan(table=tbl)
@ -48,7 +49,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
logicalProject().then(project -> {
boolean needGlobalAggregate = project.getProjects()
.stream()
.anyMatch(this::hasNonWindowedAggregateFunction);
.anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE, null));
if (needGlobalAggregate) {
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
@ -59,9 +60,31 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
);
}
private boolean hasNonWindowedAggregateFunction(Expression expression) {
return expression.anyMatch(WindowExpression.class::isInstance)
? false
: expression.anyMatch(AggregateFunction.class::isInstance);
private static class NeedAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
private static final NeedAggregateChecker INSTANCE = new NeedAggregateChecker();
@Override
public Boolean visit(Expression expr, Void context) {
boolean needAggregate = false;
for (Expression child : expr.children()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
boolean needAggregate = false;
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
return true;
}
}
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -99,7 +100,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
aggregate.getOutputExpressions(), Alias.class::isInstance);
Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow(
aggregate.getOutputExpressions());
Set<Expression> existsAggAlias = existsAliases.stream().map(alias -> alias.child())
Set<Expression> existsAggAlias = existsAliases.stream().map(UnaryNode::child)
.filter(AggregateFunction.class::isInstance)
.collect(Collectors.toSet());

View File

@ -33,6 +33,7 @@ import java.util.Objects;
* e.g. group_concat(id, ',' order by num desc), the num desc is order expression
*/
public class OrderExpression extends Expression implements UnaryExpression, PropagateNullable {
private final OrderKey orderKey;
public OrderExpression(OrderKey orderKey) {

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -93,7 +94,7 @@ public class WindowExpression extends Expression {
expressions.addAll(function.children());
expressions.addAll(partitionKeys);
expressions.addAll(orderKeys.stream()
.map(orderExpression -> orderExpression.child())
.map(UnaryNode::child)
.collect(Collectors.toList()));
return expressions;
}
@ -115,17 +116,13 @@ public class WindowExpression extends Expression {
}
public WindowExpression withOrderKeyList(List<OrderExpression> orderKeyList) {
if (windowFrame.isPresent()) {
return new WindowExpression(function, partitionKeys, orderKeyList, windowFrame.get());
}
return new WindowExpression(function, partitionKeys, orderKeyList);
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeyList, frame))
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeyList));
}
public WindowExpression withFunction(Expression function) {
if (windowFrame.isPresent()) {
return new WindowExpression(function, partitionKeys, orderKeys, windowFrame.get());
}
return new WindowExpression(function, partitionKeys, orderKeys);
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame))
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
}
@Override
@ -177,7 +174,7 @@ public class WindowExpression extends Expression {
@Override
public String toSql() {
StringBuilder sb = new StringBuilder();
sb.append(function.toSql() + " OVER(");
sb.append(function.toSql()).append(" OVER(");
if (!partitionKeys.isEmpty()) {
sb.append("PARTITION BY ").append(partitionKeys.stream()
.map(Expression::toSql)
@ -195,7 +192,7 @@ public class WindowExpression extends Expression {
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(function + " WindowSpec(");
sb.append(function).append(" WindowSpec(");
if (!partitionKeys.isEmpty()) {
sb.append("PARTITION BY ").append(partitionKeys.stream()
.map(Expression::toString)