[improve](Nereids): ReorderJoin eliminate this recursion (#13505)
This commit is contained in:
@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
@ -79,8 +80,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
|
||||
/**
|
||||
* Recursively convert to
|
||||
* {@link LogicalJoin} or
|
||||
* {@link LogicalFilter}--{@link LogicalJoin}
|
||||
* {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
|
||||
* --> {@link MultiJoin}
|
||||
*/
|
||||
public Plan joinToMultiJoin(Plan plan) {
|
||||
@ -182,20 +182,20 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
* <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
|
||||
* </ul>
|
||||
* </p>
|
||||
* <p>
|
||||
* Graphic presentation:
|
||||
* <pre>
|
||||
* A JOIN B JOIN C LEFT JOIN D JOIN F
|
||||
* left left│
|
||||
* A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF)
|
||||
* <p>
|
||||
*
|
||||
* A JOIN B RIGHT JOIN C JOIN D JOIN F
|
||||
* right │right
|
||||
* A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F)
|
||||
* <p>
|
||||
*
|
||||
* (A JOIN B JOIN C) FULL JOIN (D JOIN F)
|
||||
* full │
|
||||
* A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
|
||||
* </p>
|
||||
* </pre>
|
||||
*/
|
||||
public Plan multiJoinToJoin(MultiJoin multiJoin) {
|
||||
if (multiJoin.arity() == 1) {
|
||||
@ -272,24 +272,22 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
}
|
||||
|
||||
// following this multiJoin just contain INNER/CROSS.
|
||||
List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
|
||||
Set<Expression> joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter());
|
||||
|
||||
Plan left = multiJoinHandleChildren.child(0);
|
||||
List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
|
||||
Set<Integer> usedPlansIndex = new HashSet<>();
|
||||
usedPlansIndex.add(0);
|
||||
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
|
||||
List<Plan> newInputs = Lists.newArrayList();
|
||||
newInputs.add(join);
|
||||
newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
|
||||
while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
|
||||
joinFilter, usedPlansIndex);
|
||||
join.getHashJoinConjuncts().forEach(joinFilter::remove);
|
||||
join.getOtherJoinConjuncts().forEach(joinFilter::remove);
|
||||
|
||||
joinFilter.removeAll(join.getHashJoinConjuncts());
|
||||
joinFilter.removeAll(join.getOtherJoinConjuncts());
|
||||
// TODO(wj): eliminate this recursion.
|
||||
return multiJoinToJoin(new MultiJoin(
|
||||
newInputs,
|
||||
joinFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
left = join;
|
||||
}
|
||||
|
||||
return PlanUtils.filterOrSelf(new ArrayList<>(joinFilter), left);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -319,9 +317,14 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
* @return InnerJoin or CrossJoin{left, last of [candidates]}
|
||||
*/
|
||||
private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
|
||||
List<Expression> joinFilter) {
|
||||
Set<Expression> joinFilter, Set<Integer> usedPlansIndex) {
|
||||
List<Expression> otherJoinConditions = Lists.newArrayList();
|
||||
Set<Slot> leftOutputSet = left.getOutputSet();
|
||||
for (int i = 0; i < candidates.size(); i++) {
|
||||
if (usedPlansIndex.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Plan candidate = candidates.get(i);
|
||||
Set<Slot> rightOutputSet = candidate.getOutputSet();
|
||||
|
||||
@ -330,34 +333,35 @@ public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
List<Expression> currentJoinFilter = joinFilter.stream()
|
||||
.filter(expr -> {
|
||||
Set<Slot> exprInputSlots = expr.getInputSlots();
|
||||
Preconditions.checkState(exprInputSlots.size() > 1,
|
||||
"Predicate like table.col > 1 must have pushdown.");
|
||||
if (leftOutputSet.containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
if (rightOutputSet.containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return joinOutput.containsAll(exprInputSlots);
|
||||
return !leftOutputSet.containsAll(exprInputSlots)
|
||||
&& !rightOutputSet.containsAll(exprInputSlots)
|
||||
&& joinOutput.containsAll(exprInputSlots);
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
|
||||
left.getOutput(), candidate.getOutput(), currentJoinFilter);
|
||||
List<Expression> hashJoinConditions = pair.first;
|
||||
List<Expression> otherJoinConditions = pair.second;
|
||||
otherJoinConditions = pair.second;
|
||||
if (!hashJoinConditions.isEmpty()) {
|
||||
usedPlansIndex.add(i);
|
||||
return new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
hashJoinConditions, otherJoinConditions,
|
||||
left, candidate);
|
||||
}
|
||||
|
||||
if (i == candidates.size() - 1) {
|
||||
return new LogicalJoin<>(JoinType.CROSS_JOIN,
|
||||
hashJoinConditions, otherJoinConditions,
|
||||
left, candidate);
|
||||
}
|
||||
}
|
||||
// All { left -> one in [candidates] } is CrossJoin
|
||||
// Generate a CrossJoin
|
||||
for (int j = candidates.size() - 1; j >= 0; j--) {
|
||||
if (usedPlansIndex.contains(j)) {
|
||||
continue;
|
||||
}
|
||||
usedPlansIndex.add(j);
|
||||
return new LogicalJoin<>(JoinType.CROSS_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION,
|
||||
otherJoinConditions,
|
||||
left, candidates.get(j));
|
||||
}
|
||||
|
||||
throw new RuntimeException("findInnerJoin: can't reach here");
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,17 +130,44 @@ class ReorderJoinTest implements PatternMatchSupported {
|
||||
check(plans);
|
||||
}
|
||||
|
||||
public void check(List<LogicalPlan> plans) {
|
||||
@Test
|
||||
public void testCrossJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
for (LogicalPlan plan : plans) {
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matchesFromRoot(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).when(join -> join.getJoinType().isCrossJoin())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void check(List<LogicalPlan> plans) {
|
||||
for (LogicalPlan plan : plans) {
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.rewrite()
|
||||
.printlnTree()
|
||||
.matchesFromRoot(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).whenNot(join -> join.getJoinType().isCrossJoin())
|
||||
)
|
||||
.printlnTree();
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
@ -29,8 +30,9 @@ public class MultiJoinTest extends SqlTestBase {
|
||||
@Test
|
||||
void testMultiJoinEliminateCross() {
|
||||
List<String> sqls = ImmutableList.<String>builder()
|
||||
.add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
|
||||
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
|
||||
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0")
|
||||
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0 AND T1.id + T2.id + T3.id > 0")
|
||||
.build();
|
||||
|
||||
for (String sql : sqls) {
|
||||
@ -47,6 +49,41 @@ public class MultiJoinTest extends SqlTestBase {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
// TODO: MultiJoin And EliminateOuter
|
||||
void testEliminateBelowOuter() {
|
||||
String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.printlnTree();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPushdownAndEliminateOuter() {
|
||||
String sql = "SELECT * FROM T1 LEFT JOIN T2 ON T1.id = T2.id WHERE T2.score > 0";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.printlnTree()
|
||||
.matches(
|
||||
logicalJoin().when(join -> join.getJoinType().isInnerJoin())
|
||||
);
|
||||
|
||||
String sql1 = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id AND T3.score > 0";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql1)
|
||||
.rewrite()
|
||||
.printlnTree()
|
||||
.matches(
|
||||
logicalJoin(
|
||||
logicalJoin().when(join -> join.getJoinType().isInnerJoin()),
|
||||
any()
|
||||
).when(join -> join.getJoinType().isInnerJoin())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiJoinExistCross() {
|
||||
List<String> sqls = ImmutableList.<String>builder()
|
||||
|
||||
@ -363,6 +363,7 @@ public class PlanChecker {
|
||||
|
||||
public PlanChecker printlnTree() {
|
||||
System.out.println(cascadesContext.getMemo().copyOut().treeString());
|
||||
System.out.println("-----------------------------");
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user