[improve](Nereids): remove redundant code, add annotation in Memo. (#14083)

This commit is contained in:
jakevin
2022-11-09 13:39:20 +08:00
committed by GitHub
parent aff62655c4
commit b144d2b4f4
5 changed files with 51 additions and 113 deletions

View File

@ -18,12 +18,9 @@
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.statistics.StatsDeriveResult;
@ -46,6 +43,7 @@ import java.util.stream.Collectors;
*/
public class Group {
private final GroupId groupId;
// Save all parent GroupExpression to avoid travsing whole Memo.
private final IdentityHashMap<GroupExpression, Void> parentExpressions = new IdentityHashMap<>();
private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
@ -57,7 +55,7 @@ public class Group {
private final Map<PhysicalProperties, Pair<Double, GroupExpression>> lowestCostPlans = Maps.newHashMap();
private double costLowerBound = -1;
private boolean isExplored = false;
private boolean hasCost = false;
private StatsDeriveResult statistics;
/**
@ -87,14 +85,6 @@ public class Group {
return groupId;
}
public boolean isHasCost() {
return hasCost;
}
public void setHasCost(boolean hasCost) {
this.hasCost = hasCost;
}
/**
* Add new {@link GroupExpression} into this group.
*
@ -132,27 +122,6 @@ public class Group {
logicalExpressions.add(groupExpression);
}
public void addPhysicalExpression(GroupExpression groupExpression) {
groupExpression.setOwnerGroup(this);
physicalExpressions.add(groupExpression);
}
/**
* Rewrite the logical group expression to the new logical group expression.
*
* @param newExpression new logical group expression
* @return old logical group expression
*/
public GroupExpression rewriteLogicalExpression(GroupExpression newExpression,
LogicalProperties logicalProperties) {
newExpression.setOwnerGroup(this);
this.logicalProperties = logicalProperties;
GroupExpression oldExpression = getLogicalExpression();
logicalExpressions.clear();
logicalExpressions.add(newExpression);
return oldExpression;
}
public List<GroupExpression> clearLogicalExpressions() {
List<GroupExpression> move = logicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
@ -173,10 +142,6 @@ public class Group {
return costLowerBound;
}
public void setCostLowerBound(double costLowerBound) {
this.costLowerBound = costLowerBound;
}
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
@ -273,32 +238,6 @@ public class Group {
return Optional.ofNullable(lowestCostPlans.get(physicalProperties));
}
public Map<PhysicalProperties, Pair<Double, GroupExpression>> getLowestCostPlans() {
return lowestCostPlans;
}
/**
* Get the first Plan from Memo.
*/
public PhysicalPlan extractPlan() throws AnalysisException {
GroupExpression groupExpression = this.physicalExpressions.get(0);
List<Plan> planChildren = com.google.common.collect.Lists.newArrayList();
for (int i = 0; i < groupExpression.arity(); i++) {
planChildren.add(groupExpression.child(i).extractPlan());
}
Plan plan = groupExpression.getPlan()
.withChildren(planChildren)
.withGroupExpression(Optional.of(groupExpression));
if (!(plan instanceof PhysicalPlan)) {
throw new AnalysisException("generate logical plan");
}
PhysicalPlan physicalPlan = (PhysicalPlan) plan;
return physicalPlan;
}
public List<GroupExpression> getParentGroupExpressions() {
return ImmutableList.copyOf(parentExpressions.keySet());
}
@ -318,10 +257,6 @@ public class Group {
return parentExpressions.size();
}
public int parentExpressionNum() {
return parentExpressions.size();
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
@ -43,7 +44,7 @@ public class GroupExpression {
private double cost = 0.0;
private CostEstimate costEstimate = null;
private Group ownerGroup;
private List<Group> children;
private final List<Group> children;
private final Plan plan;
private final BitSet ruleMasks;
private boolean statDerived;
@ -71,11 +72,11 @@ public class GroupExpression {
this.plan = Objects.requireNonNull(plan, "plan can not be null")
.withGroupExpression(Optional.of(this));
this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null"));
this.children.forEach(childGroup -> childGroup.addParentExpression(this));
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
this.statDerived = false;
this.lowestCostTable = Maps.newHashMap();
this.requestPropertiesMap = Maps.newHashMap();
this.children.forEach(childGroup -> childGroup.addParentExpression(this));
}
public PhysicalProperties getOutputProperties(PhysicalProperties requestProperties) {
@ -111,21 +112,13 @@ public class GroupExpression {
/**
* replaceChild.
*
* @param originChild origin child group
* @param oldChild origin child group
* @param newChild new child group
*/
public void replaceChild(Group originChild, Group newChild) {
originChild.removeParentExpression(this);
for (int i = 0; i < children.size(); i++) {
if (children.get(i) == originChild) {
children.set(i, newChild);
newChild.addParentExpression(this);
}
}
}
public void setChild(int index, Group group) {
this.children.set(index, group);
public void replaceChild(Group oldChild, Group newChild) {
oldChild.removeParentExpression(this);
newChild.addParentExpression(this);
Utils.replaceList(children, oldChild, newChild);
}
public boolean hasApplied(Rule rule) {
@ -255,7 +248,7 @@ public class GroupExpression {
if (costEstimate != null) {
builder.append(" est=").append(costEstimate);
}
builder.append(" (plan=" + plan.toString() + ") children=[");
builder.append(" (plan=").append(plan.toString()).append(") children=[");
for (Group group : children) {
builder.append(group.getGroupId()).append(" ");
}

View File

@ -72,7 +72,6 @@ public class Memo {
/**
* Add plan to Memo.
* TODO: add ut later
*
* @param plan {@link Plan} or {@link Expression} to be added
* @param target target group to add node. null to generate new Group
@ -386,9 +385,10 @@ public class Memo {
}
}
for (GroupExpression groupExpression : needReplaceChild) {
// After change GroupExpression children, the hashcode will change,
// so need to reinsert into map.
groupExpressions.remove(groupExpression);
List<Group> children = groupExpression.children();
// TODO: use a better way to replace child, avoid traversing all groupExpression
for (int i = 0; i < children.size(); i++) {
if (children.get(i).equals(source)) {
children.set(i, destination);
@ -480,6 +480,7 @@ public class Memo {
// case 5:
// if targetGroup is null or targetGroup equal to the existedExpression's ownerGroup,
// then recycle the temporary new group expression
// No ownerGroup, don't need ownerGroup.removeChild()
recycleExpression(newExpression);
return CopyInResult.of(false, existedExpression);
}
@ -512,8 +513,7 @@ public class Memo {
List<GroupExpression> logicalExpressions = fromGroup.clearLogicalExpressions();
recycleGroup(fromGroup);
recycleLogicalExpressions(targetGroup);
recyclePhysicalExpressions(targetGroup);
recycleLogicalAndPhysicalExpressions(targetGroup);
for (GroupExpression logicalExpression : logicalExpressions) {
targetGroup.addLogicalExpression(logicalExpression);
@ -522,8 +522,7 @@ public class Memo {
}
private void reInitGroup(Group group, GroupExpression initLogicalExpression, LogicalProperties logicalProperties) {
recycleLogicalExpressions(group);
recyclePhysicalExpressions(group);
recycleLogicalAndPhysicalExpressions(group);
group.setLogicalProperties(logicalProperties);
group.addLogicalExpression(initLogicalExpression);
@ -562,42 +561,45 @@ public class Memo {
}
}
/**
* Notice: this func don't replace { Parent GroupExpressions -> this Group }.
*/
private void recycleGroup(Group group) {
// recycle in memo.
if (groups.get(group.getGroupId()) == group) {
groups.remove(group.getGroupId());
}
recycleLogicalExpressions(group);
recyclePhysicalExpressions(group);
// recycle children GroupExpression
recycleLogicalAndPhysicalExpressions(group);
}
private void recycleLogicalExpressions(Group group) {
if (!group.getLogicalExpressions().isEmpty()) {
for (GroupExpression logicalExpression : group.getLogicalExpressions()) {
recycleExpression(logicalExpression);
}
group.clearLogicalExpressions();
}
}
private void recyclePhysicalExpressions(Group group) {
if (!group.getPhysicalExpressions().isEmpty()) {
for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
recycleExpression(physicalExpression);
}
group.clearPhysicalExpressions();
}
private void recycleLogicalAndPhysicalExpressions(Group group) {
group.getLogicalExpressions().forEach(this::recycleExpression);
group.clearLogicalExpressions();
group.getPhysicalExpressions().forEach(this::recycleExpression);
group.clearPhysicalExpressions();
}
/**
* Notice: this func don't clear { OwnerGroup() -> this GroupExpression }.
*/
private void recycleExpression(GroupExpression groupExpression) {
// recycle in memo.
if (groupExpressions.get(groupExpression) == groupExpression) {
groupExpressions.remove(groupExpression);
}
for (Group childGroup : groupExpression.children()) {
// recycle parentGroupExpr in childGroup
groupExpression.children().forEach(childGroup -> {
// if not any groupExpression reference child group, then recycle the child group
if (childGroup.removeParentExpression(groupExpression) == 0) {
recycleGroup(childGroup);
}
}
});
groupExpression.setOwnerGroup(null);
}
@Override

View File

@ -201,4 +201,12 @@ public class Utils {
public static LocalDateTime getLocalDatetimeFromLong(long dateTime) {
return LocalDateTime.ofInstant(Instant.ofEpochSecond(dateTime), ZoneId.systemDefault());
}
public static <T> void replaceList(List<T> list, T oldItem, T newItem) {
for (int i = 0; i < list.size(); i++) {
if (list.get(i) == oldItem) {
list.set(i, newItem);
}
}
}
}

View File

@ -61,13 +61,13 @@ import java.util.Optional;
class MemoTest implements PatternMatchSupported {
private ConnectContext connectContext = MemoTestUtils.createConnectContext();
private final ConnectContext connectContext = MemoTestUtils.createConnectContext();
private LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
private final LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
PlanConstructor.newLogicalOlapScan(0, "A", 0),
PlanConstructor.newLogicalOlapScan(1, "B", 0));
private LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
private final LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
@Test
@ -395,7 +395,7 @@ class MemoTest implements PatternMatchSupported {
// valid case: 5 steps
class A extends UnboundRelation {
// 1: declare the Plan has some states
State state;
final State state;
public A(List<String> nameParts, State state) {
this(nameParts, state, Optional.empty());