[opt](Nereids)(WIP) optimize agg and window normalization step 1 (#19168)
1. move SimplifyAggGroupBy behind NormalizeAggregate 2. fix project to agg rule for the project containing window expression
This commit is contained in:
@ -59,7 +59,7 @@ public class NereidsAnalyzer extends BatchRewriteJob {
|
||||
),
|
||||
bottomUp(
|
||||
new ProjectToGlobalAggregate(),
|
||||
// this rule check's the logicalProject node's isDisinct property
|
||||
// this rule check's the logicalProject node's isDistinct property
|
||||
// and replace the logicalProject node with a LogicalAggregate node
|
||||
// so any rule before this, if create a new logicalProject node
|
||||
// should make sure isDistinct property is correctly passed around.
|
||||
|
||||
@ -145,19 +145,14 @@ public class NereidsRewriter extends BatchRewriteJob {
|
||||
// but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained,
|
||||
// resulting in inconsistent output results and results in apply
|
||||
topDown(
|
||||
new SimplifyAggGroupBy(),
|
||||
new NormalizeAggregate(),
|
||||
new NormalizeSort()
|
||||
),
|
||||
|
||||
topic("Window analysis",
|
||||
topDown(
|
||||
new SimplifyAggGroupBy()
|
||||
),
|
||||
topDown(
|
||||
new ExtractAndNormalizeWindowExpression(),
|
||||
// execute NormalizeAggregate() again to resolve nested AggregateFunctions in WindowExpression,
|
||||
// e.g. sum(sum(c1)) over(partition by avg(c1))
|
||||
new NormalizeAggregate(),
|
||||
new CheckAndStandardizeWindowFunctionAndFrame()
|
||||
)
|
||||
),
|
||||
|
||||
@ -66,7 +66,6 @@ public class TwoPhaseReadOpt extends PlanPostProcessor {
|
||||
@Override
|
||||
public PhysicalTopN visitPhysicalTopN(PhysicalTopN<? extends Plan> mergeTopN, CascadesContext ctx) {
|
||||
mergeTopN.child().accept(this, ctx);
|
||||
Plan child = mergeTopN.child();
|
||||
if (mergeTopN.getSortPhase() != SortPhase.MERGE_SORT || !(mergeTopN.child() instanceof PhysicalDistribute)) {
|
||||
return mergeTopN;
|
||||
}
|
||||
@ -92,7 +91,7 @@ public class TwoPhaseReadOpt extends PlanPostProcessor {
|
||||
PhysicalOlapScan olapScan;
|
||||
PhysicalProject<Plan> project = null;
|
||||
PhysicalFilter<Plan> filter = null;
|
||||
child = localTopN.child();
|
||||
Plan child = localTopN.child();
|
||||
while (child instanceof Project || child instanceof Filter) {
|
||||
if (child instanceof Filter) {
|
||||
filter = (PhysicalFilter<Plan>) child;
|
||||
|
||||
@ -22,13 +22,14 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
/**
|
||||
* ProjectToGlobalAggregate.
|
||||
*
|
||||
* <p>
|
||||
* example sql:
|
||||
* <pre>
|
||||
* select sum(value)
|
||||
@ -36,7 +37,7 @@ import com.google.common.collect.ImmutableList;
|
||||
* </pre>
|
||||
*
|
||||
* origin plan: transformed plan:
|
||||
*
|
||||
* <p>
|
||||
* LogicalProject(projects=[sum(value)]) LogicalAggregate(groupBy=[], output=[sum(value)])
|
||||
* | => |
|
||||
* LogicalOlapScan(table=tbl) LogicalOlapScan(table=tbl)
|
||||
@ -48,7 +49,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
|
||||
logicalProject().then(project -> {
|
||||
boolean needGlobalAggregate = project.getProjects()
|
||||
.stream()
|
||||
.anyMatch(this::hasNonWindowedAggregateFunction);
|
||||
.anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE, null));
|
||||
|
||||
if (needGlobalAggregate) {
|
||||
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
|
||||
@ -59,9 +60,31 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
|
||||
);
|
||||
}
|
||||
|
||||
private boolean hasNonWindowedAggregateFunction(Expression expression) {
|
||||
return expression.anyMatch(WindowExpression.class::isInstance)
|
||||
? false
|
||||
: expression.anyMatch(AggregateFunction.class::isInstance);
|
||||
private static class NeedAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
|
||||
|
||||
private static final NeedAggregateChecker INSTANCE = new NeedAggregateChecker();
|
||||
|
||||
@Override
|
||||
public Boolean visit(Expression expr, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : expr.children()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
@ -99,7 +100,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
aggregate.getOutputExpressions(), Alias.class::isInstance);
|
||||
Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow(
|
||||
aggregate.getOutputExpressions());
|
||||
Set<Expression> existsAggAlias = existsAliases.stream().map(alias -> alias.child())
|
||||
Set<Expression> existsAggAlias = existsAliases.stream().map(UnaryNode::child)
|
||||
.filter(AggregateFunction.class::isInstance)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ import java.util.Objects;
|
||||
* e.g. group_concat(id, ',' order by num desc), the num desc is order expression
|
||||
*/
|
||||
public class OrderExpression extends Expression implements UnaryExpression, PropagateNullable {
|
||||
|
||||
private final OrderKey orderKey;
|
||||
|
||||
public OrderExpression(OrderKey orderKey) {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -93,7 +94,7 @@ public class WindowExpression extends Expression {
|
||||
expressions.addAll(function.children());
|
||||
expressions.addAll(partitionKeys);
|
||||
expressions.addAll(orderKeys.stream()
|
||||
.map(orderExpression -> orderExpression.child())
|
||||
.map(UnaryNode::child)
|
||||
.collect(Collectors.toList()));
|
||||
return expressions;
|
||||
}
|
||||
@ -115,17 +116,13 @@ public class WindowExpression extends Expression {
|
||||
}
|
||||
|
||||
public WindowExpression withOrderKeyList(List<OrderExpression> orderKeyList) {
|
||||
if (windowFrame.isPresent()) {
|
||||
return new WindowExpression(function, partitionKeys, orderKeyList, windowFrame.get());
|
||||
}
|
||||
return new WindowExpression(function, partitionKeys, orderKeyList);
|
||||
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeyList, frame))
|
||||
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeyList));
|
||||
}
|
||||
|
||||
public WindowExpression withFunction(Expression function) {
|
||||
if (windowFrame.isPresent()) {
|
||||
return new WindowExpression(function, partitionKeys, orderKeys, windowFrame.get());
|
||||
}
|
||||
return new WindowExpression(function, partitionKeys, orderKeys);
|
||||
return windowFrame.map(frame -> new WindowExpression(function, partitionKeys, orderKeys, frame))
|
||||
.orElseGet(() -> new WindowExpression(function, partitionKeys, orderKeys));
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -177,7 +174,7 @@ public class WindowExpression extends Expression {
|
||||
@Override
|
||||
public String toSql() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(function.toSql() + " OVER(");
|
||||
sb.append(function.toSql()).append(" OVER(");
|
||||
if (!partitionKeys.isEmpty()) {
|
||||
sb.append("PARTITION BY ").append(partitionKeys.stream()
|
||||
.map(Expression::toSql)
|
||||
@ -195,7 +192,7 @@ public class WindowExpression extends Expression {
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(function + " WindowSpec(");
|
||||
sb.append(function).append(" WindowSpec(");
|
||||
if (!partitionKeys.isEmpty()) {
|
||||
sb.append("PARTITION BY ").append(partitionKeys.stream()
|
||||
.map(Expression::toString)
|
||||
|
||||
Reference in New Issue
Block a user