[enhance](Nereids): check cycle by getParentGroupExpressions(). (#18687)

This commit is contained in:
jakevin
2023-04-20 11:51:58 +08:00
committed by GitHub
parent 3328a65b75
commit 52d32cccad
2 changed files with 25 additions and 25 deletions

View File

@ -297,7 +297,7 @@ public class Group {
*/
public void mergeTo(Group target) {
// move parentExpressions Ownership
parentExpressions.keySet().forEach(target::addParentExpression);
parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent));
// PhysicalEnforcer isn't in groupExpressions, so mergeGroup() can't replace its children.
// So we need to manually replace the children of PhysicalEnforcer in here.
// TODO: SortEnforcer?

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.Statistics;
@ -450,21 +451,23 @@ public class Memo {
*
* @param source source group
* @param destination destination group
* @return merged group
*/
public Group mergeGroup(Group source, Group destination) {
public void mergeGroup(Group source, Group destination) {
if (source.equals(destination)) {
return source;
return;
}
List<GroupExpression> needReplaceChild = Lists.newArrayList();
for (GroupExpression groupExpression : groupExpressions.values()) {
if (groupExpression.children().contains(source)) {
if (groupExpression.getOwnerGroup().equals(destination)) {
// cycle, we should not merge
return null;
}
needReplaceChild.add(groupExpression);
for (GroupExpression parent : source.getParentGroupExpressions()) {
if (parent.getOwnerGroup().equals(destination)) {
// cycle, we should not merge
return;
}
// PhysicalEnforcer don't exist in memo, so we need skip them.
if (parent.getPlan() instanceof PhysicalDistribute) {
// TODO: SortEnforcer.
continue;
}
needReplaceChild.add(parent);
}
GROUP_MERGE_TRACER.log(GroupMergeEvent.of(source, destination, needReplaceChild));
@ -494,12 +497,8 @@ public class Memo {
groupExpressions.put(reinsertGroupExpr, reinsertGroupExpr);
}
}
if (!source.equals(destination)) {
source.mergeTo(destination);
groups.remove(source.getGroupId());
}
return destination;
source.mergeTo(destination);
groups.remove(source.getGroupId());
}
/**
@ -735,11 +734,11 @@ public class Memo {
/**
* rank all plan and select n-th plan, we write the algorithm according paper:
* * Counting,Enumerating, and Sampling of Execution Plans in a Cost-Based Query Optimizer
* * Counting,Enumerating, and Sampling of Execution Plans in a Cost-Based Query Optimizer
* Specifically each physical plan in memo is assigned a unique ID in rank(). And then we sort the
* plan according their cost and choose the n-th plan. Note we don't generate any physical plan in rank
* function.
*
* <p>
* In unrank() function, we will extract the actual physical function according the unique ID
*/
public Pair<Long, Double> rank(long n) {
@ -813,9 +812,10 @@ public class Memo {
return res;
}
/** we permute all children, e.g.,
* for children [1, 2] [1, 2, 3]
* we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3]
/**
* we permute all children, e.g.,
* for children [1, 2] [1, 2, 3]
* we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3]
*/
private void permute(List<List<Pair<Long, Double>>> children, int index,
List<Pair<Long, List<Integer>>> result, List<Integer> current) {
@ -833,9 +833,9 @@ public class Memo {
/**
* This method is used to calculate the unique ID for one combination,
* The current is used to represent the index of the child in lists e.g.,
* for children [1], [1, 2], The possible indices and IDs are:
* [0, 0]: 0*1 + 0*1*2
* [0, 1]: 0*1 + 1*1*2
* for children [1], [1, 2], The possible indices and IDs are:
* [0, 0]: 0*1 + 0*1*2
* [0, 1]: 0*1 + 1*1*2
*/
private static long getUniqueId(List<List<Pair<Long, Double>>> lists, List<Integer> current) {
long id = 0;