diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index f584425dfd..5bd11c71e7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -297,7 +297,7 @@ public class Group { */ public void mergeTo(Group target) { // move parentExpressions Ownership - parentExpressions.keySet().forEach(target::addParentExpression); + parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent)); // PhysicalEnforcer isn't in groupExpressions, so mergeGroup() can't replace its children. // So we need to manually replace the children of PhysicalEnforcer in here. // TODO: SortEnforcer? diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index bc376b1a37..6011508bd2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.qe.ConnectContext; import org.apache.doris.statistics.Statistics; @@ -450,21 +451,23 @@ public class Memo { * * @param source source group * @param destination destination group - * @return merged group */ - public Group mergeGroup(Group source, Group destination) { + public void mergeGroup(Group source, Group destination) { if (source.equals(destination)) { - return source; + return; } List needReplaceChild = Lists.newArrayList(); - for (GroupExpression groupExpression : groupExpressions.values()) { - if (groupExpression.children().contains(source)) { - if (groupExpression.getOwnerGroup().equals(destination)) { - // cycle, we should not merge - return null; - } - needReplaceChild.add(groupExpression); + for (GroupExpression parent : source.getParentGroupExpressions()) { + if (parent.getOwnerGroup().equals(destination)) { + // cycle, we should not merge + return; } + // PhysicalEnforcer don't exist in memo, so we need skip them. + if (parent.getPlan() instanceof PhysicalDistribute) { + // TODO: SortEnforcer. + continue; + } + needReplaceChild.add(parent); } GROUP_MERGE_TRACER.log(GroupMergeEvent.of(source, destination, needReplaceChild)); @@ -494,12 +497,8 @@ public class Memo { groupExpressions.put(reinsertGroupExpr, reinsertGroupExpr); } } - if (!source.equals(destination)) { - source.mergeTo(destination); - groups.remove(source.getGroupId()); - } - - return destination; + source.mergeTo(destination); + groups.remove(source.getGroupId()); } /** @@ -735,11 +734,11 @@ public class Memo { /** * rank all plan and select n-th plan, we write the algorithm according paper: - * * Counting,Enumerating, and Sampling of Execution Plans in a Cost-Based Query Optimizer + * * Counting,Enumerating, and Sampling of Execution Plans in a Cost-Based Query Optimizer * Specifically each physical plan in memo is assigned a unique ID in rank(). And then we sort the * plan according their cost and choose the n-th plan. Note we don't generate any physical plan in rank * function. - * + *

* In unrank() function, we will extract the actual physical function according the unique ID */ public Pair rank(long n) { @@ -813,9 +812,10 @@ public class Memo { return res; } - /** we permute all children, e.g., - * for children [1, 2] [1, 2, 3] - * we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3] + /** + * we permute all children, e.g., + * for children [1, 2] [1, 2, 3] + * we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3] */ private void permute(List>> children, int index, List>> result, List current) { @@ -833,9 +833,9 @@ public class Memo { /** * This method is used to calculate the unique ID for one combination, * The current is used to represent the index of the child in lists e.g., - * for children [1], [1, 2], The possible indices and IDs are: - * [0, 0]: 0*1 + 0*1*2 - * [0, 1]: 0*1 + 1*1*2 + * for children [1], [1, 2], The possible indices and IDs are: + * [0, 0]: 0*1 + 0*1*2 + * [0, 1]: 0*1 + 1*1*2 */ private static long getUniqueId(List>> lists, List current) { long id = 0;