[fix](Nereids) avoid to push top Project of JoinCluster in PushdownProjectThroughJoin (#19441)

We shouldn't push top Project of JoinCluster in PushdownProjectThroughJoin

like 

```
 *      Project  (id + 1) if this project is top project of Join Cluster
 *        |     
 *       Join   
 *      /      \         
 *    Join  Join
 *    /  ....
 * Join
```
This commit is contained in:
jakevin
2023-05-11 13:58:54 +08:00
committed by GitHub
parent 834bf2eab7
commit dc497e11bb
13 changed files with 330 additions and 264 deletions

View File

@ -20,7 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -38,75 +38,102 @@ import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for pushdown project through inner/outer join
* Rule for pushdown project through inner/outer join
* Just push down project inside join to avoid to push the top of Join-Cluster.
* <pre>
* Project Join
* | ──► / \
* Join Project Project
* / \ | |
* A B A B
* </pre>
*/
public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory {
public class PushdownProjectThroughInnerJoin implements ExplorationRuleFactory {
public static final PushdownProjectThroughInnerJoin INSTANCE = new PushdownProjectThroughInnerJoin();
/*
* Project Join
* | ──► / \
* Join Project Project
* / \ | |
* A B A B
*/
@Override
public Rule build() {
return logicalProject(logicalJoin())
.whenNot(LogicalProject::isAllSlots)
.when(project -> project.child().getJoinType().isInnerJoin())
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet();
Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet();
// reject hyper edge in Project.
if (!project.getProjects().stream().allMatch(expr -> {
Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
return aOutputExprIdSet.containsAll(inputSlotExprIds)
|| bOutputExprIdSet.containsAll(inputSlotExprIds);
})) {
return null;
}
List<NamedExpression> aProjects = new ArrayList<>();
List<NamedExpression> bProjects = new ArrayList<>();
for (NamedExpression namedExpression : project.getProjects()) {
Set<ExprId> usedExprIds = namedExpression.getInputSlotExprIds();
if (aOutputExprIdSet.containsAll(usedExprIds)) {
aProjects.add(namedExpression);
} else {
bProjects.add(namedExpression);
}
}
boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot));
boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot));
// due to JoinCommute, we don't need to consider just right contains.
if (!leftContains) {
return null;
}
Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects);
Set<Slot> aConditionSlots = CBOUtils.joinChildConditionSlots(join, true);
Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add);
Plan newLeft = CBOUtils.projectOrSelf(newAProject.build(), join.left());
if (!rightContains) {
Plan newJoin = join.withChildrenNoContext(newLeft, join.right());
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}
Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects);
Set<Slot> bConditionSlots = CBOUtils.joinChildConditionSlots(join, false);
Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add);
Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right());
Plan newJoin = join.withChildrenNoContext(newLeft, newRight);
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
public List<Rule> buildRules() {
return ImmutableList.of(
logicalJoin(logicalProject(innerLogicalJoin()), group())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(j -> j.left().isAllSlots())
.whenNot(j -> j.left().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.left();
Plan newLeft = pushdownProject(project);
if (newLeft == null) {
return null;
}
return topJoin.withChildren(newLeft, topJoin.right());
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN),
logicalJoin(group(), logicalProject(innerLogicalJoin()))
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(j -> j.right().isAllSlots())
.whenNot(j -> j.right().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.right();
Plan newRight = pushdownProject(project);
if (newRight == null) {
return null;
}
return topJoin.withChildren(topJoin.left(), newRight);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN)
);
}
private Plan pushdownProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet();
Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet();
// reject hyper edge in Project.
if (!project.getProjects().stream().allMatch(expr -> {
Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
return aOutputExprIdSet.containsAll(inputSlotExprIds)
|| bOutputExprIdSet.containsAll(inputSlotExprIds);
})) {
return null;
}
List<NamedExpression> aProjects = new ArrayList<>();
List<NamedExpression> bProjects = new ArrayList<>();
for (NamedExpression namedExpression : project.getProjects()) {
Set<ExprId> usedExprIds = namedExpression.getInputSlotExprIds();
if (aOutputExprIdSet.containsAll(usedExprIds)) {
aProjects.add(namedExpression);
} else {
bProjects.add(namedExpression);
}
}
boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot));
boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot));
// due to JoinCommute, we don't need to consider just right contains.
if (!leftContains) {
return null;
}
Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects);
Set<Slot> aConditionSlots = CBOUtils.joinChildConditionSlots(join, true);
Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot)
.collect(Collectors.toSet());
aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add);
Plan newLeft = CBOUtils.projectOrSelf(newAProject.build(), join.left());
if (!rightContains) {
Plan newJoin = join.withChildrenNoContext(newLeft, join.right());
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}
Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects);
Set<Slot> bConditionSlots = CBOUtils.joinChildConditionSlots(join, false);
Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot)
.collect(Collectors.toSet());
bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add);
Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right());
Plan newJoin = join.withChildrenNoContext(newLeft, newRight);
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}
}

View File

@ -20,8 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
@ -29,58 +28,68 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for pushdown project through left-semi/anti join
* Rule for pushdown project through left-semi/anti join
* Just push down project inside join to avoid to push the top of Join-Cluster.
* <pre>
* Join Join
* | |
* Project Join
* | ──► / \
* Join Project B
* / \ |
* A B A
* </pre>
*/
public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
public class PushdownProjectThroughSemiJoin implements ExplorationRuleFactory {
public static final PushdownProjectThroughSemiJoin INSTANCE = new PushdownProjectThroughSemiJoin();
/*
* Project Join
* | ──► / \
* Join Project B
* / \ |
* A B A
*/
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(LogicalProject::isAllSlots)
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true);
public List<Rule> buildRules() {
return ImmutableList.of(
logicalJoin(logicalProject(logicalJoin()), group())
.when(j -> j.left().child().getJoinType().isLeftSemiOrAntiJoin())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(j -> j.left().isAllSlots())
.whenNot(j -> j.left().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.left();
Plan newLeft = pushdownProject(project);
return topJoin.withChildren(newLeft, topJoin.right());
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN),
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
Set<Slot> projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot)
.collect(Collectors.toSet());
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add);
Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildrenNoContext(newLeft, join.right());
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
logicalJoin(group(), logicalProject(logicalJoin()))
.when(j -> j.right().child().getJoinType().isLeftSemiOrAntiJoin())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(j -> j.right().isAllSlots())
.whenNot(j -> j.right().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.right();
Plan newRight = pushdownProject(project);
return topJoin.withChildren(topJoin.left(), newRight);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN)
);
}
List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {
List<ExprId> orderExprIds = sortPlan.getOutput().stream().map(Slot::getExprId).collect(Collectors.toList());
// map { project input slot expr id -> project output expr }
Map<ExprId, NamedExpression> map = projects.stream()
.collect(Collectors.toMap(expr -> expr.getInputSlots().iterator().next().getExprId(), expr -> expr));
List<NamedExpression> newProjects = new ArrayList<>();
for (ExprId exprId : orderExprIds) {
if (map.containsKey(exprId)) {
newProjects.add(map.get(exprId));
}
}
return newProjects;
private Plan pushdownProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true);
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
Set<Slot> projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot)
.collect(Collectors.toSet());
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot))
.forEach(newProject::add);
Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildrenNoContext(newLeft, join.right());
return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}
}