From 59e5527eb0fb4211f87477fccb7cbfaa9e220969 Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 30 Aug 2022 21:28:17 +0800 Subject: [PATCH] [feature](Nereids)enable CBO optimize stage in Nereids (#12008) - enable CBO stage in Nereids - use the `chooseBestPlan()` to get the best plan - add a new rule JoinCommuteProject - test the stage by JoinCommute rule --- .../apache/doris/nereids/NereidsPlanner.java | 5 +- .../doris/nereids/cost/CostCalculator.java | 10 +-- .../jobs/cascades/CostAndEnforcerJob.java | 58 +++++++++------- .../apache/doris/nereids/rules/RuleSet.java | 2 + .../rules/exploration/join/JoinCommute.java | 50 +++----------- .../exploration/join/JoinCommuteHelper.java | 48 ++++++++++++++ .../exploration/join/JoinCommuteProject.java | 66 +++++++++++++++++++ .../doris/nereids/stats/JoinEstimation.java | 15 ++++- .../apache/doris/nereids/util/JoinUtils.java | 20 ++++++ 9 files changed, 199 insertions(+), 75 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index f30dd3f2f7..34a798cb0e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -114,10 +114,7 @@ public class NereidsPlanner extends Planner { // cost-based optimize and explode plan space optimize(); - // Get plan directly. Just for SSB. - PhysicalPlan physicalPlan = getRoot().extractPlan(); - // TODO: remove above - // PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY); + PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY); // post-process physical plan out of memo, just for future use. return postprocess(physicalPlan); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java index f697e934ed..b34e94a313 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java @@ -72,13 +72,12 @@ public class CostCalculator { } @Override - public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { - StatsDeriveResult statistics = context.getStatisticsWithCheck(); - return CostEstimate.ofCpu(statistics.computeSize()); + public CostEstimate visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { + return CostEstimate.ofCpu(1); } @Override - public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort physicalQuickSort, PlanContext context) { + public CostEstimate visitPhysicalQuickSort(PhysicalQuickSort physicalQuickSort, PlanContext context) { // TODO: consider two-phase sort and enforcer. StatsDeriveResult statistics = context.getStatisticsWithCheck(); StatsDeriveResult childStatistics = context.getChildStatistics(0); @@ -102,7 +101,8 @@ public class CostCalculator { } @Override - public CostEstimate visitPhysicalDistribution(PhysicalDistribution physicalDistribution, PlanContext context) { + public CostEstimate visitPhysicalDistribution(PhysicalDistribution physicalDistribution, + PlanContext context) { StatsDeriveResult statistics = context.getStatisticsWithCheck(); StatsDeriveResult childStatistics = context.getChildStatistics(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java index df30c74897..6e1c844778 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java @@ -42,25 +42,28 @@ import java.util.Optional; public class CostAndEnforcerJob extends Job implements Cloneable { // GroupExpression to optimize private final GroupExpression groupExpression; - // Current total cost - private double curTotalCost; - // Children properties from parent plan node. + // cost of current plan tree + private double curTotalCost; + // cost of current plan node + private double curNodeCost; + + // List of request property to children // Example: Physical Hash Join // [ child item: [leftProperties, rightPropertie]] // [ [Properties {"", ANY}, Properties {"", BROADCAST}], // [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}]] private List> requestChildrenPropertyList; + // index of List + private int requestPropertyIndex = 0; - private List childrenBestGroupExprList; + private List childrenBestGroupExprList = Lists.newArrayList(); private final List childrenOutputProperty = Lists.newArrayList(); - // Current stage of enumeration through child groups + // current child index of travsing all children private int curChildIndex = -1; - // Indicator of last child group that we waited for optimization + // child index in the last time of travsing all children private int prevChildIndex = -1; - // Current stage of enumeration through outputInputProperties - private int curPropertyPairIndex = 0; public CostAndEnforcerJob(GroupExpression groupExpression, JobContext context) { super(JobType.OPTIMIZE_CHILDREN, context); @@ -99,22 +102,26 @@ public class CostAndEnforcerJob extends Job implements Cloneable { public void execute() { // Do init logic of root plan/groupExpr of `subplan`, only run once per task. if (curChildIndex == -1) { + curNodeCost = 0; curTotalCost = 0; - - // Get property from groupExpression plan (it's root of subplan). + curChildIndex = 0; + // List + // [ child item: [leftProperties, rightPropertie]] + // like :[ [Properties {"", ANY}, Properties {"", BROADCAST}], + // [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ] RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context); requestChildrenPropertyList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); - - curChildIndex = 0; } - for (; curPropertyPairIndex < requestChildrenPropertyList.size(); curPropertyPairIndex++) { - // children input properties - List requestChildrenProperty = requestChildrenPropertyList.get(curPropertyPairIndex); + for (; requestPropertyIndex < requestChildrenPropertyList.size(); requestPropertyIndex++) { + // Get one from List + // like: [ Properties {"", ANY}, Properties {"", BROADCAST} ], + List requestChildrenProperty = requestChildrenPropertyList.get(requestPropertyIndex); - // Calculate cost of groupExpression and update total cost + // Calculate cost if (curChildIndex == 0 && prevChildIndex == -1) { - curTotalCost += CostCalculator.calculateCost(groupExpression); + curNodeCost = CostCalculator.calculateCost(groupExpression); + curTotalCost += curNodeCost; } // Handle all child plannode. @@ -177,6 +184,9 @@ public class CostAndEnforcerJob extends Job implements Cloneable { } StatsCalculator.estimate(groupExpression); + curTotalCost -= curNodeCost; + curNodeCost = CostCalculator.calculateCost(groupExpression); + curTotalCost += curNodeCost; // record map { outputProperty -> outputProperty }, { ANY -> outputProperty }, recordPropertyAndCost(groupExpression, outputProperty, outputProperty, requestChildrenProperty); recordPropertyAndCost(groupExpression, outputProperty, PhysicalProperties.ANY, requestChildrenProperty); @@ -188,11 +198,7 @@ public class CostAndEnforcerJob extends Job implements Cloneable { } } - // Reset child idx and total cost - childrenOutputProperty.clear(); - prevChildIndex = -1; - curChildIndex = 0; - curTotalCost = 0; + clear(); } } @@ -230,6 +236,14 @@ public class CostAndEnforcerJob extends Job implements Cloneable { curTotalCost, requestProperty); } + private void clear() { + childrenOutputProperty.clear(); + childrenBestGroupExprList.clear(); + prevChildIndex = -1; + curChildIndex = 0; + curTotalCost = 0; + curNodeCost = 0; + } /** * Shallow clone (ignore clone propertiesListList and groupExpression). diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 91b487f934..d744c12064 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules; import org.apache.doris.nereids.rules.exploration.join.JoinCommute; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalFilterToPhysicalFilter; import org.apache.doris.nereids.rules.implementation.LogicalJoinToHashJoin; @@ -40,6 +41,7 @@ import java.util.List; public class RuleSet { public static final List EXPLORATION_RULES = planRuleFactories() .add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG) + .add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG) .build(); public static final List REWRITE_RULES = planRuleFactories() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 303eb856c8..64ebe5171f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -21,22 +21,21 @@ import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; +import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; /** - * rule factory for exchange inner join's children. + * Join Commute */ @Developing public class JoinCommute extends OneExplorationRuleFactory { public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN); - public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG); - private final SwapType swapType; private final boolean swapOuter; + private final SwapType swapType; public JoinCommute(boolean swapOuter) { this.swapOuter = swapOuter; @@ -48,14 +47,12 @@ public class JoinCommute extends OneExplorationRuleFactory { this.swapType = swapType; } - enum SwapType { - BOTTOM_JOIN, ZIG_ZAG, ALL - } - @Override public Rule build() { - return innerLogicalJoin().when(this::check).then(join -> { - LogicalJoin newJoin = new LogicalJoin( + return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> { + // TODO: add project for mapping column output. + // List newOutput = new ArrayList<>(join.getOutput()); + LogicalJoin newJoin = new LogicalJoin<>( join.getJoinType(), join.getHashJoinConjuncts(), join.getOtherJoinCondition(), @@ -66,37 +63,8 @@ public class JoinCommute extends OneExplorationRuleFactory { // newJoin.getJoinReorderContext().setHasCommuteZigZag(true); // } + // LogicalProject project = new LogicalProject<>(newOutput, newJoin); return newJoin; }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); } - - - private boolean check(LogicalJoin join) { - if (!(join.left() instanceof LogicalPlan) || !(join.right() instanceof LogicalPlan)) { - return false; - } - - if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin(join)) { - return false; - } - - if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) { - return false; - } - return true; - } - - private boolean isBottomJoin(LogicalJoin join) { - // TODO: filter need to be considered? - if (join.left() instanceof LogicalProject && ((LogicalProject) join.left()).child() instanceof LogicalJoin) { - return false; - } - if (join.right() instanceof LogicalProject && ((LogicalProject) join.right()).child() instanceof LogicalJoin) { - return false; - } - if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) { - return false; - } - return true; - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java new file mode 100644 index 0000000000..47da5030e1 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +/** + * Common function for JoinCommute + */ +public class JoinCommuteHelper { + + enum SwapType { + BOTTOM_JOIN, ZIG_ZAG, ALL + } + + private final boolean swapOuter; + private final SwapType swapType; + + public JoinCommuteHelper(boolean swapOuter, SwapType swapType) { + this.swapOuter = swapOuter; + this.swapType = swapType; + } + + public static boolean check(LogicalJoin join) { + return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange(); + } + + public static boolean check(LogicalProject> project) { + return check(project.child()); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java new file mode 100644 index 0000000000..07464275a1 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +/** + * Project-Join commute + */ +public class JoinCommuteProject extends OneExplorationRuleFactory { + + public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN); + public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG); + + private final SwapType swapType; + private final boolean swapOuter; + + public JoinCommuteProject(boolean swapOuter) { + this.swapOuter = swapOuter; + this.swapType = SwapType.ALL; + } + + public JoinCommuteProject(boolean swapOuter, SwapType swapType) { + this.swapOuter = swapOuter; + this.swapType = swapType; + } + + @Override + public Rule build() { + return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> { + LogicalJoin join = project.child(); + LogicalJoin newJoin = new LogicalJoin<>( + join.getJoinType(), + join.getHashJoinConjuncts(), + join.getOtherJoinCondition(), + join.right(), join.left(), + join.getJoinReorderContext()); + newJoin.getJoinReorderContext().setHasCommute(true); + // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) { + // newJoin.getJoinReorderContext().setHasCommuteZigZag(true); + // } + + return newJoin; + }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index e6f5470d29..2e315c5604 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -18,18 +18,22 @@ package org.apache.doris.nereids.stats; import org.apache.doris.common.CheckedMath; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.algebra.Join; +import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.statistics.ColumnStats; import org.apache.doris.statistics.StatsDeriveResult; import com.google.common.base.Preconditions; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** * Estimate hash join stats. @@ -44,12 +48,17 @@ public class JoinEstimation { JoinType joinType = join.getJoinType(); StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats); statsDeriveResult.merge(rightStats); - List eqConjunctList = join.getHashJoinConjuncts(); + // TODO: normalize join hashConjuncts. + List hashJoinConjuncts = join.getHashJoinConjuncts(); + List leftSlot = new ArrayList<>(leftStats.getSlotToColumnStats().keySet()); + List normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast) + .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftSlot)) + .collect(Collectors.toList()); long rowCount = -1; if (joinType.isSemiOrAntiJoin()) { - rowCount = getSemiJoinRowCount(leftStats, rightStats, eqConjunctList, joinType); + rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); } else if (joinType.isInnerJoin() || joinType.isOuterJoin()) { - rowCount = getJoinRowCount(leftStats, rightStats, eqConjunctList, joinType); + rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); } else if (joinType.isCrossJoin()) { rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(), rightStats.getRowCount()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index 6ba3629495..2c1ded556a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -21,6 +21,8 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; @@ -202,4 +204,22 @@ public class JoinUtils { JoinType joinType = join.getJoinType(); return (joinType.isInnerJoin() && join.getHashJoinConjuncts().isEmpty()) || joinType.isCrossJoin(); } + + /** + * The left and right child of origin predicates need to be swap sometimes. + * Case A: + * select * from t1 join t2 on t2.id=t1.id + * The left plan node is t1 and the right plan node is t2. + * The left child of origin predicate is t2.id and the right child of origin predicate is t1.id. + * In this situation, the children of predicate need to be swap => t1.id=t2.id. + */ + public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, List leftOutput) { + Set leftSlots = SlotExtractor.extractSlot(equalTo.left()).stream() + .map(NamedExpression::getExprId).collect(Collectors.toSet()); + if (leftOutput.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()).containsAll(leftSlots)) { + return equalTo; + } else { + return new EqualTo(equalTo.right(), equalTo.left()); + } + } }