[fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883)

This commit is contained in:
jakevin
2023-02-18 23:12:34 +08:00
committed by GitHub
parent e2e6a0dd83
commit d4cebb39ba
3 changed files with 57 additions and 43 deletions

View File

@ -63,8 +63,8 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
return logicalJoin(logicalJoin(), group())
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -17,11 +17,13 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
@ -33,9 +35,12 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* <ul>
@ -64,7 +69,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child();
@ -72,17 +76,17 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
boolean lasscom = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<ExprId> usedSlotExprIdSet = hashJoinConjunct.getInputSlotExprIds();
lasscom = Utils.isIntersecting(usedSlotExprIdSet, aOutputExprIdSet) || lasscom;
// push topSemiJoin down project, so we need replace conjuncts by project.
Pair<List<Expression>, List<Expression>> conjuncts = replaceConjuncts(topSemiJoin, project);
Set<ExprId> conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(),
b.getOutputExprIdSet());
if (containsType == ContainsType.ALL) {
return null;
}
if (lasscom) {
if (containsType == ContainsType.LEFT) {
/*-
* topSemiJoin project
* / \ |
@ -92,22 +96,24 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
* / \ / \
* A B A C
*/
// Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN);
if (bottomJoin.getJoinType() == JoinType.RIGHT_OUTER_JOIN) {
// when bottom join is right outer join, we change it to inner join
// if we want to do this trans. However, we do not allow different logical properties
// in one group. So we need to change it to inner join in rewrite step.
return topSemiJoin;
return null;
}
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, c);
topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, a, c);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
newBottomSemiJoin, b);
return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
JoinHint.NONE, newBottomSemiJoin, b);
return project.withChildren(newTopJoin);
} else {
if (leftDeep) {
return null;
}
/*-
* topSemiJoin project
* / \ |
@ -121,40 +127,49 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
// when bottom join is left outer join, we change it to inner join
// if we want to do this trans. However, we do not allow different logical properties
// in one group. So we need to change it to inner join in rewrite step.
return topSemiJoin;
return null;
}
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, b, c);
topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, b, c);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
a, newBottomSemiJoin);
return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
JoinHint.NONE, a, newBottomSemiJoin);
return project.withChildren(newTopJoin);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
// project of bottomJoin just return A OR B, else return false.
private boolean conditionChecker(
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) {
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
boolean hashContainsA = false;
boolean hashContainsB = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
hashContainsA = Utils.isIntersecting(usedSlot, aOutput) || hashContainsA;
hashContainsB = Utils.isIntersecting(usedSlot, bOutput) || hashContainsB;
private Pair<List<Expression>, List<Expression>> replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
LogicalProject<? extends Plan> project) {
Map<ExprId, Slot> outputToInput = new HashMap<>();
for (NamedExpression outputExpr : project.getProjects()) {
Set<Slot> usedSlots = outputExpr.getInputSlots();
Preconditions.checkState(usedSlots.size() == 1);
Slot inputSlot = usedSlots.iterator().next();
outputToInput.put(outputExpr.getExprId(), inputSlot);
}
if (leftDeep && hashContainsB) {
return false;
List<Expression> topHashConjuncts =
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput);
List<Expression> topOtherConjuncts =
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput);
return Pair.of(topHashConjuncts, topOtherConjuncts);
}
enum ContainsType {
LEFT, RIGHT, ALL
}
private ContainsType containsChildren(Set<ExprId> conjunctsExprIdSet, Set<ExprId> left, Set<ExprId> right) {
boolean containsLeft = Utils.isIntersecting(conjunctsExprIdSet, left);
boolean containsRight = Utils.isIntersecting(conjunctsExprIdSet, right);
Preconditions.checkState(containsLeft || containsRight, "join output must contain child");
if (containsLeft && containsRight) {
return ContainsType.ALL;
} else if (containsLeft) {
return ContainsType.LEFT;
} else {
return ContainsType.RIGHT;
}
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
return !(hashContainsA && hashContainsB);
}
}

View File

@ -130,7 +130,6 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return otherJoinConjuncts;
}
@Override
public List<Expression> getHashJoinConjuncts() {
return hashJoinConjuncts;
}