[feature](nereids): enable exploration job (#11867)
Enable the exploration job, and fix related problem. correct the join reorder
This commit is contained in:
@ -121,6 +121,8 @@ public class NereidsPlanner extends Planner {
|
||||
|
||||
// Get plan directly. Just for SSB.
|
||||
PhysicalPlan physicalPlan = getRoot().extractPlan();
|
||||
// TODO: remove above
|
||||
// PhysicalPlan physicalPlan = chooseBestPlan(getRoot(), PhysicalProperties.ANY);
|
||||
|
||||
// post-process physical plan out of memo, just for future use.
|
||||
return postprocess(physicalPlan);
|
||||
|
||||
@ -35,7 +35,7 @@ public final class CostEstimate {
|
||||
* Constructor of CostEstimate.
|
||||
*/
|
||||
public CostEstimate(double cpuCost, double memoryCost, double networkCost) {
|
||||
// TODO: remove them after finish statistics.
|
||||
// TODO: fix stats
|
||||
if (cpuCost < 0) {
|
||||
cpuCost = 0;
|
||||
}
|
||||
|
||||
@ -45,6 +45,7 @@ import org.apache.doris.nereids.trees.plans.PlanType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribution;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
|
||||
@ -493,6 +494,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
return inputFragment;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PlanFragment visitPhysicalDistribution(PhysicalDistribution<Plan> distribution,
|
||||
PlanTranslatorContext context) {
|
||||
return distribution.child().accept(this, context);
|
||||
}
|
||||
|
||||
private void extractExecSlot(Expr root, Set<Integer> slotRefList) {
|
||||
if (root instanceof SlotRef) {
|
||||
slotRefList.add(((SlotRef) root).getDesc().getId().asInt());
|
||||
|
||||
@ -25,8 +25,6 @@ import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.pattern.Pattern;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
@ -50,8 +48,7 @@ public class ExploreGroupExpressionJob extends Job {
|
||||
@Override
|
||||
public void execute() {
|
||||
// TODO: enable exploration job after we test it
|
||||
// List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
|
||||
List<Rule> explorationRules = Lists.newArrayList();
|
||||
List<Rule> explorationRules = getRuleSet().getExplorationRules();
|
||||
List<Rule> validRules = getValidRules(groupExpression, explorationRules);
|
||||
validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));
|
||||
|
||||
|
||||
@ -44,9 +44,8 @@ public class OptimizeGroupExpressionJob extends Job {
|
||||
public void execute() {
|
||||
List<Rule> validRules = new ArrayList<>();
|
||||
List<Rule> implementationRules = getRuleSet().getImplementationRules();
|
||||
// TODO: enable exploration job after we test it
|
||||
// List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
|
||||
// validRules.addAll(getValidRules(groupExpression, explorationRules));
|
||||
List<Rule> explorationRules = getRuleSet().getExplorationRules();
|
||||
validRules.addAll(getValidRules(groupExpression, explorationRules));
|
||||
validRules.addAll(getValidRules(groupExpression, implementationRules));
|
||||
validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ import java.util.List;
|
||||
*/
|
||||
public class RuleSet {
|
||||
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
|
||||
.add(new JoinCommute(true))
|
||||
.add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
|
||||
.build();
|
||||
|
||||
public static final List<Rule> REWRITE_RULES = planRuleFactories()
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
/**
|
||||
@ -32,6 +33,8 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
|
||||
|
||||
private final SwapType swapType;
|
||||
private final boolean swapOuter;
|
||||
|
||||
@ -51,25 +54,16 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return innerLogicalJoin(any(), any()).then(join -> {
|
||||
if (!check(join)) {
|
||||
return null;
|
||||
}
|
||||
boolean isBottomJoin = isBottomJoin(join);
|
||||
if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return innerLogicalJoin().when(this::check).then(join -> {
|
||||
LogicalJoin newJoin = new LogicalJoin(
|
||||
join.getJoinType(),
|
||||
join.getCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext()
|
||||
);
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
if (swapType == SwapType.ZIG_ZAG && !isBottomJoin) {
|
||||
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
}
|
||||
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
|
||||
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
// }
|
||||
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
@ -77,6 +71,14 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
|
||||
private boolean check(LogicalJoin join) {
|
||||
if (!(join.left() instanceof LogicalPlan) || !(join.right() instanceof LogicalPlan)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin(join)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) {
|
||||
return false;
|
||||
}
|
||||
@ -84,18 +86,12 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
}
|
||||
|
||||
private boolean isBottomJoin(LogicalJoin join) {
|
||||
// TODO: wait for tree model of pattern-match.
|
||||
if (join.left() instanceof LogicalProject) {
|
||||
LogicalProject project = (LogicalProject) join.left();
|
||||
if (project.child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
// TODO: filter need to be considered?
|
||||
if (join.left() instanceof LogicalProject && ((LogicalProject) join.left()).child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
if (join.right() instanceof LogicalProject) {
|
||||
LogicalProject project = (LogicalProject) join.left();
|
||||
if (project.child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
if (join.right() instanceof LogicalProject && ((LogicalProject) join.right()).child() instanceof LogicalJoin) {
|
||||
return false;
|
||||
}
|
||||
if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) {
|
||||
return false;
|
||||
|
||||
@ -71,7 +71,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
/**
|
||||
* Constructor for LogicalJoinPlan.
|
||||
*
|
||||
* @param joinType logical type for join
|
||||
* @param joinType logical type for join
|
||||
* @param condition on clause for join node
|
||||
*/
|
||||
public LogicalJoin(JoinType joinType, Optional<Expression> condition,
|
||||
@ -82,6 +82,13 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
this.condition = Objects.requireNonNull(condition, "condition can not be null");
|
||||
}
|
||||
|
||||
public LogicalJoin(JoinType joinType, Optional<Expression> condition,
|
||||
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties,
|
||||
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) {
|
||||
this(joinType, condition, groupExpression, logicalProperties, leftChild, rightChild);
|
||||
this.joinReorderContext.copyFrom(joinReorderContext);
|
||||
}
|
||||
|
||||
public Optional<Expression> getCondition() {
|
||||
return condition;
|
||||
}
|
||||
@ -170,17 +177,18 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
@Override
|
||||
public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) {
|
||||
Preconditions.checkArgument(children.size() == 2);
|
||||
return new LogicalJoin<>(joinType, condition, children.get(0), children.get(1));
|
||||
return new LogicalJoin<>(joinType, condition, children.get(0), children.get(1), joinReorderContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
return new LogicalJoin<>(joinType, condition, groupExpression,
|
||||
Optional.of(logicalProperties), left(), right());
|
||||
Optional.of(logicalProperties), left(), right(), joinReorderContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
|
||||
return new LogicalJoin<>(joinType, condition, Optional.empty(), logicalProperties, left(), right());
|
||||
return new LogicalJoin<>(joinType, condition, Optional.empty(), logicalProperties, left(), right(),
|
||||
joinReorderContext);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user