[fix](nereids): fix all bugs in mergeGroup(). (#16079)

* [fix](Nereids): fix mergeGroup()

* polish code

* fix replace children of PhysicalEnforcer

* delete `deleteBestPlan`

* delete `getInputProperties`

* after merge GroupExpression, clear owner Group
This commit is contained in:
jakevin
2023-01-19 19:15:05 +08:00
committed by GitHub
parent dd869077f8
commit c1dd1fc331
7 changed files with 187 additions and 48 deletions

View File

@ -58,7 +58,8 @@ public class ApplyRuleJob extends Job {
@Override
public void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)) {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;
}
countJobExecutionTimesOfGroupExpressions(groupExpression);

View File

@ -104,6 +104,10 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
*/
@Override
public void execute() {
if (groupExpression.isUnused()) {
return;
}
countJobExecutionTimesOfGroupExpressions(groupExpression);
// Do init logic of root plan/groupExpr of `subplan`, only run once per task.
if (curChildIndex == -1) {

View File

@ -58,10 +58,10 @@ public class DeriveStatsJob extends Job {
@Override
public void execute() {
countJobExecutionTimesOfGroupExpressions(groupExpression);
if (groupExpression.isStatDerived()) {
if (groupExpression.isStatDerived() || groupExpression.isUnused()) {
return;
}
countJobExecutionTimesOfGroupExpressions(groupExpression);
if (!deriveChildren && groupExpression.arity() > 0) {
pushJob(new DeriveStatsJob(groupExpression, true, context));

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.StatsDeriveResult;
@ -34,10 +35,11 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
@ -213,6 +215,23 @@ public class Group {
lowestCostPlans.put(newProperty, pair);
}
/**
* replace oldGroupExpression with newGroupExpression in lowestCostPlans.
*/
public void replaceBestPlanGroupExpr(GroupExpression oldGroupExpression, GroupExpression newGroupExpression) {
Map<PhysicalProperties, Pair<Double, GroupExpression>> needReplaceBestExpressions = Maps.newHashMap();
for (Iterator<Entry<PhysicalProperties, Pair<Double, GroupExpression>>> iterator =
lowestCostPlans.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<PhysicalProperties, Pair<Double, GroupExpression>> entry = iterator.next();
Pair<Double, GroupExpression> pair = entry.getValue();
if (pair.second.equals(oldGroupExpression)) {
needReplaceBestExpressions.put(entry.getKey(), Pair.of(pair.first, newGroupExpression));
iterator.remove();
}
}
lowestCostPlans.putAll(needReplaceBestExpressions);
}
public StatsDeriveResult getStatistics() {
return statistics;
}
@ -262,26 +281,54 @@ public class Group {
* @param target the new owner group of expressions
*/
public void mergeTo(Group target) {
// move parentExpressions Ownership
// move parentExpressions Ownership
parentExpressions.keySet().forEach(target::addParentExpression);
// PhysicalEnforcer isn't in groupExpressions, so mergeGroup() can't replace its children.
// So we need to manually replace the children of PhysicalEnforcer in here.
parentExpressions.keySet().stream().filter(ge -> ge.getPlan() instanceof PhysicalDistribute)
.forEach(ge -> ge.children().set(0, target));
parentExpressions.clear();
// move LogicalExpression PhysicalExpression Ownership
HashSet<GroupExpression> logicalSet = new HashSet<>(target.getLogicalExpressions());
logicalExpressions.stream().filter(ge -> !logicalSet.contains(ge)).forEach(target::addLogicalExpression);
Map<GroupExpression, GroupExpression> logicalSet = target.getLogicalExpressions().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
for (GroupExpression logicalExpression : logicalExpressions) {
GroupExpression existGroupExpr = logicalSet.get(logicalExpression);
if (existGroupExpr != null) {
Preconditions.checkState(logicalExpression != existGroupExpr, "must not equals");
// lowCostPlans must be physical GroupExpression, don't need to replaceBestPlanGroupExpr
logicalExpression.mergeToNotOwnerRemove(existGroupExpr);
} else {
target.addLogicalExpression(logicalExpression);
}
}
logicalExpressions.clear();
// movePhysicalExpressionOwnership
HashSet<GroupExpression> physicalSet = new HashSet<>(target.getPhysicalExpressions());
physicalExpressions.stream().filter(ge -> !physicalSet.contains(ge)).forEach(target::addGroupExpression);
Map<GroupExpression, GroupExpression> physicalSet = target.getPhysicalExpressions().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
for (GroupExpression physicalExpression : physicalExpressions) {
GroupExpression existGroupExpr = physicalSet.get(physicalExpression);
if (existGroupExpr != null) {
Preconditions.checkState(physicalExpression != existGroupExpr, "must not equals");
physicalExpression.getOwnerGroup().replaceBestPlanGroupExpr(physicalExpression, existGroupExpr);
physicalExpression.mergeToNotOwnerRemove(existGroupExpr);
} else {
target.addPhysicalExpression(physicalExpression);
}
}
physicalExpressions.clear();
// moveLowestCostPlansOwnership
// Above we already replaceBestPlanGroupExpr, but we still need to moveLowestCostPlansOwnership.
// Because PhysicalEnforcer don't exist in physicalExpressions, so above `replaceBestPlanGroupExpr` can't
// move PhysicalEnforcer in lowestCostPlans. Following code can move PhysicalEnforcer in lowestCostPlans.
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
GroupExpression bestGroupExpression = costAndGroupExpr.second;
// change into target group.
if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
// move PhysicalEnforcer into target
Preconditions.checkState(bestGroupExpression.getPlan() instanceof PhysicalDistribute);
bestGroupExpression.setOwnerGroup(target);
}
// move lowestCostPlans Ownership
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {

View File

@ -66,6 +66,9 @@ public class GroupExpression {
// value is the request physical properties
private final Map<PhysicalProperties, PhysicalProperties> requestPropertiesMap;
// After mergeGroup(), source Group was cleaned up, but it may be in the Job Stack. So use this to mark and skip it.
private boolean isUnused = false;
public GroupExpression(Plan plan) {
this(plan, Lists.newArrayList());
}
@ -163,6 +166,22 @@ public class GroupExpression {
this.statDerived = statDerived;
}
/**
* Check this GroupExpression isUnused. See detail of `isUnused` in its comment.
*/
public boolean isUnused() {
if (isUnused) {
Preconditions.checkState(children.isEmpty() || ownerGroup == null);
return true;
}
Preconditions.checkState(ownerGroup != null);
return false;
}
public void setUnused(boolean isUnused) {
this.isUnused = isUnused;
}
public Map<PhysicalProperties, Pair<Double, List<PhysicalProperties>>> getLowestCostTable() {
return lowestCostTable;
}
@ -175,6 +194,7 @@ public class GroupExpression {
/**
* Add a (outputProperties) -> (cost, childrenInputProperties) in lowestCostTable.
* if the outputProperties exists, will be covered.
*
* @return true if lowest cost table change.
*/
public boolean updateLowestCostTable(PhysicalProperties outputProperties,
@ -209,6 +229,32 @@ public class GroupExpression {
this.requestPropertiesMap.put(requiredPropertySet, outputPropertySet);
}
/**
* Merge GroupExpression.
*/
public void mergeTo(GroupExpression target) {
this.ownerGroup.removeGroupExpression(this);
this.mergeToNotOwnerRemove(target);
}
/**
* Merge GroupExpression, but owner don't remove this GroupExpression.
*/
public void mergeToNotOwnerRemove(GroupExpression target) {
// LowestCostTable
this.getLowestCostTable()
.forEach((properties, pair) -> target.updateLowestCostTable(properties, pair.second, pair.first));
// requestPropertiesMap
target.requestPropertiesMap.putAll(this.requestPropertiesMap);
// ruleMasks
target.ruleMasks.or(this.ruleMasks);
// clear
this.children.forEach(child -> child.removeParentExpression(this));
this.children.clear();
this.ownerGroup = null;
}
public double getCost() {
return cost;
}

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.StatsDeriveResult;
@ -444,29 +445,40 @@ public class Memo {
}
}
GROUP_MERGE_TRACER.log(GroupMergeEvent.of(source, destination, needReplaceChild));
for (GroupExpression groupExpression : needReplaceChild) {
// After change GroupExpression children, the hashcode will change,
// so need to reinsert into map.
groupExpressions.remove(groupExpression);
Utils.replaceList(groupExpression.children(), source, destination);
GroupExpression that = groupExpressions.get(groupExpression);
if (that != null && that.getOwnerGroup() != null
&& !that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) {
// remove groupExpression from its owner group to avoid adding it to that.getOwnerGroup()
// that.getOwnerGroup() already has this groupExpression.
Group ownerGroup = groupExpression.getOwnerGroup();
groupExpression.getOwnerGroup().removeGroupExpression(groupExpression);
mergeGroup(ownerGroup, that.getOwnerGroup());
Map<Group, Group> needMergeGroupPairs = Maps.newHashMap();
for (GroupExpression reinsertGroupExpr : needReplaceChild) {
// After change GroupExpression children, hashcode will change, so need to reinsert into map.
groupExpressions.remove(reinsertGroupExpr);
Utils.replaceList(reinsertGroupExpr.children(), source, destination);
GroupExpression existGroupExpr = groupExpressions.get(reinsertGroupExpr);
if (existGroupExpr != null) {
Preconditions.checkState(existGroupExpr.getOwnerGroup() != null);
// remove reinsertGroupExpr from its owner group to avoid adding it to existGroupExpr.getOwnerGroup()
// existGroupExpr.getOwnerGroup() already has this reinsertGroupExpr.
reinsertGroupExpr.setUnused(true);
if (existGroupExpr.getOwnerGroup().equals(reinsertGroupExpr.getOwnerGroup())) {
// reinsertGroupExpr & existGroupExpr are in same Group, so merge them.
if (reinsertGroupExpr.getPlan() instanceof PhysicalPlan) {
reinsertGroupExpr.getOwnerGroup().replaceBestPlanGroupExpr(reinsertGroupExpr, existGroupExpr);
}
// existingGroupExpression merge the state of reinsertGroupExpr
reinsertGroupExpr.mergeTo(existGroupExpr);
} else {
// reinsertGroupExpr & existGroupExpr aren't in same group, need to merge their OwnerGroup.
needMergeGroupPairs.put(reinsertGroupExpr.getOwnerGroup(), existGroupExpr.getOwnerGroup());
}
} else {
groupExpressions.put(groupExpression, groupExpression);
groupExpressions.put(reinsertGroupExpr, reinsertGroupExpr);
}
}
if (!source.equals(destination)) {
// TODO: stats and other
source.mergeTo(destination);
groups.remove(source.getGroupId());
}
needMergeGroupPairs.forEach(this::mergeGroup);
return destination;
}

View File

@ -53,7 +53,6 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
@ -68,33 +67,63 @@ class MemoTest implements PatternMatchSupported {
private final LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>(
JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0));
/*
* ┌─────────────────────────┐ ┌───────────┐
* │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ │
* │ │0┌─┐ │ │1┌─┐ │ │ │ │1┌─┐ │ │
* │ │ └┼┘ │ │ └┼┘ │ │ │ │ └┼┘ │ │
* │ └──┼──┘ └──┼──┘ │ │ └──┼──┘ │
* │Memo │ │ ├────►│Memo │ │
* │ ┌──▼──┐ ┌──▼──┐ │ │ ┌──▼──┐ │
* │ │ src │ │ dst │ │ │ │ dst │ │
* │ │2 │ │3 │ │ │ │3 │ │
* │ └─────┘ └─────┘ │ │ └─────┘ │
* └─────────────────────────┘ └───────────┘
*/
@Test
void mergeGroup() {
void testMergeGroup() {
Group srcGroup = new Group(new GroupId(2), new GroupExpression(new FakePlan()),
new LogicalProperties(ArrayList::new));
Group dstGroup = new Group(new GroupId(3), new GroupExpression(new FakePlan()),
new LogicalProperties(ArrayList::new));
FakePlan fakePlan = new FakePlan();
GroupExpression srcParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(srcGroup));
Group srcParentGroup = new Group(new GroupId(0), srcParentExpression, new LogicalProperties(ArrayList::new));
srcParentGroup.setBestPlan(srcParentExpression, Double.MIN_VALUE, PhysicalProperties.ANY);
GroupExpression dstParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(dstGroup));
Group dstParentGroup = new Group(new GroupId(1), dstParentExpression, new LogicalProperties(ArrayList::new));
Memo memo = new Memo();
GroupId gid2 = new GroupId(2);
Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
GroupId gid3 = new GroupId(3);
Group dstGroup = new Group(gid3, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new));
FakePlan d = new FakePlan();
GroupExpression ge1 = new GroupExpression(d, Arrays.asList(srcGroup));
GroupId gid0 = new GroupId(0);
Group g1 = new Group(gid0, ge1, new LogicalProperties(ArrayList::new));
g1.setBestPlan(ge1, Double.MIN_VALUE, PhysicalProperties.ANY);
GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
GroupId gid1 = new GroupId(1);
Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups");
groups.put(gid2, srcGroup);
groups.put(gid3, dstGroup);
groups.put(gid0, g1);
groups.put(gid1, g2);
groups.put(srcGroup.getGroupId(), srcGroup);
groups.put(dstGroup.getGroupId(), dstGroup);
groups.put(srcParentGroup.getGroupId(), srcParentGroup);
groups.put(dstParentGroup.getGroupId(), dstParentGroup);
Map<GroupExpression, GroupExpression> groupExpressions =
Deencapsulation.getField(memo, "groupExpressions");
groupExpressions.put(ge1, ge1);
groupExpressions.put(ge2, ge2);
groupExpressions.put(srcParentExpression, srcParentExpression);
groupExpressions.put(dstParentExpression, dstParentExpression);
memo.mergeGroup(srcGroup, dstGroup);
Assertions.assertNull(g1.getBestPlan(PhysicalProperties.ANY));
Assertions.assertEquals(ge1.getOwnerGroup(), g2);
// check
Assertions.assertEquals(0, srcGroup.getParentGroupExpressions().size());
Assertions.assertEquals(0, srcGroup.getPhysicalExpressions().size());
Assertions.assertEquals(0, srcGroup.getLogicalExpressions().size());
Assertions.assertEquals(0, srcParentGroup.getParentGroupExpressions().size());
Assertions.assertEquals(0, srcParentGroup.getPhysicalExpressions().size());
Assertions.assertEquals(0, srcParentGroup.getLogicalExpressions().size());
// TODO: add root test.
// Assertions.assertEquals(memo.getRoot(), dstParentGroup);
Assertions.assertEquals(2, dstGroup.getPhysicalExpressions().size());
Assertions.assertEquals(1, dstParentGroup.getPhysicalExpressions().size());
Assertions.assertNull(srcParentExpression.getOwnerGroup());
Assertions.assertEquals(0, srcParentExpression.arity());
}
/**