[improve](Nereids): ReorderJoin eliminate this recursion (#13505)

This commit is contained in:
jakevin
2022-10-24 17:11:43 +08:00
committed by GitHub
parent 7faad9f004
commit 409bd76999
4 changed files with 111 additions and 42 deletions

View File

@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -79,8 +80,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
/**
* Recursively convert to
* {@link LogicalJoin} or
* {@link LogicalFilter}--{@link LogicalJoin}
* {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
* --> {@link MultiJoin}
*/
public Plan joinToMultiJoin(Plan plan) {
@ -182,20 +182,20 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
* </ul>
* </p>
* <p>
* Graphic presentation:
* <pre>
* A JOIN B JOIN C LEFT JOIN D JOIN F
* left left│
* A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF)
* <p>
*
* A JOIN B RIGHT JOIN C JOIN D JOIN F
* right │right
* A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F)
* <p>
*
* (A JOIN B JOIN C) FULL JOIN (D JOIN F)
* full │
* A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
* </p>
* </pre>
*/
public Plan multiJoinToJoin(MultiJoin multiJoin) {
if (multiJoin.arity() == 1) {
@ -272,24 +272,22 @@ public class ReorderJoin extends OneRewriteRuleFactory {
}
// following this multiJoin just contain INNER/CROSS.
List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
Set<Expression> joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter());
Plan left = multiJoinHandleChildren.child(0);
List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
Set<Integer> usedPlansIndex = new HashSet<>();
usedPlansIndex.add(0);
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
List<Plan> newInputs = Lists.newArrayList();
newInputs.add(join);
newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
joinFilter, usedPlansIndex);
join.getHashJoinConjuncts().forEach(joinFilter::remove);
join.getOtherJoinConjuncts().forEach(joinFilter::remove);
joinFilter.removeAll(join.getHashJoinConjuncts());
joinFilter.removeAll(join.getOtherJoinConjuncts());
// TODO(wj): eliminate this recursion.
return multiJoinToJoin(new MultiJoin(
newInputs,
joinFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION));
left = join;
}
return PlanUtils.filterOrSelf(new ArrayList<>(joinFilter), left);
}
/**
@ -319,9 +317,14 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* @return InnerJoin or CrossJoin{left, last of [candidates]}
*/
private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
List<Expression> joinFilter) {
Set<Expression> joinFilter, Set<Integer> usedPlansIndex) {
List<Expression> otherJoinConditions = Lists.newArrayList();
Set<Slot> leftOutputSet = left.getOutputSet();
for (int i = 0; i < candidates.size(); i++) {
if (usedPlansIndex.contains(i)) {
continue;
}
Plan candidate = candidates.get(i);
Set<Slot> rightOutputSet = candidate.getOutputSet();
@ -330,34 +333,35 @@ public class ReorderJoin extends OneRewriteRuleFactory {
List<Expression> currentJoinFilter = joinFilter.stream()
.filter(expr -> {
Set<Slot> exprInputSlots = expr.getInputSlots();
Preconditions.checkState(exprInputSlots.size() > 1,
"Predicate like table.col > 1 must have pushdown.");
if (leftOutputSet.containsAll(exprInputSlots)) {
return false;
}
if (rightOutputSet.containsAll(exprInputSlots)) {
return false;
}
return joinOutput.containsAll(exprInputSlots);
return !leftOutputSet.containsAll(exprInputSlots)
&& !rightOutputSet.containsAll(exprInputSlots)
&& joinOutput.containsAll(exprInputSlots);
}).collect(Collectors.toList());
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
left.getOutput(), candidate.getOutput(), currentJoinFilter);
List<Expression> hashJoinConditions = pair.first;
List<Expression> otherJoinConditions = pair.second;
otherJoinConditions = pair.second;
if (!hashJoinConditions.isEmpty()) {
usedPlansIndex.add(i);
return new LogicalJoin<>(JoinType.INNER_JOIN,
hashJoinConditions, otherJoinConditions,
left, candidate);
}
if (i == candidates.size() - 1) {
return new LogicalJoin<>(JoinType.CROSS_JOIN,
hashJoinConditions, otherJoinConditions,
left, candidate);
}
}
// All { left -> one in [candidates] } is CrossJoin
// Generate a CrossJoin
for (int j = candidates.size() - 1; j >= 0; j--) {
if (usedPlansIndex.contains(j)) {
continue;
}
usedPlansIndex.add(j);
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
otherJoinConditions,
left, candidates.get(j));
}
throw new RuntimeException("findInnerJoin: can't reach here");
}
}

View File

@ -130,17 +130,44 @@ class ReorderJoinTest implements PatternMatchSupported {
check(plans);
}
public void check(List<LogicalPlan> plans) {
@Test
public void testCrossJoin() {
ImmutableList<LogicalPlan> plans = ImmutableList.of(
new LogicalPlanBuilder(scan1)
.hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
.build(),
new LogicalPlanBuilder(scan1)
.hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
.filter(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)))
.build()
);
for (LogicalPlan plan : plans) {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyBottomUp(new ReorderJoin())
.matchesFromRoot(
logicalJoin(
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
leafPlan()
).when(join -> join.getJoinType().isCrossJoin())
);
}
}
public void check(List<LogicalPlan> plans) {
for (LogicalPlan plan : plans) {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.rewrite()
.printlnTree()
.matchesFromRoot(
logicalJoin(
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
leafPlan()
).whenNot(join -> join.getJoinType().isCrossJoin())
)
.printlnTree();
);
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import org.apache.doris.nereids.util.PlanChecker;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -29,8 +30,9 @@ public class MultiJoinTest extends SqlTestBase {
@Test
void testMultiJoinEliminateCross() {
List<String> sqls = ImmutableList.<String>builder()
.add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0")
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0 AND T1.id + T2.id + T3.id > 0")
.build();
for (String sql : sqls) {
@ -47,6 +49,41 @@ public class MultiJoinTest extends SqlTestBase {
}
}
@Test
@Disabled
// TODO: MultiJoin And EliminateOuter
void testEliminateBelowOuter() {
String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id";
PlanChecker.from(connectContext)
.analyze(sql)
.applyBottomUp(new ReorderJoin())
.printlnTree();
}
@Test
void testPushdownAndEliminateOuter() {
String sql = "SELECT * FROM T1 LEFT JOIN T2 ON T1.id = T2.id WHERE T2.score > 0";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.printlnTree()
.matches(
logicalJoin().when(join -> join.getJoinType().isInnerJoin())
);
String sql1 = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id AND T3.score > 0";
PlanChecker.from(connectContext)
.analyze(sql1)
.rewrite()
.printlnTree()
.matches(
logicalJoin(
logicalJoin().when(join -> join.getJoinType().isInnerJoin()),
any()
).when(join -> join.getJoinType().isInnerJoin())
);
}
@Test
void testMultiJoinExistCross() {
List<String> sqls = ImmutableList.<String>builder()

View File

@ -363,6 +363,7 @@ public class PlanChecker {
public PlanChecker printlnTree() {
System.out.println(cascadesContext.getMemo().copyOut().treeString());
System.out.println("-----------------------------");
return this;
}