From def6f5568e690221e01a8c176c1d5cdd322f6b62 Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 22 Aug 2022 23:38:17 +0800 Subject: [PATCH] [feature](nereids): enable exploration job (#11867) Enable the exploration job, and fix related problem. correct the join reorder --- .../apache/doris/nereids/NereidsPlanner.java | 2 + .../doris/nereids/cost/CostEstimate.java | 2 +- .../translator/PhysicalPlanTranslator.java | 7 +++ .../cascades/ExploreGroupExpressionJob.java | 5 +- .../cascades/OptimizeGroupExpressionJob.java | 5 +- .../apache/doris/nereids/rules/RuleSet.java | 2 +- .../rules/exploration/join/JoinCommute.java | 46 +++++++++---------- .../trees/plans/logical/LogicalJoin.java | 16 +++++-- 8 files changed, 47 insertions(+), 38 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index a0a24f9f66..5be733f91c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -121,6 +121,8 @@ public class NereidsPlanner extends Planner { // Get plan directly. Just for SSB. PhysicalPlan physicalPlan = getRoot().extractPlan(); + // TODO: remove above + // PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY); // post-process physical plan out of memo, just for future use. return postprocess(physicalPlan); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java index da60840877..a25d725706 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostEstimate.java @@ -35,7 +35,7 @@ public final class CostEstimate { * Constructor of CostEstimate. */ public CostEstimate(double cpuCost, double memoryCost, double networkCost) { - // TODO: remove them after finish statistics. + // TODO: fix stats if (cpuCost < 0) { cpuCost = 0; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index f560adf8eb..56490dd0a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -45,6 +45,7 @@ import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribution; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; @@ -493,6 +494,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor distribution, + PlanTranslatorContext context) { + return distribution.child().accept(this, context); + } + private void extractExecSlot(Expr root, Set slotRefList) { if (root instanceof SlotRef) { slotRefList.add(((SlotRef) root).getDesc().getId().asInt()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ExploreGroupExpressionJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ExploreGroupExpressionJob.java index a74b588ff7..6a7a5e59d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ExploreGroupExpressionJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ExploreGroupExpressionJob.java @@ -25,8 +25,6 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.pattern.Pattern; import org.apache.doris.nereids.rules.Rule; -import com.google.common.collect.Lists; - import java.util.Comparator; import java.util.List; @@ -50,8 +48,7 @@ public class ExploreGroupExpressionJob extends Job { @Override public void execute() { // TODO: enable exploration job after we test it - // List> explorationRules = getRuleSet().getExplorationRules(); - List explorationRules = Lists.newArrayList(); + List explorationRules = getRuleSet().getExplorationRules(); List validRules = getValidRules(groupExpression, explorationRules); validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise())); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java index 4a2a876452..89ae3114f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/OptimizeGroupExpressionJob.java @@ -44,9 +44,8 @@ public class OptimizeGroupExpressionJob extends Job { public void execute() { List validRules = new ArrayList<>(); List implementationRules = getRuleSet().getImplementationRules(); - // TODO: enable exploration job after we test it - // List> explorationRules = getRuleSet().getExplorationRules(); - // validRules.addAll(getValidRules(groupExpression, explorationRules)); + List explorationRules = getRuleSet().getExplorationRules(); + validRules.addAll(getValidRules(groupExpression, explorationRules)); validRules.addAll(getValidRules(groupExpression, implementationRules)); validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise())); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 692c50c3d4..91b487f934 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -39,7 +39,7 @@ import java.util.List; */ public class RuleSet { public static final List EXPLORATION_RULES = planRuleFactories() - .add(new JoinCommute(true)) + .add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG) .build(); public static final List REWRITE_RULES = planRuleFactories() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 19bf4491b9..99bac270e3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; /** @@ -32,6 +33,8 @@ public class JoinCommute extends OneExplorationRuleFactory { public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN); + public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG); + private final SwapType swapType; private final boolean swapOuter; @@ -51,25 +54,16 @@ public class JoinCommute extends OneExplorationRuleFactory { @Override public Rule build() { - return innerLogicalJoin(any(), any()).then(join -> { - if (!check(join)) { - return null; - } - boolean isBottomJoin = isBottomJoin(join); - if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin) { - return null; - } - + return innerLogicalJoin().when(this::check).then(join -> { LogicalJoin newJoin = new LogicalJoin( join.getJoinType(), join.getCondition(), join.right(), join.left(), - join.getJoinReorderContext() - ); + join.getJoinReorderContext()); newJoin.getJoinReorderContext().setHasCommute(true); - if (swapType == SwapType.ZIG_ZAG && !isBottomJoin) { - newJoin.getJoinReorderContext().setHasCommuteZigZag(true); - } + // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) { + // newJoin.getJoinReorderContext().setHasCommuteZigZag(true); + // } return newJoin; }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); @@ -77,6 +71,14 @@ public class JoinCommute extends OneExplorationRuleFactory { private boolean check(LogicalJoin join) { + if (!(join.left() instanceof LogicalPlan) || !(join.right() instanceof LogicalPlan)) { + return false; + } + + if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin(join)) { + return false; + } + if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) { return false; } @@ -84,18 +86,12 @@ public class JoinCommute extends OneExplorationRuleFactory { } private boolean isBottomJoin(LogicalJoin join) { - // TODO: wait for tree model of pattern-match. - if (join.left() instanceof LogicalProject) { - LogicalProject project = (LogicalProject) join.left(); - if (project.child() instanceof LogicalJoin) { - return false; - } + // TODO: filter need to be considered? + if (join.left() instanceof LogicalProject && ((LogicalProject) join.left()).child() instanceof LogicalJoin) { + return false; } - if (join.right() instanceof LogicalProject) { - LogicalProject project = (LogicalProject) join.left(); - if (project.child() instanceof LogicalJoin) { - return false; - } + if (join.right() instanceof LogicalProject && ((LogicalProject) join.right()).child() instanceof LogicalJoin) { + return false; } if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 81fbafdde6..77e9bea685 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -71,7 +71,7 @@ public class LogicalJoin condition, @@ -82,6 +82,13 @@ public class LogicalJoin condition, + Optional groupExpression, Optional logicalProperties, + LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) { + this(joinType, condition, groupExpression, logicalProperties, leftChild, rightChild); + this.joinReorderContext.copyFrom(joinReorderContext); + } + public Optional getCondition() { return condition; } @@ -170,17 +177,18 @@ public class LogicalJoin withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new LogicalJoin<>(joinType, condition, children.get(0), children.get(1)); + return new LogicalJoin<>(joinType, condition, children.get(0), children.get(1), joinReorderContext); } @Override public Plan withGroupExpression(Optional groupExpression) { return new LogicalJoin<>(joinType, condition, groupExpression, - Optional.of(logicalProperties), left(), right()); + Optional.of(logicalProperties), left(), right(), joinReorderContext); } @Override public Plan withLogicalProperties(Optional logicalProperties) { - return new LogicalJoin<>(joinType, condition, Optional.empty(), logicalProperties, left(), right()); + return new LogicalJoin<>(joinType, condition, Optional.empty(), logicalProperties, left(), right(), + joinReorderContext); } }