[test](nereids) Add some tests for PushFilterInsideJoin and FindHashConditionForJoin rule (#24550)

This commit is contained in:
JingDas
2023-09-28 16:45:05 +08:00
committed by GitHub
parent bf4fb32487
commit 230b7bd15e
2 changed files with 143 additions and 25 deletions

View File

@ -30,11 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@ -43,27 +45,35 @@ import java.util.Optional;
/**
* initial plan:
* join
* -hashJoinConjuncts={}
* -otherJoinCondition=
* "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x"
* join
* -hashJoinConjuncts={}
* -otherJoinCondition=
* "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x and A.x=B.x+B.y and A.x+B.x=B.y"
* after transform
* join
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
* join
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y, A.x=B.x+B.y}
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x and A.x+B.x=B.y"
*/
class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
private static Plan studentScan;
private static Plan scoreScan;
@BeforeAll
public static void beforeAll() {
studentScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
scoreScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
ImmutableList.of(""));
}
@Test
void testFindHashCondition() {
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
ImmutableList.of(""));
Slot studentId = student.getOutput().get(0);
Slot gender = student.getOutput().get(1);
Slot scoreId = score.getOutput().get(0);
Slot cid = score.getOutput().get(1);
Slot studentId = studentScan.getOutput().get(0);
Slot gender = studentScan.getOutput().get(1);
Slot scoreId = scoreScan.getOutput().get(0);
Slot cid = scoreScan.getOutput().get(1);
Expression eq1 = new EqualTo(studentId, scoreId); // a=b
Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1
@ -72,15 +82,37 @@ class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
new EqualTo(scoreId, studentId),
new EqualTo(gender, cid));
Expression less = new LessThan(scoreId, studentId);
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
Expression eq4 = new EqualTo(studentId, new Add(scoreId, cid));
Expression eq5 = new EqualTo(studentId, new Add(studentId, cid));
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less, eq4, eq5);
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(),
expr, JoinHint.NONE, Optional.empty(), student, score);
expr, JoinHint.NONE, Optional.empty(), studentScan, scoreScan);
PlanChecker.from(new ConnectContext(), join)
.applyTopDown(new FindHashConditionForJoin())
.matches(
logicalJoin()
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
.applyTopDown(new FindHashConditionForJoin())
.matches(
logicalJoin()
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3, eq4)))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less, eq5))));
}
@Test
void testFindHashConditionAndConvertToInnerJoin() {
Slot studentId = studentScan.getOutput().get(0);
Slot sid = scoreScan.getOutput().get(0);
Expression eq1 = new EqualTo(studentId, sid);
Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1
LogicalJoin join = new LogicalJoin<>(JoinType.CROSS_JOIN, new ArrayList<>(),
ImmutableList.of(eq1, eq2), JoinHint.NONE, Optional.empty(), studentScan, scoreScan);
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.applyTopDown(new FindHashConditionForJoin())
.matches(
logicalJoin()
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1)))
.when(j -> j.getJoinType().isInnerJoin())
);
}
}

View File

@ -17,8 +17,14 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -28,15 +34,35 @@ import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
class PushFilterInsideJoinTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static LogicalOlapScan scan1;
private static LogicalOlapScan scan2;
private static LogicalOlapScan scoreScan;
private static LogicalOlapScan studentScan;
private static LogicalOlapScan courseScan;
@BeforeAll
public static void beforeAll() {
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
scoreScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.score,
ImmutableList.of(""));
studentScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.student,
ImmutableList.of(""));
courseScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.course,
ImmutableList.of(""));
}
@Test
void testPushInside() {
void testPushInsideCrossJoin() {
Expression predicates = new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1));
LogicalPlan plan = new LogicalPlanBuilder(scan1)
@ -51,4 +77,64 @@ class PushFilterInsideJoinTest implements MemoPatternMatchSupported {
logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates))
);
}
@Test
public void testPushInsideInnerJoin() {
Expression predicates = new Or(new And(
// score.sid = student.id
new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0)),
// score.cid = course.cid
new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))),
// grade > 2
new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2)));
LogicalPlan plan = new LogicalPlanBuilder(scoreScan)
.joinEmptyOn(studentScan, JoinType.CROSS_JOIN)
.join(courseScan, JoinType.INNER_JOIN,
// score.cid = course.cid
ImmutableList.of(new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))),
// grade < 10
ImmutableList.of(new LessThan(scoreScan.getOutput().get(2), new DoubleLiteral(10))))
.filter(predicates)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushFilterInsideJoin())
.printlnTree()
.matchesFromRoot(
logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates)));
}
@Test
public void testShouldNotPushInsideJoin() {
for (JoinType joinType : JoinType.values()) {
if (JoinType.INNER_JOIN == joinType || JoinType.CROSS_JOIN == joinType) {
continue;
}
shouldNotPushInsideJoin(joinType);
}
}
private void shouldNotPushInsideJoin(JoinType joinType) {
// score.sid = student.id
Expression eq = new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0));
// default use left side column grade > 2
Expression predicate = new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2));
if (JoinType.RIGHT_ANTI_JOIN == joinType || JoinType.RIGHT_SEMI_JOIN == joinType) {
// use right side column age < 10
predicate = new LessThan(studentScan.getOutput().get(3), new DoubleLiteral(10));
}
LogicalPlan plan = new LogicalPlanBuilder(scoreScan)
.join(studentScan, joinType, ImmutableList.of(eq), ImmutableList.of())
.filter(predicate)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushFilterInsideJoin())
.printlnTree()
.matches(
logicalJoin().when(join -> join.getHashJoinConjuncts().get(0).equals(eq))
.when(join -> join.getOtherJoinConjuncts().isEmpty()));
}
}