[fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883)
This commit is contained in:
@ -63,8 +63,8 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
|
||||
return logicalJoin(logicalJoin(), group())
|
||||
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
|
||||
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
|
||||
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
|
||||
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
|
||||
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
|
||||
@ -17,11 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinHint;
|
||||
@ -33,9 +35,12 @@ import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* <ul>
|
||||
@ -64,7 +69,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
|
||||
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
|
||||
.when(join -> JoinReorderUtils.checkProject(join.left()))
|
||||
.when(this::conditionChecker)
|
||||
.then(topSemiJoin -> {
|
||||
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child();
|
||||
@ -72,17 +76,17 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
GroupPlan b = bottomJoin.right();
|
||||
GroupPlan c = topSemiJoin.right();
|
||||
|
||||
Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
|
||||
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
|
||||
boolean lasscom = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<ExprId> usedSlotExprIdSet = hashJoinConjunct.getInputSlotExprIds();
|
||||
lasscom = Utils.isIntersecting(usedSlotExprIdSet, aOutputExprIdSet) || lasscom;
|
||||
// push topSemiJoin down project, so we need replace conjuncts by project.
|
||||
Pair<List<Expression>, List<Expression>> conjuncts = replaceConjuncts(topSemiJoin, project);
|
||||
Set<ExprId> conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
|
||||
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
|
||||
ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(),
|
||||
b.getOutputExprIdSet());
|
||||
if (containsType == ContainsType.ALL) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (lasscom) {
|
||||
if (containsType == ContainsType.LEFT) {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
@ -92,22 +96,24 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
// Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN);
|
||||
if (bottomJoin.getJoinType() == JoinType.RIGHT_OUTER_JOIN) {
|
||||
// when bottom join is right outer join, we change it to inner join
|
||||
// if we want to do this trans. However, we do not allow different logical properties
|
||||
// in one group. So we need to change it to inner join in rewrite step.
|
||||
return topSemiJoin;
|
||||
return null;
|
||||
}
|
||||
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
|
||||
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, c);
|
||||
topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, a, c);
|
||||
|
||||
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
|
||||
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
|
||||
JoinHint.NONE,
|
||||
newBottomSemiJoin, b);
|
||||
return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
|
||||
JoinHint.NONE, newBottomSemiJoin, b);
|
||||
return project.withChildren(newTopJoin);
|
||||
} else {
|
||||
if (leftDeep) {
|
||||
return null;
|
||||
}
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
@ -121,40 +127,49 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
// when bottom join is left outer join, we change it to inner join
|
||||
// if we want to do this trans. However, we do not allow different logical properties
|
||||
// in one group. So we need to change it to inner join in rewrite step.
|
||||
return topSemiJoin;
|
||||
return null;
|
||||
}
|
||||
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
|
||||
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, b, c);
|
||||
topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, b, c);
|
||||
|
||||
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
|
||||
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
|
||||
JoinHint.NONE,
|
||||
a, newBottomSemiJoin);
|
||||
return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
|
||||
JoinHint.NONE, a, newBottomSemiJoin);
|
||||
return project.withChildren(newTopJoin);
|
||||
}
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
|
||||
}
|
||||
|
||||
// project of bottomJoin just return A OR B, else return false.
|
||||
private boolean conditionChecker(
|
||||
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) {
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
|
||||
List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
|
||||
List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
|
||||
|
||||
boolean hashContainsA = false;
|
||||
boolean hashContainsB = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
|
||||
hashContainsA = Utils.isIntersecting(usedSlot, aOutput) || hashContainsA;
|
||||
hashContainsB = Utils.isIntersecting(usedSlot, bOutput) || hashContainsB;
|
||||
private Pair<List<Expression>, List<Expression>> replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
|
||||
LogicalProject<? extends Plan> project) {
|
||||
Map<ExprId, Slot> outputToInput = new HashMap<>();
|
||||
for (NamedExpression outputExpr : project.getProjects()) {
|
||||
Set<Slot> usedSlots = outputExpr.getInputSlots();
|
||||
Preconditions.checkState(usedSlots.size() == 1);
|
||||
Slot inputSlot = usedSlots.iterator().next();
|
||||
outputToInput.put(outputExpr.getExprId(), inputSlot);
|
||||
}
|
||||
if (leftDeep && hashContainsB) {
|
||||
return false;
|
||||
List<Expression> topHashConjuncts =
|
||||
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput);
|
||||
List<Expression> topOtherConjuncts =
|
||||
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput);
|
||||
return Pair.of(topHashConjuncts, topOtherConjuncts);
|
||||
}
|
||||
|
||||
enum ContainsType {
|
||||
LEFT, RIGHT, ALL
|
||||
}
|
||||
|
||||
private ContainsType containsChildren(Set<ExprId> conjunctsExprIdSet, Set<ExprId> left, Set<ExprId> right) {
|
||||
boolean containsLeft = Utils.isIntersecting(conjunctsExprIdSet, left);
|
||||
boolean containsRight = Utils.isIntersecting(conjunctsExprIdSet, right);
|
||||
Preconditions.checkState(containsLeft || containsRight, "join output must contain child");
|
||||
if (containsLeft && containsRight) {
|
||||
return ContainsType.ALL;
|
||||
} else if (containsLeft) {
|
||||
return ContainsType.LEFT;
|
||||
} else {
|
||||
return ContainsType.RIGHT;
|
||||
}
|
||||
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
|
||||
return !(hashContainsA && hashContainsB);
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,7 +130,6 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
return otherJoinConjuncts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Expression> getHashJoinConjuncts() {
|
||||
return hashJoinConjuncts;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user