[fix](Nereids): fix push filter through outer join in ReorderJoinRule. (#14952)

This commit is contained in:
jakevin
2022-12-12 00:38:21 +08:00
committed by GitHub
parent 5fe5f596f2
commit 9e0e376a72
3 changed files with 75 additions and 52 deletions

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@ -41,6 +40,7 @@ import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@ -203,7 +203,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
}
Builder<Plan> builder = ImmutableList.builder();
// recursively hanlde multiJoin children.
// recursively handle multiJoin children.
for (Plan child : multiJoin.children()) {
if (child instanceof MultiJoin) {
MultiJoin childMultiJoin = (MultiJoin) child;
@ -215,60 +215,56 @@ public class ReorderJoin extends OneRewriteRuleFactory {
MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build());
if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) {
List<Expression> leftFilter = Lists.newArrayList();
List<Expression> rightFilter = Lists.newArrayList();
List<Expression> remainingFilter = Lists.newArrayList();
List<Expression> remainingFilter;
Plan left;
Plan right;
if (multiJoinHandleChildren.getJoinType().isLeftJoin()) {
right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
Set<Slot> rightOutputSet = right.getOutputSet();
Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
.collect(Collectors.partitioningBy(expr ->
ExpressionUtils.isIntersecting(rightOutputSet, expr.getInputSlots())
));
remainingFilter = split.get(true);
List<Expression> pushedFilter = split.get(false);
left = multiJoinToJoin(new MultiJoin(
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
leftFilter,
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION));
right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
} else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
left = multiJoinHandleChildren.child(0);
Set<Slot> leftOutputSet = left.getOutputSet();
Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
.collect(Collectors.partitioningBy(expr ->
ExpressionUtils.isIntersecting(leftOutputSet, expr.getInputSlots())
));
remainingFilter = split.get(true);
List<Expression> pushedFilter = split.get(false);
right = multiJoinToJoin(new MultiJoin(
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
rightFilter,
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION));
} else {
left = multiJoinToJoin(new MultiJoin(
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
leftFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION));
right = multiJoinToJoin(new MultiJoin(
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
rightFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION));
}
// split filter
for (Expression expr : multiJoinHandleChildren.getJoinFilter()) {
Set<Slot> exprInputSlots = expr.getInputSlots();
Preconditions.checkState(!exprInputSlots.isEmpty());
if (left.getOutputSet().containsAll(exprInputSlots)) {
leftFilter.add(expr);
} else if (right.getOutputSet().containsAll(exprInputSlots)) {
rightFilter.add(expr);
} else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) {
remainingFilter.add(expr);
} else {
NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains.");
throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains.");
}
remainingFilter = multiJoin.getJoinFilter();
Preconditions.checkState(multiJoinHandleChildren.arity() == 2);
List<Plan> children = multiJoinHandleChildren.children().stream().map(child -> {
if (child instanceof MultiJoin) {
return multiJoinToJoin((MultiJoin) child);
} else {
return child;
}
}).collect(Collectors.toList());
left = children.get(0);
right = children.get(1);
}
return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>(
multiJoinHandleChildren.getJoinType(),
ExpressionUtils.EMPTY_CONDITION,
multiJoinHandleChildren.getNotInnerJoinConditions(),
PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right)));
ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(),
left, right));
}
// following this multiJoin just contain INNER/CROSS.

View File

@ -96,20 +96,31 @@ class ReorderJoinTest implements PatternMatchSupported {
@Test
public void testRightSemiJoin() {
ImmutableList<LogicalPlan> plans = ImmutableList.of(
new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
.build(),
new LogicalPlanBuilder(scan1)
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
.build()
);
LogicalPlan plan1 = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
.build();
check(ImmutableList.of(plan1));
LogicalPlan plan2 = new LogicalPlanBuilder(scan2)
.hashJoinUsing(
new LogicalPlanBuilder(scan1)
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.build(),
JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)
)
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan2)
.rewrite()
.matchesFromRoot(
rightSemiLogicalJoin(
leafPlan(),
innerLogicalJoin()
)
);
check(plans);
}
@Test
@ -171,7 +182,7 @@ class ReorderJoinTest implements PatternMatchSupported {
}
}
/**
/*
* join
* crossjoin / \
* / \ join D

View File

@ -104,4 +104,20 @@ public class MultiJoinTest extends SqlTestBase {
.printlnTree();
}
}
@Test
void testOuterJoin() {
String sql = "SELECT * FROM T1 LEFT OUTER JOIN T2 ON T1.id = T2.id, T3 WHERE T2.score > 0";
PlanChecker.from(connectContext)
.analyze(sql)
.applyBottomUp(new ReorderJoin())
.printlnTree()
.matches(
crossLogicalJoin(
leftOuterLogicalJoin()
.when(join -> join.getOtherJoinConjuncts().size() == 1),
logicalOlapScan()
)
);
}
}