[feature](Nereids)enable CBO optimize stage in Nereids (#12008)
- enable CBO stage in Nereids - use the `chooseBestPlan()` to get the best plan - add a new rule JoinCommuteProject - test the stage by JoinCommute rule
This commit is contained in:
@ -114,10 +114,7 @@ public class NereidsPlanner extends Planner {
|
||||
// cost-based optimize and explode plan space
|
||||
optimize();
|
||||
|
||||
// Get plan directly. Just for SSB.
|
||||
PhysicalPlan physicalPlan = getRoot().extractPlan();
|
||||
// TODO: remove above
|
||||
// PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
|
||||
PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
|
||||
|
||||
// post-process physical plan out of memo, just for future use.
|
||||
return postprocess(physicalPlan);
|
||||
|
||||
@ -72,13 +72,12 @@ public class CostCalculator {
|
||||
}
|
||||
|
||||
@Override
|
||||
public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) {
|
||||
StatsDeriveResult statistics = context.getStatisticsWithCheck();
|
||||
return CostEstimate.ofCpu(statistics.computeSize());
|
||||
public CostEstimate visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
|
||||
return CostEstimate.ofCpu(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort physicalQuickSort, PlanContext context) {
|
||||
public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort<Plan> physicalQuickSort, PlanContext context) {
|
||||
// TODO: consider two-phase sort and enforcer.
|
||||
StatsDeriveResult statistics = context.getStatisticsWithCheck();
|
||||
StatsDeriveResult childStatistics = context.getChildStatistics(0);
|
||||
@ -102,7 +101,8 @@ public class CostCalculator {
|
||||
}
|
||||
|
||||
@Override
|
||||
public CostEstimate visitPhysicalDistribution(PhysicalDistribution physicalDistribution, PlanContext context) {
|
||||
public CostEstimate visitPhysicalDistribution(PhysicalDistribution<Plan> physicalDistribution,
|
||||
PlanContext context) {
|
||||
StatsDeriveResult statistics = context.getStatisticsWithCheck();
|
||||
StatsDeriveResult childStatistics = context.getChildStatistics(0);
|
||||
|
||||
|
||||
@ -42,25 +42,28 @@ import java.util.Optional;
|
||||
public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
// GroupExpression to optimize
|
||||
private final GroupExpression groupExpression;
|
||||
// Current total cost
|
||||
private double curTotalCost;
|
||||
|
||||
// Children properties from parent plan node.
|
||||
// cost of current plan tree
|
||||
private double curTotalCost;
|
||||
// cost of current plan node
|
||||
private double curNodeCost;
|
||||
|
||||
// List of request property to children
|
||||
// Example: Physical Hash Join
|
||||
// [ child item: [leftProperties, rightPropertie]]
|
||||
// [ [Properties {"", ANY}, Properties {"", BROADCAST}],
|
||||
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}]]
|
||||
private List<List<PhysicalProperties>> requestChildrenPropertyList;
|
||||
// index of List<request property to children>
|
||||
private int requestPropertyIndex = 0;
|
||||
|
||||
private List<GroupExpression> childrenBestGroupExprList;
|
||||
private List<GroupExpression> childrenBestGroupExprList = Lists.newArrayList();
|
||||
private final List<PhysicalProperties> childrenOutputProperty = Lists.newArrayList();
|
||||
|
||||
// Current stage of enumeration through child groups
|
||||
// current child index of travsing all children
|
||||
private int curChildIndex = -1;
|
||||
// Indicator of last child group that we waited for optimization
|
||||
// child index in the last time of travsing all children
|
||||
private int prevChildIndex = -1;
|
||||
// Current stage of enumeration through outputInputProperties
|
||||
private int curPropertyPairIndex = 0;
|
||||
|
||||
public CostAndEnforcerJob(GroupExpression groupExpression, JobContext context) {
|
||||
super(JobType.OPTIMIZE_CHILDREN, context);
|
||||
@ -99,22 +102,26 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
public void execute() {
|
||||
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
|
||||
if (curChildIndex == -1) {
|
||||
curNodeCost = 0;
|
||||
curTotalCost = 0;
|
||||
|
||||
// Get property from groupExpression plan (it's root of subplan).
|
||||
curChildIndex = 0;
|
||||
// List<request property to children>
|
||||
// [ child item: [leftProperties, rightPropertie]]
|
||||
// like :[ [Properties {"", ANY}, Properties {"", BROADCAST}],
|
||||
// [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ]
|
||||
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context);
|
||||
requestChildrenPropertyList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
|
||||
|
||||
curChildIndex = 0;
|
||||
}
|
||||
|
||||
for (; curPropertyPairIndex < requestChildrenPropertyList.size(); curPropertyPairIndex++) {
|
||||
// children input properties
|
||||
List<PhysicalProperties> requestChildrenProperty = requestChildrenPropertyList.get(curPropertyPairIndex);
|
||||
for (; requestPropertyIndex < requestChildrenPropertyList.size(); requestPropertyIndex++) {
|
||||
// Get one from List<request property to children>
|
||||
// like: [ Properties {"", ANY}, Properties {"", BROADCAST} ],
|
||||
List<PhysicalProperties> requestChildrenProperty = requestChildrenPropertyList.get(requestPropertyIndex);
|
||||
|
||||
// Calculate cost of groupExpression and update total cost
|
||||
// Calculate cost
|
||||
if (curChildIndex == 0 && prevChildIndex == -1) {
|
||||
curTotalCost += CostCalculator.calculateCost(groupExpression);
|
||||
curNodeCost = CostCalculator.calculateCost(groupExpression);
|
||||
curTotalCost += curNodeCost;
|
||||
}
|
||||
|
||||
// Handle all child plannode.
|
||||
@ -177,6 +184,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
}
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
|
||||
curTotalCost -= curNodeCost;
|
||||
curNodeCost = CostCalculator.calculateCost(groupExpression);
|
||||
curTotalCost += curNodeCost;
|
||||
// record map { outputProperty -> outputProperty }, { ANY -> outputProperty },
|
||||
recordPropertyAndCost(groupExpression, outputProperty, outputProperty, requestChildrenProperty);
|
||||
recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, requestChildrenProperty);
|
||||
@ -188,11 +198,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset child idx and total cost
|
||||
childrenOutputProperty.clear();
|
||||
prevChildIndex = -1;
|
||||
curChildIndex = 0;
|
||||
curTotalCost = 0;
|
||||
clear();
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,6 +236,14 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
curTotalCost, requestProperty);
|
||||
}
|
||||
|
||||
private void clear() {
|
||||
childrenOutputProperty.clear();
|
||||
childrenBestGroupExprList.clear();
|
||||
prevChildIndex = -1;
|
||||
curChildIndex = 0;
|
||||
curTotalCost = 0;
|
||||
curNodeCost = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow clone (ignore clone propertiesListList and groupExpression).
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalFilterToPhysicalFilter;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalJoinToHashJoin;
|
||||
@ -40,6 +41,7 @@ import java.util.List;
|
||||
public class RuleSet {
|
||||
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
|
||||
.add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
|
||||
.add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG)
|
||||
.build();
|
||||
|
||||
public static final List<Rule> REWRITE_RULES = planRuleFactories()
|
||||
|
||||
@ -21,22 +21,21 @@ import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
/**
|
||||
* rule factory for exchange inner join's children.
|
||||
* Join Commute
|
||||
*/
|
||||
@Developing
|
||||
public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
|
||||
|
||||
private final SwapType swapType;
|
||||
private final boolean swapOuter;
|
||||
private final SwapType swapType;
|
||||
|
||||
public JoinCommute(boolean swapOuter) {
|
||||
this.swapOuter = swapOuter;
|
||||
@ -48,14 +47,12 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
enum SwapType {
|
||||
BOTTOM_JOIN, ZIG_ZAG, ALL
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return innerLogicalJoin().when(this::check).then(join -> {
|
||||
LogicalJoin newJoin = new LogicalJoin(
|
||||
return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> {
|
||||
// TODO: add project for mapping column output.
|
||||
// List<NamedExpression> newOutput = new ArrayList<>(join.getOutput());
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
@ -66,37 +63,8 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
// }
|
||||
|
||||
// LogicalProject<LogicalJoin> project = new LogicalProject<>(newOutput, newJoin);
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
}
|
||||
|
||||
|
||||
private boolean check(LogicalJoin join) {
|
||||
if (!(join.left() instanceof LogicalPlan) || !(join.right() instanceof LogicalPlan)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin(join)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean isBottomJoin(LogicalJoin join) {
|
||||
// TODO: filter need to be considered?
|
||||
if (join.left() instanceof LogicalProject && ((LogicalProject) join.left()).child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
if (join.right() instanceof LogicalProject && ((LogicalProject) join.right()).child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,48 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
/**
|
||||
* Common function for JoinCommute
|
||||
*/
|
||||
public class JoinCommuteHelper {
|
||||
|
||||
enum SwapType {
|
||||
BOTTOM_JOIN, ZIG_ZAG, ALL
|
||||
}
|
||||
|
||||
private final boolean swapOuter;
|
||||
private final SwapType swapType;
|
||||
|
||||
public JoinCommuteHelper(boolean swapOuter, SwapType swapType) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
|
||||
}
|
||||
|
||||
public static boolean check(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
|
||||
return check(project.child());
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,66 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
/**
|
||||
* Project-Join commute
|
||||
*/
|
||||
public class JoinCommuteProject extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
|
||||
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
|
||||
|
||||
private final SwapType swapType;
|
||||
private final boolean swapOuter;
|
||||
|
||||
public JoinCommuteProject(boolean swapOuter) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = SwapType.ALL;
|
||||
}
|
||||
|
||||
public JoinCommuteProject(boolean swapOuter, SwapType swapType) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
|
||||
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
// }
|
||||
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
}
|
||||
}
|
||||
@ -18,18 +18,22 @@
|
||||
package org.apache.doris.nereids.stats;
|
||||
|
||||
import org.apache.doris.common.CheckedMath;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Join;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.statistics.ColumnStats;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Estimate hash join stats.
|
||||
@ -44,12 +48,17 @@ public class JoinEstimation {
|
||||
JoinType joinType = join.getJoinType();
|
||||
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
|
||||
statsDeriveResult.merge(rightStats);
|
||||
List<Expression> eqConjunctList = join.getHashJoinConjuncts();
|
||||
// TODO: normalize join hashConjuncts.
|
||||
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
|
||||
List<Slot> leftSlot = new ArrayList<>(leftStats.getSlotToColumnStats().keySet());
|
||||
List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
|
||||
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftSlot))
|
||||
.collect(Collectors.toList());
|
||||
long rowCount = -1;
|
||||
if (joinType.isSemiOrAntiJoin()) {
|
||||
rowCount = getSemiJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
|
||||
rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
|
||||
} else if (joinType.isInnerJoin() || joinType.isOuterJoin()) {
|
||||
rowCount = getJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);
|
||||
rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
|
||||
} else if (joinType.isCrossJoin()) {
|
||||
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
|
||||
rightStats.getRowCount());
|
||||
|
||||
@ -21,6 +21,8 @@ import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -202,4 +204,22 @@ public class JoinUtils {
|
||||
JoinType joinType = join.getJoinType();
|
||||
return (joinType.isInnerJoin() && join.getHashJoinConjuncts().isEmpty()) || joinType.isCrossJoin();
|
||||
}
|
||||
|
||||
/**
|
||||
* The left and right child of origin predicates need to be swap sometimes.
|
||||
* Case A:
|
||||
* select * from t1 join t2 on t2.id=t1.id
|
||||
* The left plan node is t1 and the right plan node is t2.
|
||||
* The left child of origin predicate is t2.id and the right child of origin predicate is t1.id.
|
||||
* In this situation, the children of predicate need to be swap => t1.id=t2.id.
|
||||
*/
|
||||
public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, List<Slot> leftOutput) {
|
||||
Set<ExprId> leftSlots = SlotExtractor.extractSlot(equalTo.left()).stream()
|
||||
.map(NamedExpression::getExprId).collect(Collectors.toSet());
|
||||
if (leftOutput.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()).containsAll(leftSlots)) {
|
||||
return equalTo;
|
||||
} else {
|
||||
return new EqualTo(equalTo.right(), equalTo.left());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user