[fix](Nereids): other cond should be kept for each anti join when expanding anti join such as (#31521)

This commit is contained in:
谢健
2024-02-28 18:14:47 +08:00
committed by yiguolei
parent ac38356058
commit 4de25ede85
3 changed files with 425 additions and 52 deletions

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
@ -150,11 +151,11 @@ public class OrExpansion extends OneExplorationRuleFactory {
}
// expand Anti Join:
// Left Anti join cond1 or cond2 Left Anti join cond1
// / \ / \
//left right ===> Anti join cond2 CTERight2
// / \
// CTELeft CTERight1
// Left Anti join cond1 or cond2, other Left Anti join cond1 and other
// / \ / \
//left right ===> Anti join cond2 and other CTERight2
// / \
// CTELeft CTERight1
private Plan expandLeftAntiJoin(CascadesContext ctx,
Pair<List<Expression>, List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> originJoin,
@ -171,14 +172,14 @@ public class OrExpansion extends OneExplorationRuleFactory {
replaced.putAll(right.getProducerToConsumerOutputMap());
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
otherConditions = otherConditions.stream()
List<Expression> newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)).collect(Collectors.toList());
Expression hashCond = disjunctions.get(0);
hashCond = hashCond.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s);
Plan newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
otherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), left, right, null);
newOtherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), left, right, JoinReorderContext.EMPTY);
if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
Plan normalizedPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
(LogicalJoin<? extends Plan, ? extends Plan>) newPlan);
@ -192,10 +193,13 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(newRight);
Map<Slot, Slot> newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap());
newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
.collect(Collectors.toList());
hashCond = hashCond.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s);
newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
new ArrayList<>(), originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), newPlan, newRight, null);
newOtherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), newPlan, newRight, JoinReorderContext.EMPTY);
if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
newPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
(LogicalJoin<? extends Plan, ? extends Plan>) newPlan);