[enhancement](Nereids) reduce CostAndEnforcerJob call times (#14442)

record pruned plan's cost to avoid optimize same GroupExpression more than once.
This commit is contained in:
morrySnow
2022-11-23 16:57:41 +08:00
committed by GitHub
parent 388f067300
commit 8d5eabb64f
8 changed files with 174 additions and 190 deletions

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids;
import org.apache.doris.common.Id;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -26,7 +25,6 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.List;
@ -37,8 +35,6 @@ import java.util.List;
* Inspired by GPORCA-CExpressionHandle.
*/
public class PlanContext {
// array of children's derived stats
private final List<StatsDeriveResult> childrenStats;
// attached group expression
private final GroupExpression groupExpression;
@ -47,21 +43,12 @@ public class PlanContext {
*/
public PlanContext(GroupExpression groupExpression) {
this.groupExpression = groupExpression;
childrenStats = Lists.newArrayListWithCapacity(groupExpression.children().size());
for (Group group : groupExpression.children()) {
childrenStats.add(group.getStatistics());
}
}
public GroupExpression getGroupExpression() {
return groupExpression;
}
public List<StatsDeriveResult> getChildrenStats() {
return childrenStats;
}
public StatsDeriveResult getStatisticsWithCheck() {
StatsDeriveResult statistics = groupExpression.getOwnerGroup().getStatistics();
Preconditions.checkNotNull(statistics);
@ -84,7 +71,7 @@ public class PlanContext {
* Get child statistics.
*/
public StatsDeriveResult getChildStatistics(int index) {
StatsDeriveResult statistics = childrenStats.get(index);
StatsDeriveResult statistics = groupExpression.child(index).getStatistics();
Preconditions.checkNotNull(statistics);
return statistics;
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.common.Id;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpec;
@ -41,8 +40,6 @@ import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* Calculate the cost of a plan.
* Inspired by Presto.
@ -187,25 +184,22 @@ public class CostCalculator {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult inputStatistics = context.getChildStatistics(0);
return CostEstimate.of(inputStatistics.computeSize(), statistics.computeSize(), 0);
return CostEstimate.of(inputStatistics.getRowCount(), statistics.computeSize(), 0);
}
@Override
public CostEstimate visitPhysicalHashJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin, PlanContext context) {
Preconditions.checkState(context.getGroupExpression().arity() == 2);
Preconditions.checkState(context.getChildrenStats().size() == 2);
StatsDeriveResult outputStats = physicalHashJoin.getGroupExpression().get().getOwnerGroup().getStatistics();
double outputRowCount = outputStats.computeSize();
double outputRowCount = outputStats.getRowCount();
StatsDeriveResult probeStats = context.getChildStatistics(0);
StatsDeriveResult buildStats = context.getChildStatistics(1);
List<Id> leftIds = context.getChildOutputIds(0);
List<Id> rightIds = context.getChildOutputIds(1);
double leftRowCount = probeStats.computeColumnSize(leftIds);
double rightRowCount = buildStats.computeColumnSize(rightIds);
double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
/*
pattern1: L join1 (Agg1() join2 Agg2())
result number of join2 may much less than Agg1.
@ -240,7 +234,6 @@ public class CostCalculator {
PlanContext context) {
// TODO: copy from physicalHashJoin, should update according to physical nested loop join properties.
Preconditions.checkState(context.getGroupExpression().arity() == 2);
Preconditions.checkState(context.getChildrenStats().size() == 2);
StatsDeriveResult leftStatistics = context.getChildStatistics(0);
StatsDeriveResult rightStatistics = context.getChildStatistics(1);

View File

@ -80,8 +80,7 @@ public abstract class Job {
* @param candidateRules rules to be applied
* @return all rules that can be applied on this group expression
*/
public List<Rule> getValidRules(GroupExpression groupExpression,
List<Rule> candidateRules) {
public List<Rule> getValidRules(GroupExpression groupExpression, List<Rule> candidateRules) {
return candidateRules.stream()
.filter(rule -> Objects.nonNull(rule) && rule.getPattern().matchRoot(groupExpression.getPlan())
&& groupExpression.notApplied(rule)).collect(Collectors.toList());

View File

@ -41,6 +41,7 @@ import java.util.Optional;
* Inspired by NoisePage and ORCA-Paper.
*/
public class CostAndEnforcerJob extends Job implements Cloneable {
// GroupExpression to optimize
private final GroupExpression groupExpression;
@ -165,7 +166,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
curTotalCost += lowestCostExpr.getLowestCostTable().get(requestChildProperty).first;
if (curTotalCost > context.getCostUpperBound()) {
break;
curTotalCost = Double.POSITIVE_INFINITY;
}
// the request child properties will be covered by the output properties
// that corresponding to the request properties. so if we run a costAndEnforceJob of the same

View File

@ -53,7 +53,7 @@ public class Group {
// Map of cost lower bounds
// Map required plan props to cost lower bound of corresponding plan
private final Map<PhysicalProperties, Pair<Double, GroupExpression>> lowestCostPlans = Maps.newHashMap();
private double costLowerBound = -1;
private boolean isExplored = false;
private StatsDeriveResult statistics;
@ -101,87 +101,11 @@ public class Group {
return groupExpression;
}
/**
* Remove groupExpression from this group.
*
* @param groupExpression to be removed
* @return removed {@link GroupExpression}
*/
public GroupExpression removeGroupExpression(GroupExpression groupExpression) {
if (groupExpression.getPlan() instanceof LogicalPlan) {
logicalExpressions.remove(groupExpression);
} else {
physicalExpressions.remove(groupExpression);
}
groupExpression.setOwnerGroup(null);
return groupExpression;
}
public void addLogicalExpression(GroupExpression groupExpression) {
groupExpression.setOwnerGroup(this);
logicalExpressions.add(groupExpression);
}
public List<GroupExpression> clearLogicalExpressions() {
List<GroupExpression> move = logicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
logicalExpressions.clear();
return move;
}
public List<GroupExpression> clearPhysicalExpressions() {
List<GroupExpression> move = physicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
physicalExpressions.clear();
return move;
}
public double getCostLowerBound() {
return costLowerBound;
}
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
if (lowestCostPlans.get(properties).first >= cost) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
}
public GroupExpression getBestPlan(PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
return lowestCostPlans.get(properties).second;
}
return null;
}
/**
* replace best plan with new properties
*/
public void replaceBestPlan(PhysicalProperties oldProperty, PhysicalProperties newProperty, double cost) {
Pair<Double, GroupExpression> pair = lowestCostPlans.get(oldProperty);
GroupExpression lowestGroupExpr = pair.second;
lowestGroupExpr.updateLowestCostTable(newProperty,
lowestGroupExpr.getInputPropertiesList(oldProperty), cost);
lowestCostPlans.remove(oldProperty);
lowestCostPlans.put(newProperty, pair);
}
public StatsDeriveResult getStatistics() {
return statistics;
}
public void setStatistics(StatsDeriveResult statistics) {
this.statistics = statistics;
}
public List<GroupExpression> getLogicalExpressions() {
return logicalExpressions;
}
@ -208,20 +132,40 @@ public class Group {
return physicalExpressions;
}
public LogicalProperties getLogicalProperties() {
return logicalProperties;
/**
* Remove groupExpression from this group.
*
* @param groupExpression to be removed
* @return removed {@link GroupExpression}
*/
public GroupExpression removeGroupExpression(GroupExpression groupExpression) {
if (groupExpression.getPlan() instanceof LogicalPlan) {
logicalExpressions.remove(groupExpression);
} else {
physicalExpressions.remove(groupExpression);
}
groupExpression.setOwnerGroup(null);
return groupExpression;
}
public void setLogicalProperties(LogicalProperties logicalProperties) {
this.logicalProperties = logicalProperties;
public List<GroupExpression> clearLogicalExpressions() {
List<GroupExpression> move = logicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
logicalExpressions.clear();
return move;
}
public boolean isExplored() {
return isExplored;
public List<GroupExpression> clearPhysicalExpressions() {
List<GroupExpression> move = physicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
physicalExpressions.clear();
return move;
}
public void setExplored(boolean explored) {
isExplored = explored;
public double getCostLowerBound() {
return -1D;
}
/**
@ -238,6 +182,62 @@ public class Group {
return Optional.ofNullable(lowestCostPlans.get(physicalProperties));
}
public GroupExpression getBestPlan(PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
return lowestCostPlans.get(properties).second;
}
return null;
}
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
if (lowestCostPlans.get(properties).first >= cost) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
}
/**
* replace best plan with new properties
*/
public void replaceBestPlan(PhysicalProperties oldProperty, PhysicalProperties newProperty, double cost) {
Pair<Double, GroupExpression> pair = lowestCostPlans.get(oldProperty);
GroupExpression lowestGroupExpr = pair.second;
lowestGroupExpr.updateLowestCostTable(newProperty,
lowestGroupExpr.getInputPropertiesList(oldProperty), cost);
lowestCostPlans.remove(oldProperty);
lowestCostPlans.put(newProperty, pair);
}
public StatsDeriveResult getStatistics() {
return statistics;
}
public void setStatistics(StatsDeriveResult statistics) {
this.statistics = statistics;
}
public LogicalProperties getLogicalProperties() {
return logicalProperties;
}
public void setLogicalProperties(LogicalProperties logicalProperties) {
this.logicalProperties = logicalProperties;
}
public boolean isExplored() {
return isExplored;
}
public void setExplored(boolean explored) {
isExplored = explored;
}
public List<GroupExpression> getParentGroupExpressions() {
return ImmutableList.copyOf(parentExpressions.keySet());
}
@ -257,6 +257,65 @@ public class Group {
return parentExpressions.size();
}
/**
* move the ownerGroup of all logical expressions to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void moveLogicalExpressionOwnership(Group target) {
if (equals(target)) {
return;
}
for (GroupExpression expression : logicalExpressions) {
target.addGroupExpression(expression);
}
logicalExpressions.clear();
}
/**
* move the ownerGroup of all physical expressions to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void movePhysicalExpressionOwnership(Group target) {
if (equals(target)) {
return;
}
for (GroupExpression expression : physicalExpressions) {
target.addGroupExpression(expression);
}
physicalExpressions.clear();
}
/**
* move the ownerGroup of all lowestCostPlans to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void moveLowestCostPlansOwnership(Group target) {
if (equals(target)) {
return;
}
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
GroupExpression bestGroupExpression = costAndGroupExpr.second;
// change into target group.
if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
bestGroupExpression.setOwnerGroup(target);
}
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {
if (costAndGroupExpr.first < target.lowestCostPlans.get(physicalProperties).first) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
}
}
});
lowestCostPlans.clear();
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -322,63 +381,4 @@ public class Group {
};
return TreeStringUtils.treeString(this, toString, getChildren);
}
/**
* move the ownerGroup of all logical expressions to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void moveLogicalExpressionOwnership(Group target) {
if (equals(target)) {
return;
}
for (GroupExpression expression : logicalExpressions) {
target.addGroupExpression(expression);
}
logicalExpressions.clear();
}
/**
* move the ownerGroup of all physical expressions to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void movePhysicalExpressionOwnership(Group target) {
if (equals(target)) {
return;
}
for (GroupExpression expression : physicalExpressions) {
target.addGroupExpression(expression);
}
physicalExpressions.clear();
}
/**
* move the ownerGroup of all lowestCostPlans to target group
* if this.equals(target), do nothing.
*
* @param target the new owner group of expressions
*/
public void moveLowestCostPlansOwnership(Group target) {
if (equals(target)) {
return;
}
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
GroupExpression bestGroupExpression = costAndGroupExpr.second;
// change into target group.
if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
bestGroupExpression.setOwnerGroup(target);
}
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {
if (costAndGroupExpr.first < target.lowestCostPlans.get(physicalProperties).first) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
}
}
});
lowestCostPlans.clear();
}
}

View File

@ -85,17 +85,17 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
public Void visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, PlanContext context) {
// 1. first phase agg just return any
if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) {
addToRequestPropertyToChildren(PhysicalProperties.ANY);
addRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}
if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) {
addToRequestPropertyToChildren(requestPropertyFromParent);
addRequestPropertyToChildren(requestPropertyFromParent);
return null;
}
// 2. second phase agg, need to return shuffle with partition key
List<Expression> partitionExpressions = agg.getPartitionExpressions();
if (partitionExpressions.isEmpty()) {
addToRequestPropertyToChildren(PhysicalProperties.GATHER);
addRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
// TODO: when parent is a join node,
@ -105,7 +105,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
addToRequestPropertyToChildren(
addRequestPropertyToChildren(
PhysicalProperties.createHash(new DistributionSpecHash(partitionedSlots, ShuffleType.AGGREGATE)));
return null;
}
@ -118,34 +118,33 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
@Override
public Void visitPhysicalQuickSort(PhysicalQuickSort<? extends Plan> sort, PlanContext context) {
addToRequestPropertyToChildren(PhysicalProperties.ANY);
addRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}
@Override
public Void visitPhysicalLocalQuickSort(PhysicalLocalQuickSort<? extends Plan> sort, PlanContext context) {
// TODO: rethink here, should we throw exception directly?
addToRequestPropertyToChildren(PhysicalProperties.ANY);
addRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}
@Override
public Void visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, PlanContext context) {
// for broadcast join
if (JoinUtils.couldBroadcast(hashJoin)) {
addToRequestPropertyToChildren(PhysicalProperties.ANY, PhysicalProperties.REPLICATED);
}
// for shuffle join
if (JoinUtils.couldShuffle(hashJoin)) {
Pair<List<ExprId>, List<ExprId>> onClauseUsedSlots = JoinUtils.getOnClauseUsedSlots(hashJoin);
// shuffle join
addToRequestPropertyToChildren(
addRequestPropertyToChildren(
PhysicalProperties.createHash(
new DistributionSpecHash(onClauseUsedSlots.first, ShuffleType.JOIN)),
PhysicalProperties.createHash(
new DistributionSpecHash(onClauseUsedSlots.second, ShuffleType.JOIN)));
}
// for broadcast join
if (JoinUtils.couldBroadcast(hashJoin)) {
addRequestPropertyToChildren(PhysicalProperties.ANY, PhysicalProperties.REPLICATED);
}
return null;
}
@ -154,13 +153,13 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
public Void visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin, PlanContext context) {
// TODO: currently doris only use NLJ to do cross join, update this if we use NLJ to do other joins.
addToRequestPropertyToChildren(PhysicalProperties.ANY, PhysicalProperties.REPLICATED);
addRequestPropertyToChildren(PhysicalProperties.ANY, PhysicalProperties.REPLICATED);
return null;
}
@Override
public Void visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, PlanContext context) {
addToRequestPropertyToChildren(PhysicalProperties.GATHER);
addRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
@ -168,7 +167,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
* helper function to assemble request children physical properties
* @param physicalProperties one set request properties for children
*/
private void addToRequestPropertyToChildren(PhysicalProperties... physicalProperties) {
private void addRequestPropertyToChildren(PhysicalProperties... physicalProperties) {
requestPropertyToChildren.add(Lists.newArrayList(physicalProperties));
}
}

View File

@ -31,6 +31,7 @@ import java.util.Set;
*/
public class StatsDeriveResult {
private final double rowCount;
private double computeSize = -1D;
private int width = 1;
private double penalty = 0.0;
@ -79,8 +80,12 @@ public class StatsDeriveResult {
}
public double computeSize() {
return Math.max(1, slotIdToColumnStats.values().stream().map(s -> s.dataSize).reduce(0D, Double::sum))
* rowCount;
if (computeSize < 0) {
computeSize = Math.max(1, slotIdToColumnStats.values().stream()
.map(s -> s.dataSize).reduce(0D, Double::sum)
) * rowCount;
}
return computeSize;
}
/**

View File

@ -132,11 +132,11 @@ public class RequestPropertyDeriverTest {
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
List<List<PhysicalProperties>> expected = Lists.newArrayList();
expected.add(Lists.newArrayList(PhysicalProperties.ANY, PhysicalProperties.REPLICATED));
expected.add(Lists.newArrayList(
new PhysicalProperties(new DistributionSpecHash(Lists.newArrayList(new ExprId(0)), ShuffleType.JOIN)),
new PhysicalProperties(new DistributionSpecHash(Lists.newArrayList(new ExprId(1)), ShuffleType.JOIN))
));
expected.add(Lists.newArrayList(PhysicalProperties.ANY, PhysicalProperties.REPLICATED));
Assertions.assertEquals(expected, actual);
}