[refactor](Nereids): cascades refactor (#9470)

Describe the overview of changes.

- rename GroupExpression
- use `HashSet<GroupExpression> groupExpressions` in `memo`
- add label of `Nereids` for CI
- remove `GroupExpr` from Plan
This commit is contained in:
jakevin
2022-05-11 11:07:58 +08:00
committed by GitHub
parent ad88eb739b
commit 74352c807e
14 changed files with 107 additions and 114 deletions

View File

@ -31,6 +31,9 @@ kind/test:
area/vectorization:
- be/src/vec/**/*
area/nereids:
- fe/fe-core/src/main/java/org/apache/doris/nereids/**/*
area/planner:
- fe/fe-core/src/main/java/org/apache/doris/planner/**/*
- fe/fe-core/src/main/java/org/apache/doris/analysis/**/*

View File

@ -19,7 +19,7 @@ package org.apache.doris.nereids.jobs;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.trees.plans.Plan;
@ -46,7 +46,7 @@ public abstract class Job {
return context.getOptimizerContext().getRuleSet();
}
public void prunedInvalidRules(PlanReference planReference, List<Rule<Plan>> candidateRules) {
public void prunedInvalidRules(GroupExpression groupExpression, List<Rule<Plan>> candidateRules) {
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.PatternMatching;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
@ -30,30 +30,30 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import java.util.List;
/**
* Job to apply rule on {@link PlanReference}.
* Job to apply rule on {@link GroupExpression}.
*/
public class ApplyRuleJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
private final Rule<Plan> rule;
private final boolean exploredOnly;
/**
* Constructor of ApplyRuleJob.
*
* @param planReference apply rule on this {@link PlanReference}
* @param groupExpression apply rule on this {@link GroupExpression}
* @param rule rule to be applied
* @param context context of optimization
*/
public ApplyRuleJob(PlanReference planReference, Rule<Plan> rule, PlannerContext context) {
public ApplyRuleJob(GroupExpression groupExpression, Rule<Plan> rule, PlannerContext context) {
super(JobType.APPLY_RULE, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
this.rule = rule;
this.exploredOnly = false;
}
@Override
public void execute() throws AnalysisException {
if (planReference.hasExplored(rule)) {
if (groupExpression.hasExplored(rule)) {
return;
}
@ -65,20 +65,20 @@ public class ApplyRuleJob extends Job {
}
List<Plan> newPlanList = rule.transform(plan, context);
for (Plan newPlan : newPlanList) {
PlanReference newReference = context.getOptimizerContext().getMemo()
.newPlanReference(newPlan, planReference.getParent());
GroupExpression newGroupExpression = context.getOptimizerContext().getMemo()
.newGroupExpression(newPlan, groupExpression.getParent());
// TODO need to check return is a new Reference, other wise will be into a dead loop
if (newPlan instanceof LogicalPlan) {
pushTask(new DeriveStatsJob(newReference, context));
pushTask(new DeriveStatsJob(newGroupExpression, context));
if (exploredOnly) {
pushTask(new ExplorePlanJob(newReference, context));
pushTask(new ExplorePlanJob(newGroupExpression, context));
}
pushTask(new OptimizePlanJob(newReference, context));
pushTask(new OptimizePlanJob(newGroupExpression, context));
} else {
pushTask(new CostAndEnforcerJob(newReference, context));
pushTask(new CostAndEnforcerJob(newGroupExpression, context));
}
}
}
planReference.setExplored(rule);
groupExpression.setExplored(rule);
}
}

View File

@ -20,17 +20,17 @@ package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
/**
* Job to compute cost and add enforcer.
*/
public class CostAndEnforcerJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
public CostAndEnforcerJob(PlanReference planReference, PlannerContext context) {
public CostAndEnforcerJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.OPTIMIZE_CHILDREN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}
@Override

View File

@ -21,24 +21,24 @@ import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
/**
* Job to derive stats for {@link PlanReference} in {@link org.apache.doris.nereids.memo.Memo}.
* Job to derive stats for {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class DeriveStatsJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
private boolean deriveChildren;
/**
* Constructor for DeriveStatsJob.
*
* @param planReference Derive stats on this {@link PlanReference}
* @param groupExpression Derive stats on this {@link GroupExpression}
* @param context context of optimization
*/
public DeriveStatsJob(PlanReference planReference, PlannerContext context) {
public DeriveStatsJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.DERIVE_STATS, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
this.deriveChildren = false;
}
@ -49,7 +49,7 @@ public class DeriveStatsJob extends Job {
*/
public DeriveStatsJob(DeriveStatsJob other) {
super(JobType.DERIVE_STATS, other.context);
this.planReference = other.planReference;
this.groupExpression = other.groupExpression;
this.deriveChildren = other.deriveChildren;
}
@ -58,14 +58,14 @@ public class DeriveStatsJob extends Job {
if (!deriveChildren) {
deriveChildren = true;
pushTask(new DeriveStatsJob(this));
for (Group childSet : planReference.getChildren()) {
for (Group childSet : groupExpression.getChildren()) {
if (!childSet.getLogicalPlanList().isEmpty()) {
pushTask(new DeriveStatsJob(childSet.getLogicalPlanList().get(0), context));
}
}
} else {
// TODO: derive stat here
planReference.setStatDerived(true);
groupExpression.setStatDerived(true);
}
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
/**
* Job to explore {@link Group} in {@link org.apache.doris.nereids.memo.Memo}.
@ -45,8 +45,8 @@ public class ExploreGroupJob extends Job {
if (group.isExplored()) {
return;
}
for (PlanReference planReference : group.getLogicalPlanList()) {
pushTask(new ExplorePlanJob(planReference, context));
for (GroupExpression groupExpression : group.getLogicalPlanList()) {
pushTask(new ExplorePlanJob(groupExpression, context));
}
group.setExplored(true);
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
@ -30,34 +30,34 @@ import java.util.Comparator;
import java.util.List;
/**
* Job to explore {@link PlanReference} in {@link org.apache.doris.nereids.memo.Memo}.
* Job to explore {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class ExplorePlanJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
/**
* Constructor for ExplorePlanJob.
*
* @param planReference {@link PlanReference} to be explored
* @param groupExpression {@link GroupExpression} to be explored
* @param context context of optimization
*/
public ExplorePlanJob(PlanReference planReference, PlannerContext context) {
public ExplorePlanJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.EXPLORE_PLAN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}
@Override
public void execute() {
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
prunedInvalidRules(planReference, explorationRules);
prunedInvalidRules(groupExpression, explorationRules);
explorationRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));
for (Rule rule : explorationRules) {
pushTask(new ApplyRuleJob(planReference, rule, context));
pushTask(new ApplyRuleJob(groupExpression, rule, context));
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
Group childSet = groupExpression.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
/**
* Job to optimize {@link Group} in {@link org.apache.doris.nereids.memo.Memo}.
@ -41,12 +41,12 @@ public class OptimizeGroupJob extends Job {
return;
}
if (!group.isExplored()) {
for (PlanReference logicalPlanReference : group.getLogicalPlanList()) {
context.getOptimizerContext().pushTask(new OptimizePlanJob(logicalPlanReference, context));
for (GroupExpression logicalGroupExpression : group.getLogicalPlanList()) {
context.getOptimizerContext().pushTask(new OptimizePlanJob(logicalGroupExpression, context));
}
}
for (PlanReference physicalPlanReference : group.getPhysicalPlanList()) {
context.getOptimizerContext().pushTask(new CostAndEnforcerJob(physicalPlanReference, context));
for (GroupExpression physicalGroupExpression : group.getPhysicalPlanList()) {
context.getOptimizerContext().pushTask(new CostAndEnforcerJob(physicalGroupExpression, context));
}
group.setExplored(true);
}

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
@ -34,11 +34,11 @@ import java.util.List;
* Job to optimize {@link org.apache.doris.nereids.trees.plans.Plan} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class OptimizePlanJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
public OptimizePlanJob(PlanReference planReference, PlannerContext context) {
public OptimizePlanJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.OPTIMIZE_PLAN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}
@Override
@ -46,21 +46,21 @@ public class OptimizePlanJob extends Job {
List<Rule<Plan>> validRules = new ArrayList<>();
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
List<Rule<Plan>> implementationRules = getRuleSet().getImplementationRules();
prunedInvalidRules(planReference, explorationRules);
prunedInvalidRules(planReference, implementationRules);
prunedInvalidRules(groupExpression, explorationRules);
prunedInvalidRules(groupExpression, implementationRules);
validRules.addAll(explorationRules);
validRules.addAll(implementationRules);
validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));
for (Rule rule : validRules) {
pushTask(new ApplyRuleJob(planReference, rule, context));
pushTask(new ApplyRuleJob(groupExpression, rule, context));
// If child_pattern has any more children (i.e non-leaf), then we will explore the
// child before applying the rule. (assumes task pool is effectively a stack)
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
Group childSet = groupExpression.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
}

View File

@ -36,32 +36,32 @@ import java.util.Optional;
public class Group {
private final GroupId groupId = GroupId.newPlanSetId();
private final List<PlanReference> logicalPlanList = Lists.newArrayList();
private final List<PlanReference> physicalPlanList = Lists.newArrayList();
private final List<GroupExpression> logicalPlanList = Lists.newArrayList();
private final List<GroupExpression> physicalPlanList = Lists.newArrayList();
private final LogicalProperties logicalProperties;
private Map<PhysicalProperties, Pair<Double, PlanReference>> lowestCostPlans;
private Map<PhysicalProperties, Pair<Double, GroupExpression>> lowestCostPlans;
private double costLowerBound = -1;
private boolean isExplored = false;
/**
* Constructor for Group.
*
* @param planReference first {@link PlanReference} in this Group
* @param groupExpression first {@link GroupExpression} in this Group
*/
public Group(PlanReference planReference) {
if (planReference.getPlan() instanceof LogicalPlan) {
this.logicalPlanList.add(planReference);
public Group(GroupExpression groupExpression) {
if (groupExpression.getPlan() instanceof LogicalPlan) {
this.logicalPlanList.add(groupExpression);
} else {
this.physicalPlanList.add(planReference);
this.physicalPlanList.add(groupExpression);
}
logicalProperties = new LogicalProperties();
try {
logicalProperties.setOutput(planReference.getPlan().getOutput());
logicalProperties.setOutput(groupExpression.getPlan().getOutput());
} catch (UnboundException e) {
throw new RuntimeException(e);
}
planReference.setParent(this);
groupExpression.setParent(this);
}
public GroupId getGroupId() {
@ -69,18 +69,18 @@ public class Group {
}
/**
* Add new {@link PlanReference} into this group.
* Add new {@link GroupExpression} into this group.
*
* @param planReference {@link PlanReference} to be added
* @return added {@link PlanReference}
* @param groupExpression {@link GroupExpression} to be added
* @return added {@link GroupExpression}
*/
public PlanReference addPlanReference(PlanReference planReference) {
if (planReference.getPlan() instanceof LogicalPlan) {
logicalPlanList.add(planReference);
public GroupExpression addGroupExpression(GroupExpression groupExpression) {
if (groupExpression.getPlan() instanceof LogicalPlan) {
logicalPlanList.add(groupExpression);
} else {
physicalPlanList.add(planReference);
physicalPlanList.add(groupExpression);
}
return planReference;
return groupExpression;
}
public double getCostLowerBound() {
@ -91,11 +91,11 @@ public class Group {
this.costLowerBound = costLowerBound;
}
public List<PlanReference> getLogicalPlanList() {
public List<GroupExpression> getLogicalPlanList() {
return logicalPlanList;
}
public List<PlanReference> getPhysicalPlanList() {
public List<GroupExpression> getPhysicalPlanList() {
return physicalPlanList;
}
@ -116,9 +116,9 @@ public class Group {
* which meeting the physical property constraints in this Group.
*
* @param physicalProperties the physical property constraints
* @return {@link Optional} of cost and {@link PlanReference} of physical plan pair.
* @return {@link Optional} of cost and {@link GroupExpression} of physical plan pair.
*/
public Optional<Pair<Double, PlanReference>> getLowestCostPlan(PhysicalProperties physicalProperties) {
public Optional<Pair<Double, GroupExpression>> getLowestCostPlan(PhysicalProperties physicalProperties) {
if (physicalProperties == null || CollectionUtils.isEmpty(lowestCostPlans)) {
return Optional.empty();
}

View File

@ -29,29 +29,28 @@ import java.util.List;
/**
* Representation for group expression in cascades optimizer.
*/
public class PlanReference {
public class GroupExpression {
private Group parent;
private List<Group> children;
private final Plan<?> plan;
private final BitSet ruleMasks;
private boolean statDerived;
public PlanReference(Plan<?> plan) {
public GroupExpression(Plan<?> plan) {
this(plan, Lists.newArrayList());
}
/**
* Constructor for PlanReference.
* Constructor for GroupExpression.
*
* @param plan {@link Plan} to reference
* @param children children groups in memo
*/
public PlanReference(Plan<?> plan, List<Group> children) {
public GroupExpression(Plan<?> plan, List<Group> children) {
this.plan = plan;
this.children = children;
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
this.statDerived = false;
plan.setPlanReference(this);
}
public void addChild(Group child) {

View File

@ -21,19 +21,21 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
/**
* Representation for memo in cascades optimizer.
*/
public class Memo {
private final List<Group> groups = Lists.newArrayList();
private final List<PlanReference> planReferences = Lists.newArrayList();
private final Set<GroupExpression> groupExpressions = Sets.newHashSet();
private Group rootSet;
public void initialize(LogicalPlan plan) {
rootSet = newPlanReference(plan, null).getParent();
rootSet = newGroupExpression(plan, null).getParent();
}
public Group getRootSet() {
@ -43,31 +45,37 @@ public class Memo {
/**
* Add plan to Memo.
*
* @param plan {@link Plan} to be added
* @param plan {@link Plan} to be added
* @param target target group to add plan. null to generate new Group
* @return Reference of plan in Memo
*/
// TODO: need to merge PlanRefSet if new PlanRef is same with some one already in memo
public PlanReference newPlanReference(Plan<?> plan, Group target) {
if (plan.getPlanReference() != null) {
return plan.getPlanReference();
}
List<PlanReference> childReferences = Lists.newArrayList();
public GroupExpression newGroupExpression(Plan<?> plan, Group target) {
List<GroupExpression> childGroupExpr = Lists.newArrayList();
for (Plan<?> childrenPlan : plan.children()) {
childReferences.add(newPlanReference(childrenPlan, null));
childGroupExpr.add(newGroupExpression(childrenPlan, null));
}
PlanReference newPlanReference = new PlanReference(plan);
planReferences.add(newPlanReference);
for (PlanReference childReference : childReferences) {
newPlanReference.addChild(childReference.getParent());
GroupExpression newGroupExpression = new GroupExpression(plan);
for (GroupExpression childReference : childGroupExpr) {
newGroupExpression.addChild(childReference.getParent());
}
return insertGroupExpression(newGroupExpression, target);
}
private GroupExpression insertGroupExpression(GroupExpression groupExpression, Group target) {
if (groupExpressions.contains(groupExpression)) {
return groupExpression;
}
groupExpressions.add(groupExpression);
if (target != null) {
target.addPlanReference(newPlanReference);
target.addGroupExpression(groupExpression);
} else {
Group group = new Group(newPlanReference);
Group group = new Group(groupExpression);
groups.add(group);
}
return newPlanReference;
return groupExpression;
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -37,7 +36,6 @@ import java.util.List;
public abstract class AbstractPlan<PLAN_TYPE extends AbstractPlan<PLAN_TYPE>>
extends AbstractTreeNode<PLAN_TYPE> implements Plan<PLAN_TYPE> {
protected PlanReference planReference;
protected List<Slot> output;
public AbstractPlan(NodeType type, Plan...children) {
@ -47,16 +45,6 @@ public abstract class AbstractPlan<PLAN_TYPE extends AbstractPlan<PLAN_TYPE>>
@Override
public abstract List<Slot> getOutput() throws UnboundException;
@Override
public PlanReference getPlanReference() {
return planReference;
}
@Override
public void setPlanReference(PlanReference planReference) {
this.planReference = planReference;
}
@Override
public List<Plan> children() {
return (List) children;

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -31,10 +30,6 @@ public interface Plan<PLAN_TYPE extends Plan<PLAN_TYPE>> extends TreeNode<PLAN_T
List<Slot> getOutput() throws UnboundException;
PlanReference getPlanReference();
void setPlanReference(PlanReference planReference);
String treeString();
@Override