[fix](Nereids): fix copyIn() in Memo when useless project with groupplan (#18223)

This commit is contained in:
jakevin
2023-03-30 23:49:21 +08:00
committed by GitHub
parent d6b0fe9072
commit 28793b6441
14 changed files with 73 additions and 69 deletions

View File

@ -202,8 +202,8 @@ public class NereidsPlanner extends Planner {
optimize();
//print memo before choose plan.
//if chooseNthPlan failed, we could get memo to debug
// print memo before choose plan.
// if chooseNthPlan failed, we could get memo to debug
if (ConnectContext.get().getSessionVariable().isDumpNereidsMemo()) {
String memo = cascadesContext.getMemo().toString();
LOG.info(memo);

View File

@ -107,6 +107,22 @@ public class Memo {
return groupExpressions;
}
private Plan skipProject(Plan plan, Group targetGroup) {
if (plan instanceof LogicalProject) {
LogicalProject<Plan> logicalProject = (LogicalProject<Plan>) plan;
if (targetGroup != root) {
if (logicalProject.getOutputSet().equals(logicalProject.child().getOutputSet())) {
return skipProject(logicalProject.child(), targetGroup);
}
} else {
if (logicalProject.getOutput().equals(logicalProject.child().getOutput())) {
return skipProject(logicalProject.child(), targetGroup);
}
}
}
return plan;
}
/**
* Add plan to Memo.
*
@ -122,7 +138,7 @@ public class Memo {
if (rewrite) {
result = doRewrite(plan, target);
} else {
result = doCopyIn(plan, target);
result = doCopyIn(skipProject(plan, target), target);
}
maybeAddStateId(result);
return result;
@ -316,6 +332,17 @@ public class Memo {
}
}
private Plan skipProjectGetChild(Plan plan) {
if (plan instanceof LogicalProject) {
LogicalProject<Plan> logicalProject = (LogicalProject<Plan>) plan;
Plan child = logicalProject.child();
if (logicalProject.getOutputSet().equals(child.getOutputSet())) {
return skipProjectGetChild(child);
}
}
return plan;
}
/**
* add the plan into the target group
* @param plan the plan which want added
@ -326,20 +353,7 @@ public class Memo {
* and the second element is a reference of node in Memo
*/
private CopyInResult doCopyIn(Plan plan, @Nullable Group targetGroup) {
// TODO: this is same with EliminateUnnecessaryProject,
// we need a infra to rewrite plan after every exploration job
if (plan instanceof LogicalProject) {
LogicalProject<Plan> logicalProject = (LogicalProject<Plan>) plan;
if (targetGroup != root) {
if (logicalProject.getOutputSet().equals(logicalProject.child().getOutputSet())) {
return doCopyIn(logicalProject.child(), targetGroup);
}
} else {
if (logicalProject.getOutput().equals(logicalProject.child().getOutput())) {
return doCopyIn(logicalProject.child(), targetGroup);
}
}
}
Preconditions.checkArgument(!(plan instanceof GroupPlan), "plan can not be GroupPlan");
// check logicalproperties, must same output in a Group.
if (targetGroup != null && !plan.getLogicalProperties().equals(targetGroup.getLogicalProperties())) {
throw new IllegalStateException("Insert a plan into targetGroup but differ in logicalproperties");
@ -350,13 +364,14 @@ public class Memo {
}
List<Group> childrenGroups = Lists.newArrayList();
for (int i = 0; i < plan.children().size(); i++) {
Plan child = plan.children().get(i);
// skip useless project.
Plan child = skipProjectGetChild(plan.child(i));
if (child instanceof GroupPlan) {
childrenGroups.add(((GroupPlan) child).getGroup());
} else if (child.getGroupExpression().isPresent()) {
childrenGroups.add(child.getGroupExpression().get().getOwnerGroup());
} else {
childrenGroups.add(copyIn(child, null, false).correspondingExpression.getOwnerGroup());
childrenGroups.add(doCopyIn(child, null).correspondingExpression.getOwnerGroup());
}
}
plan = replaceChildrenToGroupPlan(plan, childrenGroups);

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -85,13 +86,11 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
newBottomJoin.getJoinReorderContext().setHasCommute(false);
// merge newTopHashConjuncts newTopOtherConjuncts topJoin.getOutputExprIdSet()
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, b);
Plan left = newBottomJoin;
Plan right = b;
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, b);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(newTopHashConjuncts,
newTopOtherConjuncts, left, right);

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -77,13 +78,11 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory {
newBottomHashConjuncts, newBottomOtherConjuncts, a, b);
// new Project.
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, c);
Plan left = newBottomJoin;
Plan right = c;
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, c);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(
newTopHashConjuncts, newTopOtherConjuncts, left, right);

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -74,13 +75,11 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory {
newBottomHashConjuncts, newBottomOtherConjuncts, b, c);
// new Project.
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, a);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
Plan left = a;
Plan right = newBottomJoin;
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, a);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(
newTopHashConjuncts, newTopOtherConjuncts, left, right);

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -91,14 +92,14 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory {
newLeftJoinHashJoinConjuncts, newLeftJoinOtherJoinConjuncts, JoinHint.NONE, a, c);
LogicalJoin<GroupPlan, GroupPlan> newRightJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newRightJoinHashJoinConjuncts, newRightJoinOtherJoinConjuncts, JoinHint.NONE, b, d);
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE,
newLeftJoin, newRightJoin);
left, right);
JoinExchange.setNewLeftJoinReorder(newLeftJoin, leftJoin);
JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin);
JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin);

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -90,14 +91,14 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory {
newLeftJoinHashJoinConjuncts, newLeftJoinOtherJoinConjuncts, JoinHint.NONE, a, c);
LogicalJoin<GroupPlan, GroupPlan> newRightJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newRightJoinHashJoinConjuncts, newRightJoinOtherJoinConjuncts, JoinHint.NONE, b, d);
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE,
newLeftJoin, newRightJoin);
left, right);
JoinExchange.setNewLeftJoinReorder(newLeftJoin, leftJoin);
JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin);
JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin);

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -90,14 +91,14 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory {
newLeftJoinHashJoinConjuncts, newLeftJoinOtherJoinConjuncts, JoinHint.NONE, a, c);
LogicalJoin<GroupPlan, GroupPlan> newRightJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newRightJoinHashJoinConjuncts, newRightJoinOtherJoinConjuncts, JoinHint.NONE, b, d);
// Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
// newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
// Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
// Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin);
Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin);
LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE,
newLeftJoin, newRightJoin);
left, right);
JoinExchange.setNewLeftJoinReorder(newLeftJoin, leftJoin);
JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin);
JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin);

View File

@ -52,7 +52,6 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
@Disabled
void testSimple() {
/*
* Star-Join

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class InnerJoinLeftAssociateProjectTest implements MemoPatternMatchSupported {
@ -37,7 +36,6 @@ class InnerJoinLeftAssociateProjectTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
@Disabled
void testSimple() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(

View File

@ -28,7 +28,6 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class InnerJoinRightAssociateProjectTest implements MemoPatternMatchSupported {
@ -37,7 +36,6 @@ class InnerJoinRightAssociateProjectTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
@Disabled
void testSimple() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))

View File

@ -28,12 +28,10 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class JoinExchangeBothProjectTest implements MemoPatternMatchSupported {
@Test
@Disabled
public void testSimple() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

View File

@ -28,12 +28,10 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class JoinExchangeLeftProjectTest implements MemoPatternMatchSupported {
@Test
@Disabled
public void testSimple() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

View File

@ -28,12 +28,10 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
class JoinExchangeRightProjectTest implements MemoPatternMatchSupported {
@Test
@Disabled
public void testSimple() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);