diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index 978aa0e0fe..686be94c72 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -18,12 +18,9 @@ package org.apache.doris.nereids.memo; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.util.TreeStringUtils; import org.apache.doris.statistics.StatsDeriveResult; @@ -46,6 +43,7 @@ import java.util.stream.Collectors; */ public class Group { private final GroupId groupId; + // Save all parent GroupExpression to avoid travsing whole Memo. private final IdentityHashMap parentExpressions = new IdentityHashMap<>(); private final List logicalExpressions = Lists.newArrayList(); @@ -57,7 +55,7 @@ public class Group { private final Map> lowestCostPlans = Maps.newHashMap(); private double costLowerBound = -1; private boolean isExplored = false; - private boolean hasCost = false; + private StatsDeriveResult statistics; /** @@ -87,14 +85,6 @@ public class Group { return groupId; } - public boolean isHasCost() { - return hasCost; - } - - public void setHasCost(boolean hasCost) { - this.hasCost = hasCost; - } - /** * Add new {@link GroupExpression} into this group. * @@ -132,27 +122,6 @@ public class Group { logicalExpressions.add(groupExpression); } - public void addPhysicalExpression(GroupExpression groupExpression) { - groupExpression.setOwnerGroup(this); - physicalExpressions.add(groupExpression); - } - - /** - * Rewrite the logical group expression to the new logical group expression. - * - * @param newExpression new logical group expression - * @return old logical group expression - */ - public GroupExpression rewriteLogicalExpression(GroupExpression newExpression, - LogicalProperties logicalProperties) { - newExpression.setOwnerGroup(this); - this.logicalProperties = logicalProperties; - GroupExpression oldExpression = getLogicalExpression(); - logicalExpressions.clear(); - logicalExpressions.add(newExpression); - return oldExpression; - } - public List clearLogicalExpressions() { List move = logicalExpressions.stream() .peek(groupExpr -> groupExpr.setOwnerGroup(null)) @@ -173,10 +142,6 @@ public class Group { return costLowerBound; } - public void setCostLowerBound(double costLowerBound) { - this.costLowerBound = costLowerBound; - } - /** * Set or update lowestCostPlans: properties --> Pair.of(cost, expression) */ @@ -273,32 +238,6 @@ public class Group { return Optional.ofNullable(lowestCostPlans.get(physicalProperties)); } - public Map> getLowestCostPlans() { - return lowestCostPlans; - } - - /** - * Get the first Plan from Memo. - */ - public PhysicalPlan extractPlan() throws AnalysisException { - GroupExpression groupExpression = this.physicalExpressions.get(0); - - List planChildren = com.google.common.collect.Lists.newArrayList(); - for (int i = 0; i < groupExpression.arity(); i++) { - planChildren.add(groupExpression.child(i).extractPlan()); - } - - Plan plan = groupExpression.getPlan() - .withChildren(planChildren) - .withGroupExpression(Optional.of(groupExpression)); - if (!(plan instanceof PhysicalPlan)) { - throw new AnalysisException("generate logical plan"); - } - PhysicalPlan physicalPlan = (PhysicalPlan) plan; - - return physicalPlan; - } - public List getParentGroupExpressions() { return ImmutableList.copyOf(parentExpressions.keySet()); } @@ -318,10 +257,6 @@ public class Group { return parentExpressions.size(); } - public int parentExpressionNum() { - return parentExpressions.size(); - } - @Override public boolean equals(Object o) { if (this == o) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index a0ad369664..b8e8edcc19 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.StatsDeriveResult; import com.google.common.base.Preconditions; @@ -43,7 +44,7 @@ public class GroupExpression { private double cost = 0.0; private CostEstimate costEstimate = null; private Group ownerGroup; - private List children; + private final List children; private final Plan plan; private final BitSet ruleMasks; private boolean statDerived; @@ -71,11 +72,11 @@ public class GroupExpression { this.plan = Objects.requireNonNull(plan, "plan can not be null") .withGroupExpression(Optional.of(this)); this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null")); + this.children.forEach(childGroup -> childGroup.addParentExpression(this)); this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal()); this.statDerived = false; this.lowestCostTable = Maps.newHashMap(); this.requestPropertiesMap = Maps.newHashMap(); - this.children.forEach(childGroup -> childGroup.addParentExpression(this)); } public PhysicalProperties getOutputProperties(PhysicalProperties requestProperties) { @@ -111,21 +112,13 @@ public class GroupExpression { /** * replaceChild. * - * @param originChild origin child group + * @param oldChild origin child group * @param newChild new child group */ - public void replaceChild(Group originChild, Group newChild) { - originChild.removeParentExpression(this); - for (int i = 0; i < children.size(); i++) { - if (children.get(i) == originChild) { - children.set(i, newChild); - newChild.addParentExpression(this); - } - } - } - - public void setChild(int index, Group group) { - this.children.set(index, group); + public void replaceChild(Group oldChild, Group newChild) { + oldChild.removeParentExpression(this); + newChild.addParentExpression(this); + Utils.replaceList(children, oldChild, newChild); } public boolean hasApplied(Rule rule) { @@ -255,7 +248,7 @@ public class GroupExpression { if (costEstimate != null) { builder.append(" est=").append(costEstimate); } - builder.append(" (plan=" + plan.toString() + ") children=["); + builder.append(" (plan=").append(plan.toString()).append(") children=["); for (Group group : children) { builder.append(group.getGroupId()).append(" "); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index 81b1dc7c3f..31a5dee6a4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -72,7 +72,6 @@ public class Memo { /** * Add plan to Memo. - * TODO: add ut later * * @param plan {@link Plan} or {@link Expression} to be added * @param target target group to add node. null to generate new Group @@ -386,9 +385,10 @@ public class Memo { } } for (GroupExpression groupExpression : needReplaceChild) { + // After change GroupExpression children, the hashcode will change, + // so need to reinsert into map. groupExpressions.remove(groupExpression); List children = groupExpression.children(); - // TODO: use a better way to replace child, avoid traversing all groupExpression for (int i = 0; i < children.size(); i++) { if (children.get(i).equals(source)) { children.set(i, destination); @@ -480,6 +480,7 @@ public class Memo { // case 5: // if targetGroup is null or targetGroup equal to the existedExpression's ownerGroup, // then recycle the temporary new group expression + // No ownerGroup, don't need ownerGroup.removeChild() recycleExpression(newExpression); return CopyInResult.of(false, existedExpression); } @@ -512,8 +513,7 @@ public class Memo { List logicalExpressions = fromGroup.clearLogicalExpressions(); recycleGroup(fromGroup); - recycleLogicalExpressions(targetGroup); - recyclePhysicalExpressions(targetGroup); + recycleLogicalAndPhysicalExpressions(targetGroup); for (GroupExpression logicalExpression : logicalExpressions) { targetGroup.addLogicalExpression(logicalExpression); @@ -522,8 +522,7 @@ public class Memo { } private void reInitGroup(Group group, GroupExpression initLogicalExpression, LogicalProperties logicalProperties) { - recycleLogicalExpressions(group); - recyclePhysicalExpressions(group); + recycleLogicalAndPhysicalExpressions(group); group.setLogicalProperties(logicalProperties); group.addLogicalExpression(initLogicalExpression); @@ -562,42 +561,45 @@ public class Memo { } } + /** + * Notice: this func don't replace { Parent GroupExpressions -> this Group }. + */ private void recycleGroup(Group group) { + // recycle in memo. if (groups.get(group.getGroupId()) == group) { groups.remove(group.getGroupId()); } - recycleLogicalExpressions(group); - recyclePhysicalExpressions(group); + + // recycle children GroupExpression + recycleLogicalAndPhysicalExpressions(group); } - private void recycleLogicalExpressions(Group group) { - if (!group.getLogicalExpressions().isEmpty()) { - for (GroupExpression logicalExpression : group.getLogicalExpressions()) { - recycleExpression(logicalExpression); - } - group.clearLogicalExpressions(); - } - } - - private void recyclePhysicalExpressions(Group group) { - if (!group.getPhysicalExpressions().isEmpty()) { - for (GroupExpression physicalExpression : group.getPhysicalExpressions()) { - recycleExpression(physicalExpression); - } - group.clearPhysicalExpressions(); - } + private void recycleLogicalAndPhysicalExpressions(Group group) { + group.getLogicalExpressions().forEach(this::recycleExpression); + group.clearLogicalExpressions(); + + group.getPhysicalExpressions().forEach(this::recycleExpression); + group.clearPhysicalExpressions(); } + /** + * Notice: this func don't clear { OwnerGroup() -> this GroupExpression }. + */ private void recycleExpression(GroupExpression groupExpression) { + // recycle in memo. if (groupExpressions.get(groupExpression) == groupExpression) { groupExpressions.remove(groupExpression); } - for (Group childGroup : groupExpression.children()) { + + // recycle parentGroupExpr in childGroup + groupExpression.children().forEach(childGroup -> { // if not any groupExpression reference child group, then recycle the child group if (childGroup.removeParentExpression(groupExpression) == 0) { recycleGroup(childGroup); } - } + }); + + groupExpression.setOwnerGroup(null); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java index b4c41bdf3f..927c553964 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java @@ -201,4 +201,12 @@ public class Utils { public static LocalDateTime getLocalDatetimeFromLong(long dateTime) { return LocalDateTime.ofInstant(Instant.ofEpochSecond(dateTime), ZoneId.systemDefault()); } + + public static void replaceList(List list, T oldItem, T newItem) { + for (int i = 0; i < list.size(); i++) { + if (list.get(i) == oldItem) { + list.set(i, newItem); + } + } + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index f9b04d9758..1735ecec1f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -61,13 +61,13 @@ import java.util.Optional; class MemoTest implements PatternMatchSupported { - private ConnectContext connectContext = MemoTestUtils.createConnectContext(); + private final ConnectContext connectContext = MemoTestUtils.createConnectContext(); - private LogicalJoin logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN, + private final LogicalJoin logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN, PlanConstructor.newLogicalOlapScan(0, "A", 0), PlanConstructor.newLogicalOlapScan(1, "B", 0)); - private LogicalJoin, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>( + private final LogicalJoin, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>( JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0)); @Test @@ -395,7 +395,7 @@ class MemoTest implements PatternMatchSupported { // valid case: 5 steps class A extends UnboundRelation { // 1: declare the Plan has some states - State state; + final State state; public A(List nameParts, State state) { this(nameParts, state, Optional.empty());