[fix](Nereids): fix push filter through outer join in ReorderJoinRule. (#14952)
This commit is contained in:
@ -18,7 +18,6 @@
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
@ -41,6 +40,7 @@ import com.google.common.collect.Lists;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -203,7 +203,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
}
|
||||
|
||||
Builder<Plan> builder = ImmutableList.builder();
|
||||
// recursively hanlde multiJoin children.
|
||||
// recursively handle multiJoin children.
|
||||
for (Plan child : multiJoin.children()) {
|
||||
if (child instanceof MultiJoin) {
|
||||
MultiJoin childMultiJoin = (MultiJoin) child;
|
||||
@ -215,60 +215,56 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build());
|
||||
|
||||
if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) {
|
||||
List<Expression> leftFilter = Lists.newArrayList();
|
||||
List<Expression> rightFilter = Lists.newArrayList();
|
||||
List<Expression> remainingFilter = Lists.newArrayList();
|
||||
List<Expression> remainingFilter;
|
||||
|
||||
Plan left;
|
||||
Plan right;
|
||||
if (multiJoinHandleChildren.getJoinType().isLeftJoin()) {
|
||||
right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
|
||||
Set<Slot> rightOutputSet = right.getOutputSet();
|
||||
Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
|
||||
.collect(Collectors.partitioningBy(expr ->
|
||||
ExpressionUtils.isIntersecting(rightOutputSet, expr.getInputSlots())
|
||||
));
|
||||
remainingFilter = split.get(true);
|
||||
List<Expression> pushedFilter = split.get(false);
|
||||
left = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
|
||||
leftFilter,
|
||||
pushedFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
|
||||
} else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
|
||||
left = multiJoinHandleChildren.child(0);
|
||||
Set<Slot> leftOutputSet = left.getOutputSet();
|
||||
Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
|
||||
.collect(Collectors.partitioningBy(expr ->
|
||||
ExpressionUtils.isIntersecting(leftOutputSet, expr.getInputSlots())
|
||||
));
|
||||
remainingFilter = split.get(true);
|
||||
List<Expression> pushedFilter = split.get(false);
|
||||
right = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
|
||||
rightFilter,
|
||||
pushedFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
} else {
|
||||
left = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
|
||||
leftFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
right = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
|
||||
rightFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
}
|
||||
|
||||
// split filter
|
||||
for (Expression expr : multiJoinHandleChildren.getJoinFilter()) {
|
||||
Set<Slot> exprInputSlots = expr.getInputSlots();
|
||||
Preconditions.checkState(!exprInputSlots.isEmpty());
|
||||
|
||||
if (left.getOutputSet().containsAll(exprInputSlots)) {
|
||||
leftFilter.add(expr);
|
||||
} else if (right.getOutputSet().containsAll(exprInputSlots)) {
|
||||
rightFilter.add(expr);
|
||||
} else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) {
|
||||
remainingFilter.add(expr);
|
||||
} else {
|
||||
NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains.");
|
||||
throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains.");
|
||||
}
|
||||
remainingFilter = multiJoin.getJoinFilter();
|
||||
Preconditions.checkState(multiJoinHandleChildren.arity() == 2);
|
||||
List<Plan> children = multiJoinHandleChildren.children().stream().map(child -> {
|
||||
if (child instanceof MultiJoin) {
|
||||
return multiJoinToJoin((MultiJoin) child);
|
||||
} else {
|
||||
return child;
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
left = children.get(0);
|
||||
right = children.get(1);
|
||||
}
|
||||
|
||||
return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>(
|
||||
multiJoinHandleChildren.getJoinType(),
|
||||
ExpressionUtils.EMPTY_CONDITION,
|
||||
multiJoinHandleChildren.getNotInnerJoinConditions(),
|
||||
PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right)));
|
||||
ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(),
|
||||
left, right));
|
||||
}
|
||||
|
||||
// following this multiJoin just contain INNER/CROSS.
|
||||
|
||||
@ -96,20 +96,31 @@ class ReorderJoinTest implements PatternMatchSupported {
|
||||
|
||||
@Test
|
||||
public void testRightSemiJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
LogicalPlan plan1 = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
|
||||
.build();
|
||||
check(ImmutableList.of(plan1));
|
||||
|
||||
LogicalPlan plan2 = new LogicalPlanBuilder(scan2)
|
||||
.hashJoinUsing(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.build(),
|
||||
JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)
|
||||
)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan2)
|
||||
.rewrite()
|
||||
.matchesFromRoot(
|
||||
rightSemiLogicalJoin(
|
||||
leafPlan(),
|
||||
innerLogicalJoin()
|
||||
)
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -171,7 +182,7 @@ class ReorderJoinTest implements PatternMatchSupported {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
/*
|
||||
* join
|
||||
* crossjoin / \
|
||||
* / \ join D
|
||||
|
||||
@ -104,4 +104,20 @@ public class MultiJoinTest extends SqlTestBase {
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testOuterJoin() {
|
||||
String sql = "SELECT * FROM T1 LEFT OUTER JOIN T2 ON T1.id = T2.id, T3 WHERE T2.score > 0";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.printlnTree()
|
||||
.matches(
|
||||
crossLogicalJoin(
|
||||
leftOuterLogicalJoin()
|
||||
.when(join -> join.getOtherJoinConjuncts().size() == 1),
|
||||
logicalOlapScan()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user