[feature](Nereids)enable CBO optimize stage in Nereids (#12008)

- enable CBO stage in Nereids
- use the `chooseBestPlan()` to get the best plan
- add a new rule JoinCommuteProject
- test the stage by JoinCommute rule
This commit is contained in:
jakevin
2022-08-30 21:28:17 +08:00
committed by GitHub
parent dda446f490
commit 59e5527eb0
9 changed files with 199 additions and 75 deletions

View File

@ -114,10 +114,7 @@ public class NereidsPlanner extends Planner {
// cost-based optimize and explode plan space
optimize();
// Get plan directly. Just for SSB.
PhysicalPlan physicalPlan = getRoot().extractPlan();
// TODO: remove above
// PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
// post-process physical plan out of memo, just for future use.
return postprocess(physicalPlan);

View File

@ -72,13 +72,12 @@ public class CostCalculator {
}
@Override
public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.computeSize());
public CostEstimate visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostEstimate.ofCpu(1);
}
@Override
public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort physicalQuickSort, PlanContext context) {
public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort<Plan> physicalQuickSort, PlanContext context) {
// TODO: consider two-phase sort and enforcer.
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);
@ -102,7 +101,8 @@ public class CostCalculator {
}
@Override
public CostEstimate visitPhysicalDistribution(PhysicalDistribution physicalDistribution, PlanContext context) {
public CostEstimate visitPhysicalDistribution(PhysicalDistribution<Plan> physicalDistribution,
PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
StatsDeriveResult childStatistics = context.getChildStatistics(0);

View File

@ -42,25 +42,28 @@ import java.util.Optional;
public class CostAndEnforcerJob extends Job implements Cloneable {
// GroupExpression to optimize
private final GroupExpression groupExpression;
// Current total cost
private double curTotalCost;
// Children properties from parent plan node.
// cost of current plan tree
private double curTotalCost;
// cost of current plan node
private double curNodeCost;
// List of request property to children
// Example: Physical Hash Join
// [ child item: [leftProperties, rightPropertie]]
// [ [Properties {"", ANY}, Properties {"", BROADCAST}],
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}]]
private List<List<PhysicalProperties>> requestChildrenPropertyList;
// index of List<request property to children>
private int requestPropertyIndex = 0;
private List<GroupExpression> childrenBestGroupExprList;
private List<GroupExpression> childrenBestGroupExprList = Lists.newArrayList();
private final List<PhysicalProperties> childrenOutputProperty = Lists.newArrayList();
// Current stage of enumeration through child groups
// current child index of travsing all children
private int curChildIndex = -1;
// Indicator of last child group that we waited for optimization
// child index in the last time of travsing all children
private int prevChildIndex = -1;
// Current stage of enumeration through outputInputProperties
private int curPropertyPairIndex = 0;
public CostAndEnforcerJob(GroupExpression groupExpression, JobContext context) {
super(JobType.OPTIMIZE_CHILDREN, context);
@ -99,22 +102,26 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
public void execute() {
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
if (curChildIndex == -1) {
curNodeCost = 0;
curTotalCost = 0;
// Get property from groupExpression plan (it's root of subplan).
curChildIndex = 0;
// List<request property to children>
// [ child item: [leftProperties, rightPropertie]]
// like :[ [Properties {"", ANY}, Properties {"", BROADCAST}],
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ]
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context);
requestChildrenPropertyList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
curChildIndex = 0;
}
for (; curPropertyPairIndex < requestChildrenPropertyList.size(); curPropertyPairIndex++) {
// children input properties
List<PhysicalProperties> requestChildrenProperty = requestChildrenPropertyList.get(curPropertyPairIndex);
for (; requestPropertyIndex < requestChildrenPropertyList.size(); requestPropertyIndex++) {
// Get one from List<request property to children>
// like: [ Properties {"", ANY}, Properties {"", BROADCAST} ],
List<PhysicalProperties> requestChildrenProperty = requestChildrenPropertyList.get(requestPropertyIndex);
// Calculate cost of groupExpression and update total cost
// Calculate cost
if (curChildIndex == 0 && prevChildIndex == -1) {
curTotalCost += CostCalculator.calculateCost(groupExpression);
curNodeCost = CostCalculator.calculateCost(groupExpression);
curTotalCost += curNodeCost;
}
// Handle all child plannode.
@ -177,6 +184,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
}
StatsCalculator.estimate(groupExpression);
curTotalCost -= curNodeCost;
curNodeCost = CostCalculator.calculateCost(groupExpression);
curTotalCost += curNodeCost;
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },
recordPropertyAndCost(groupExpression, outputProperty, outputProperty, requestChildrenProperty);
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, requestChildrenProperty);
@ -188,11 +198,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
}
}
// Reset child idx and total cost
childrenOutputProperty.clear();
prevChildIndex = -1;
curChildIndex = 0;
curTotalCost = 0;
clear();
}
}
@ -230,6 +236,14 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
curTotalCost, requestProperty);
}
private void clear() {
childrenOutputProperty.clear();
childrenBestGroupExprList.clear();
prevChildIndex = -1;
curChildIndex = 0;
curTotalCost = 0;
curNodeCost = 0;
}
/**
* Shallow clone (ignore clone propertiesListList and groupExpression).

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import org.apache.doris.nereids.rules.implementation.LogicalFilterToPhysicalFilter;
import org.apache.doris.nereids.rules.implementation.LogicalJoinToHashJoin;
@ -40,6 +41,7 @@ import java.util.List;
public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
.add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG)
.build();
public static final List<Rule> REWRITE_RULES = planRuleFactories()

View File

@ -21,22 +21,21 @@ import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* rule factory for exchange inner join's children.
* Join Commute
*/
@Developing
public class JoinCommute extends OneExplorationRuleFactory {
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
private final SwapType swapType;
private final boolean swapOuter;
private final SwapType swapType;
public JoinCommute(boolean swapOuter) {
this.swapOuter = swapOuter;
@ -48,14 +47,12 @@ public class JoinCommute extends OneExplorationRuleFactory {
this.swapType = swapType;
}
enum SwapType {
BOTTOM_JOIN, ZIG_ZAG, ALL
}
@Override
public Rule build() {
return innerLogicalJoin().when(this::check).then(join -> {
LogicalJoin newJoin = new LogicalJoin(
return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> {
// TODO: add project for mapping column output.
// List<NamedExpression> newOutput = new ArrayList<>(join.getOutput());
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
@ -66,37 +63,8 @@ public class JoinCommute extends OneExplorationRuleFactory {
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
// }
// LogicalProject<LogicalJoin> project = new LogicalProject<>(newOutput, newJoin);
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
}
private boolean check(LogicalJoin join) {
if (!(join.left() instanceof LogicalPlan) || !(join.right() instanceof LogicalPlan)) {
return false;
}
if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin(join)) {
return false;
}
if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) {
return false;
}
return true;
}
private boolean isBottomJoin(LogicalJoin join) {
// TODO: filter need to be considered?
if (join.left() instanceof LogicalProject && ((LogicalProject) join.left()).child() instanceof LogicalJoin) {
return false;
}
if (join.right() instanceof LogicalProject && ((LogicalProject) join.right()).child() instanceof LogicalJoin) {
return false;
}
if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) {
return false;
}
return true;
}
}

View File

@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* Common function for JoinCommute
*/
public class JoinCommuteHelper {
enum SwapType {
BOTTOM_JOIN, ZIG_ZAG, ALL
}
private final boolean swapOuter;
private final SwapType swapType;
public JoinCommuteHelper(boolean swapOuter, SwapType swapType) {
this.swapOuter = swapOuter;
this.swapType = swapType;
}
public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
}
public static boolean check(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return check(project.child());
}
}

View File

@ -0,0 +1,66 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
/**
* Project-Join commute
*/
public class JoinCommuteProject extends OneExplorationRuleFactory {
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
private final SwapType swapType;
private final boolean swapOuter;
public JoinCommuteProject(boolean swapOuter) {
this.swapOuter = swapOuter;
this.swapType = SwapType.ALL;
}
public JoinCommuteProject(boolean swapOuter, SwapType swapType) {
this.swapOuter = swapOuter;
this.swapType = swapType;
}
@Override
public Rule build() {
return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
// }
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
}
}

View File

@ -18,18 +18,22 @@
package org.apache.doris.nereids.stats;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Estimate hash join stats.
@ -44,12 +48,17 @@ public class JoinEstimation {
JoinType joinType = join.getJoinType();
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
statsDeriveResult.merge(rightStats);
List<Expression> eqConjunctList = join.getHashJoinConjuncts();
// TODO: normalize join hashConjuncts.
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
List<Slot> leftSlot = new ArrayList<>(leftStats.getSlotToColumnStats().keySet());
List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftSlot))
.collect(Collectors.toList());
long rowCount = -1;
if (joinType.isSemiOrAntiJoin()) {
rowCount = getSemiJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isInnerJoin() || joinType.isOuterJoin()) {
rowCount = getJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isCrossJoin()) {
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
rightStats.getRowCount());

View File

@ -21,6 +21,8 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -202,4 +204,22 @@ public class JoinUtils {
JoinType joinType = join.getJoinType();
return (joinType.isInnerJoin() && join.getHashJoinConjuncts().isEmpty()) || joinType.isCrossJoin();
}
/**
* The left and right child of origin predicates need to be swap sometimes.
* Case A:
* select * from t1 join t2 on t2.id=t1.id
* The left plan node is t1 and the right plan node is t2.
* The left child of origin predicate is t2.id and the right child of origin predicate is t1.id.
* In this situation, the children of predicate need to be swap => t1.id=t2.id.
*/
public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, List<Slot> leftOutput) {
Set<ExprId> leftSlots = SlotExtractor.extractSlot(equalTo.left()).stream()
.map(NamedExpression::getExprId).collect(Collectors.toSet());
if (leftOutput.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()).containsAll(leftSlots)) {
return equalTo;
} else {
return new EqualTo(equalTo.right(), equalTo.left());
}
}
}