[test](nereids) Add some tests for PushFilterInsideJoin and FindHashConditionForJoin rule (#24550)
This commit is contained in:
@ -30,11 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -43,27 +45,35 @@ import java.util.Optional;
|
||||
|
||||
/**
|
||||
* initial plan:
|
||||
* join
|
||||
* -hashJoinConjuncts={}
|
||||
* -otherJoinCondition=
|
||||
* "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x"
|
||||
* join
|
||||
* -hashJoinConjuncts={}
|
||||
* -otherJoinCondition=
|
||||
* "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x and A.x=B.x+B.y and A.x+B.x=B.y"
|
||||
* after transform
|
||||
* join
|
||||
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
|
||||
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
|
||||
* join
|
||||
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y, A.x=B.x+B.y}
|
||||
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x and A.x+B.x=B.y"
|
||||
*/
|
||||
class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
private static Plan studentScan;
|
||||
private static Plan scoreScan;
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll() {
|
||||
studentScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
scoreScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
|
||||
ImmutableList.of(""));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFindHashCondition() {
|
||||
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
|
||||
ImmutableList.of(""));
|
||||
|
||||
Slot studentId = student.getOutput().get(0);
|
||||
Slot gender = student.getOutput().get(1);
|
||||
Slot scoreId = score.getOutput().get(0);
|
||||
Slot cid = score.getOutput().get(1);
|
||||
Slot studentId = studentScan.getOutput().get(0);
|
||||
Slot gender = studentScan.getOutput().get(1);
|
||||
Slot scoreId = scoreScan.getOutput().get(0);
|
||||
Slot cid = scoreScan.getOutput().get(1);
|
||||
|
||||
Expression eq1 = new EqualTo(studentId, scoreId); // a=b
|
||||
Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1
|
||||
@ -72,15 +82,37 @@ class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
|
||||
new EqualTo(scoreId, studentId),
|
||||
new EqualTo(gender, cid));
|
||||
Expression less = new LessThan(scoreId, studentId);
|
||||
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
|
||||
Expression eq4 = new EqualTo(studentId, new Add(scoreId, cid));
|
||||
Expression eq5 = new EqualTo(studentId, new Add(studentId, cid));
|
||||
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less, eq4, eq5);
|
||||
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(),
|
||||
expr, JoinHint.NONE, Optional.empty(), student, score);
|
||||
expr, JoinHint.NONE, Optional.empty(), studentScan, scoreScan);
|
||||
|
||||
PlanChecker.from(new ConnectContext(), join)
|
||||
.applyTopDown(new FindHashConditionForJoin())
|
||||
.matches(
|
||||
logicalJoin()
|
||||
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
|
||||
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
|
||||
.applyTopDown(new FindHashConditionForJoin())
|
||||
.matches(
|
||||
logicalJoin()
|
||||
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3, eq4)))
|
||||
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less, eq5))));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFindHashConditionAndConvertToInnerJoin() {
|
||||
Slot studentId = studentScan.getOutput().get(0);
|
||||
Slot sid = scoreScan.getOutput().get(0);
|
||||
|
||||
Expression eq1 = new EqualTo(studentId, sid);
|
||||
Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1
|
||||
|
||||
LogicalJoin join = new LogicalJoin<>(JoinType.CROSS_JOIN, new ArrayList<>(),
|
||||
ImmutableList.of(eq1, eq2), JoinHint.NONE, Optional.empty(), studentScan, scoreScan);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
.applyTopDown(new FindHashConditionForJoin())
|
||||
.matches(
|
||||
logicalJoin()
|
||||
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1)))
|
||||
.when(j -> j.getJoinType().isInnerJoin())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,8 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
@ -28,15 +34,35 @@ import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class PushFilterInsideJoinTest implements MemoPatternMatchSupported {
|
||||
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private static LogicalOlapScan scan1;
|
||||
private static LogicalOlapScan scan2;
|
||||
private static LogicalOlapScan scoreScan;
|
||||
private static LogicalOlapScan studentScan;
|
||||
private static LogicalOlapScan courseScan;
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll() {
|
||||
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
scoreScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
|
||||
PlanConstructor.score,
|
||||
ImmutableList.of(""));
|
||||
studentScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
|
||||
PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
courseScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
|
||||
PlanConstructor.course,
|
||||
ImmutableList.of(""));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPushInside() {
|
||||
void testPushInsideCrossJoin() {
|
||||
Expression predicates = new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1));
|
||||
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
@ -51,4 +77,64 @@ class PushFilterInsideJoinTest implements MemoPatternMatchSupported {
|
||||
logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPushInsideInnerJoin() {
|
||||
Expression predicates = new Or(new And(
|
||||
// score.sid = student.id
|
||||
new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0)),
|
||||
// score.cid = course.cid
|
||||
new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))),
|
||||
// grade > 2
|
||||
new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2)));
|
||||
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scoreScan)
|
||||
.joinEmptyOn(studentScan, JoinType.CROSS_JOIN)
|
||||
.join(courseScan, JoinType.INNER_JOIN,
|
||||
// score.cid = course.cid
|
||||
ImmutableList.of(new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))),
|
||||
// grade < 10
|
||||
ImmutableList.of(new LessThan(scoreScan.getOutput().get(2), new DoubleLiteral(10))))
|
||||
.filter(predicates)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushFilterInsideJoin())
|
||||
.printlnTree()
|
||||
.matchesFromRoot(
|
||||
logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldNotPushInsideJoin() {
|
||||
for (JoinType joinType : JoinType.values()) {
|
||||
if (JoinType.INNER_JOIN == joinType || JoinType.CROSS_JOIN == joinType) {
|
||||
continue;
|
||||
}
|
||||
shouldNotPushInsideJoin(joinType);
|
||||
}
|
||||
}
|
||||
|
||||
private void shouldNotPushInsideJoin(JoinType joinType) {
|
||||
// score.sid = student.id
|
||||
Expression eq = new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0));
|
||||
// default use left side column grade > 2
|
||||
Expression predicate = new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2));
|
||||
if (JoinType.RIGHT_ANTI_JOIN == joinType || JoinType.RIGHT_SEMI_JOIN == joinType) {
|
||||
// use right side column age < 10
|
||||
predicate = new LessThan(studentScan.getOutput().get(3), new DoubleLiteral(10));
|
||||
}
|
||||
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scoreScan)
|
||||
.join(studentScan, joinType, ImmutableList.of(eq), ImmutableList.of())
|
||||
.filter(predicate)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushFilterInsideJoin())
|
||||
.printlnTree()
|
||||
.matches(
|
||||
logicalJoin().when(join -> join.getHashJoinConjuncts().get(0).equals(eq))
|
||||
.when(join -> join.getOtherJoinConjuncts().isEmpty()));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user