[improve](Nereids): remove redundant code, add annotation in Memo. (#14083)
This commit is contained in:
@ -18,12 +18,9 @@
|
||||
package org.apache.doris.nereids.memo;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.util.TreeStringUtils;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
|
||||
@ -46,6 +43,7 @@ import java.util.stream.Collectors;
|
||||
*/
|
||||
public class Group {
|
||||
private final GroupId groupId;
|
||||
// Save all parent GroupExpression to avoid travsing whole Memo.
|
||||
private final IdentityHashMap<GroupExpression, Void> parentExpressions = new IdentityHashMap<>();
|
||||
|
||||
private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
|
||||
@ -57,7 +55,7 @@ public class Group {
|
||||
private final Map<PhysicalProperties, Pair<Double, GroupExpression>> lowestCostPlans = Maps.newHashMap();
|
||||
private double costLowerBound = -1;
|
||||
private boolean isExplored = false;
|
||||
private boolean hasCost = false;
|
||||
|
||||
private StatsDeriveResult statistics;
|
||||
|
||||
/**
|
||||
@ -87,14 +85,6 @@ public class Group {
|
||||
return groupId;
|
||||
}
|
||||
|
||||
public boolean isHasCost() {
|
||||
return hasCost;
|
||||
}
|
||||
|
||||
public void setHasCost(boolean hasCost) {
|
||||
this.hasCost = hasCost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add new {@link GroupExpression} into this group.
|
||||
*
|
||||
@ -132,27 +122,6 @@ public class Group {
|
||||
logicalExpressions.add(groupExpression);
|
||||
}
|
||||
|
||||
public void addPhysicalExpression(GroupExpression groupExpression) {
|
||||
groupExpression.setOwnerGroup(this);
|
||||
physicalExpressions.add(groupExpression);
|
||||
}
|
||||
|
||||
/**
|
||||
* Rewrite the logical group expression to the new logical group expression.
|
||||
*
|
||||
* @param newExpression new logical group expression
|
||||
* @return old logical group expression
|
||||
*/
|
||||
public GroupExpression rewriteLogicalExpression(GroupExpression newExpression,
|
||||
LogicalProperties logicalProperties) {
|
||||
newExpression.setOwnerGroup(this);
|
||||
this.logicalProperties = logicalProperties;
|
||||
GroupExpression oldExpression = getLogicalExpression();
|
||||
logicalExpressions.clear();
|
||||
logicalExpressions.add(newExpression);
|
||||
return oldExpression;
|
||||
}
|
||||
|
||||
public List<GroupExpression> clearLogicalExpressions() {
|
||||
List<GroupExpression> move = logicalExpressions.stream()
|
||||
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
|
||||
@ -173,10 +142,6 @@ public class Group {
|
||||
return costLowerBound;
|
||||
}
|
||||
|
||||
public void setCostLowerBound(double costLowerBound) {
|
||||
this.costLowerBound = costLowerBound;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
|
||||
*/
|
||||
@ -273,32 +238,6 @@ public class Group {
|
||||
return Optional.ofNullable(lowestCostPlans.get(physicalProperties));
|
||||
}
|
||||
|
||||
public Map<PhysicalProperties, Pair<Double, GroupExpression>> getLowestCostPlans() {
|
||||
return lowestCostPlans;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first Plan from Memo.
|
||||
*/
|
||||
public PhysicalPlan extractPlan() throws AnalysisException {
|
||||
GroupExpression groupExpression = this.physicalExpressions.get(0);
|
||||
|
||||
List<Plan> planChildren = com.google.common.collect.Lists.newArrayList();
|
||||
for (int i = 0; i < groupExpression.arity(); i++) {
|
||||
planChildren.add(groupExpression.child(i).extractPlan());
|
||||
}
|
||||
|
||||
Plan plan = groupExpression.getPlan()
|
||||
.withChildren(planChildren)
|
||||
.withGroupExpression(Optional.of(groupExpression));
|
||||
if (!(plan instanceof PhysicalPlan)) {
|
||||
throw new AnalysisException("generate logical plan");
|
||||
}
|
||||
PhysicalPlan physicalPlan = (PhysicalPlan) plan;
|
||||
|
||||
return physicalPlan;
|
||||
}
|
||||
|
||||
public List<GroupExpression> getParentGroupExpressions() {
|
||||
return ImmutableList.copyOf(parentExpressions.keySet());
|
||||
}
|
||||
@ -318,10 +257,6 @@ public class Group {
|
||||
return parentExpressions.size();
|
||||
}
|
||||
|
||||
public int parentExpressionNum() {
|
||||
return parentExpressions.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -43,7 +44,7 @@ public class GroupExpression {
|
||||
private double cost = 0.0;
|
||||
private CostEstimate costEstimate = null;
|
||||
private Group ownerGroup;
|
||||
private List<Group> children;
|
||||
private final List<Group> children;
|
||||
private final Plan plan;
|
||||
private final BitSet ruleMasks;
|
||||
private boolean statDerived;
|
||||
@ -71,11 +72,11 @@ public class GroupExpression {
|
||||
this.plan = Objects.requireNonNull(plan, "plan can not be null")
|
||||
.withGroupExpression(Optional.of(this));
|
||||
this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null"));
|
||||
this.children.forEach(childGroup -> childGroup.addParentExpression(this));
|
||||
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
|
||||
this.statDerived = false;
|
||||
this.lowestCostTable = Maps.newHashMap();
|
||||
this.requestPropertiesMap = Maps.newHashMap();
|
||||
this.children.forEach(childGroup -> childGroup.addParentExpression(this));
|
||||
}
|
||||
|
||||
public PhysicalProperties getOutputProperties(PhysicalProperties requestProperties) {
|
||||
@ -111,21 +112,13 @@ public class GroupExpression {
|
||||
/**
|
||||
* replaceChild.
|
||||
*
|
||||
* @param originChild origin child group
|
||||
* @param oldChild origin child group
|
||||
* @param newChild new child group
|
||||
*/
|
||||
public void replaceChild(Group originChild, Group newChild) {
|
||||
originChild.removeParentExpression(this);
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
if (children.get(i) == originChild) {
|
||||
children.set(i, newChild);
|
||||
newChild.addParentExpression(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void setChild(int index, Group group) {
|
||||
this.children.set(index, group);
|
||||
public void replaceChild(Group oldChild, Group newChild) {
|
||||
oldChild.removeParentExpression(this);
|
||||
newChild.addParentExpression(this);
|
||||
Utils.replaceList(children, oldChild, newChild);
|
||||
}
|
||||
|
||||
public boolean hasApplied(Rule rule) {
|
||||
@ -255,7 +248,7 @@ public class GroupExpression {
|
||||
if (costEstimate != null) {
|
||||
builder.append(" est=").append(costEstimate);
|
||||
}
|
||||
builder.append(" (plan=" + plan.toString() + ") children=[");
|
||||
builder.append(" (plan=").append(plan.toString()).append(") children=[");
|
||||
for (Group group : children) {
|
||||
builder.append(group.getGroupId()).append(" ");
|
||||
}
|
||||
|
||||
@ -72,7 +72,6 @@ public class Memo {
|
||||
|
||||
/**
|
||||
* Add plan to Memo.
|
||||
* TODO: add ut later
|
||||
*
|
||||
* @param plan {@link Plan} or {@link Expression} to be added
|
||||
* @param target target group to add node. null to generate new Group
|
||||
@ -386,9 +385,10 @@ public class Memo {
|
||||
}
|
||||
}
|
||||
for (GroupExpression groupExpression : needReplaceChild) {
|
||||
// After change GroupExpression children, the hashcode will change,
|
||||
// so need to reinsert into map.
|
||||
groupExpressions.remove(groupExpression);
|
||||
List<Group> children = groupExpression.children();
|
||||
// TODO: use a better way to replace child, avoid traversing all groupExpression
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
if (children.get(i).equals(source)) {
|
||||
children.set(i, destination);
|
||||
@ -480,6 +480,7 @@ public class Memo {
|
||||
// case 5:
|
||||
// if targetGroup is null or targetGroup equal to the existedExpression's ownerGroup,
|
||||
// then recycle the temporary new group expression
|
||||
// No ownerGroup, don't need ownerGroup.removeChild()
|
||||
recycleExpression(newExpression);
|
||||
return CopyInResult.of(false, existedExpression);
|
||||
}
|
||||
@ -512,8 +513,7 @@ public class Memo {
|
||||
List<GroupExpression> logicalExpressions = fromGroup.clearLogicalExpressions();
|
||||
recycleGroup(fromGroup);
|
||||
|
||||
recycleLogicalExpressions(targetGroup);
|
||||
recyclePhysicalExpressions(targetGroup);
|
||||
recycleLogicalAndPhysicalExpressions(targetGroup);
|
||||
|
||||
for (GroupExpression logicalExpression : logicalExpressions) {
|
||||
targetGroup.addLogicalExpression(logicalExpression);
|
||||
@ -522,8 +522,7 @@ public class Memo {
|
||||
}
|
||||
|
||||
private void reInitGroup(Group group, GroupExpression initLogicalExpression, LogicalProperties logicalProperties) {
|
||||
recycleLogicalExpressions(group);
|
||||
recyclePhysicalExpressions(group);
|
||||
recycleLogicalAndPhysicalExpressions(group);
|
||||
|
||||
group.setLogicalProperties(logicalProperties);
|
||||
group.addLogicalExpression(initLogicalExpression);
|
||||
@ -562,42 +561,45 @@ public class Memo {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Notice: this func don't replace { Parent GroupExpressions -> this Group }.
|
||||
*/
|
||||
private void recycleGroup(Group group) {
|
||||
// recycle in memo.
|
||||
if (groups.get(group.getGroupId()) == group) {
|
||||
groups.remove(group.getGroupId());
|
||||
}
|
||||
recycleLogicalExpressions(group);
|
||||
recyclePhysicalExpressions(group);
|
||||
|
||||
// recycle children GroupExpression
|
||||
recycleLogicalAndPhysicalExpressions(group);
|
||||
}
|
||||
|
||||
private void recycleLogicalExpressions(Group group) {
|
||||
if (!group.getLogicalExpressions().isEmpty()) {
|
||||
for (GroupExpression logicalExpression : group.getLogicalExpressions()) {
|
||||
recycleExpression(logicalExpression);
|
||||
}
|
||||
group.clearLogicalExpressions();
|
||||
}
|
||||
}
|
||||
|
||||
private void recyclePhysicalExpressions(Group group) {
|
||||
if (!group.getPhysicalExpressions().isEmpty()) {
|
||||
for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
|
||||
recycleExpression(physicalExpression);
|
||||
}
|
||||
group.clearPhysicalExpressions();
|
||||
}
|
||||
private void recycleLogicalAndPhysicalExpressions(Group group) {
|
||||
group.getLogicalExpressions().forEach(this::recycleExpression);
|
||||
group.clearLogicalExpressions();
|
||||
|
||||
group.getPhysicalExpressions().forEach(this::recycleExpression);
|
||||
group.clearPhysicalExpressions();
|
||||
}
|
||||
|
||||
/**
|
||||
* Notice: this func don't clear { OwnerGroup() -> this GroupExpression }.
|
||||
*/
|
||||
private void recycleExpression(GroupExpression groupExpression) {
|
||||
// recycle in memo.
|
||||
if (groupExpressions.get(groupExpression) == groupExpression) {
|
||||
groupExpressions.remove(groupExpression);
|
||||
}
|
||||
for (Group childGroup : groupExpression.children()) {
|
||||
|
||||
// recycle parentGroupExpr in childGroup
|
||||
groupExpression.children().forEach(childGroup -> {
|
||||
// if not any groupExpression reference child group, then recycle the child group
|
||||
if (childGroup.removeParentExpression(groupExpression) == 0) {
|
||||
recycleGroup(childGroup);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
groupExpression.setOwnerGroup(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -201,4 +201,12 @@ public class Utils {
|
||||
public static LocalDateTime getLocalDatetimeFromLong(long dateTime) {
|
||||
return LocalDateTime.ofInstant(Instant.ofEpochSecond(dateTime), ZoneId.systemDefault());
|
||||
}
|
||||
|
||||
public static <T> void replaceList(List<T> list, T oldItem, T newItem) {
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
if (list.get(i) == oldItem) {
|
||||
list.set(i, newItem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,13 +61,13 @@ import java.util.Optional;
|
||||
|
||||
class MemoTest implements PatternMatchSupported {
|
||||
|
||||
private ConnectContext connectContext = MemoTestUtils.createConnectContext();
|
||||
private final ConnectContext connectContext = MemoTestUtils.createConnectContext();
|
||||
|
||||
private LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
private final LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
PlanConstructor.newLogicalOlapScan(0, "A", 0),
|
||||
PlanConstructor.newLogicalOlapScan(1, "B", 0));
|
||||
|
||||
private LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
|
||||
private final LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
|
||||
JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
|
||||
|
||||
@Test
|
||||
@ -395,7 +395,7 @@ class MemoTest implements PatternMatchSupported {
|
||||
// valid case: 5 steps
|
||||
class A extends UnboundRelation {
|
||||
// 1: declare the Plan has some states
|
||||
State state;
|
||||
final State state;
|
||||
|
||||
public A(List<String> nameParts, State state) {
|
||||
this(nameParts, state, Optional.empty());
|
||||
|
||||
Reference in New Issue
Block a user