diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoinTest.java index 470b12dfb3..e768a0771a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/FindHashConditionForJoinTest.java @@ -30,11 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -43,27 +45,35 @@ import java.util.Optional; /** * initial plan: - * join - * -hashJoinConjuncts={} - * -otherJoinCondition= - * "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x" + * join + * -hashJoinConjuncts={} + * -otherJoinCondition= + * "A.x=B.x and A.y+1=B.y and A.x=1 and (A.y=B.y or B.x=A.x) and A.x>B.x and A.x=B.x+B.y and A.x+B.x=B.y" * after transform - * join - * -hashJoinConjuncts={A.x=B.x, A.y+1=B.y} - * -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x" + * join + * -hashJoinConjuncts={A.x=B.x, A.y+1=B.y, A.x=B.x+B.y} + * -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x and A.x+B.x=B.y" */ class FindHashConditionForJoinTest implements MemoPatternMatchSupported { + + private static Plan studentScan; + private static Plan scoreScan; + + @BeforeAll + public static void beforeAll() { + studentScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, + ImmutableList.of("")); + scoreScan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, + ImmutableList.of("")); + } + @Test void testFindHashCondition() { - Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, - ImmutableList.of("")); - Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, - ImmutableList.of("")); - Slot studentId = student.getOutput().get(0); - Slot gender = student.getOutput().get(1); - Slot scoreId = score.getOutput().get(0); - Slot cid = score.getOutput().get(1); + Slot studentId = studentScan.getOutput().get(0); + Slot gender = studentScan.getOutput().get(1); + Slot scoreId = scoreScan.getOutput().get(0); + Slot cid = scoreScan.getOutput().get(1); Expression eq1 = new EqualTo(studentId, scoreId); // a=b Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1 @@ -72,15 +82,37 @@ class FindHashConditionForJoinTest implements MemoPatternMatchSupported { new EqualTo(scoreId, studentId), new EqualTo(gender, cid)); Expression less = new LessThan(scoreId, studentId); - List expr = ImmutableList.of(eq1, eq2, eq3, or, less); + Expression eq4 = new EqualTo(studentId, new Add(scoreId, cid)); + Expression eq5 = new EqualTo(studentId, new Add(studentId, cid)); + List expr = ImmutableList.of(eq1, eq2, eq3, or, less, eq4, eq5); LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(), - expr, JoinHint.NONE, Optional.empty(), student, score); + expr, JoinHint.NONE, Optional.empty(), studentScan, scoreScan); PlanChecker.from(new ConnectContext(), join) - .applyTopDown(new FindHashConditionForJoin()) - .matches( - logicalJoin() - .when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3))) - .when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less)))); + .applyTopDown(new FindHashConditionForJoin()) + .matches( + logicalJoin() + .when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3, eq4))) + .when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less, eq5)))); + } + + @Test + void testFindHashConditionAndConvertToInnerJoin() { + Slot studentId = studentScan.getOutput().get(0); + Slot sid = scoreScan.getOutput().get(0); + + Expression eq1 = new EqualTo(studentId, sid); + Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1 + + LogicalJoin join = new LogicalJoin<>(JoinType.CROSS_JOIN, new ArrayList<>(), + ImmutableList.of(eq1, eq2), JoinHint.NONE, Optional.empty(), studentScan, scoreScan); + + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyTopDown(new FindHashConditionForJoin()) + .matches( + logicalJoin() + .when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1))) + .when(j -> j.getJoinType().isInnerJoin()) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushFilterInsideJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushFilterInsideJoinTest.java index 4249f7644d..4fd436c54a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushFilterInsideJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushFilterInsideJoinTest.java @@ -17,8 +17,14 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -28,15 +34,35 @@ import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; class PushFilterInsideJoinTest implements MemoPatternMatchSupported { - private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private static LogicalOlapScan scan1; + private static LogicalOlapScan scan2; + private static LogicalOlapScan scoreScan; + private static LogicalOlapScan studentScan; + private static LogicalOlapScan courseScan; + + @BeforeAll + public static void beforeAll() { + scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + scoreScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), + PlanConstructor.score, + ImmutableList.of("")); + studentScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), + PlanConstructor.student, + ImmutableList.of("")); + courseScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), + PlanConstructor.course, + ImmutableList.of("")); + } @Test - void testPushInside() { + void testPushInsideCrossJoin() { Expression predicates = new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1)); LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -51,4 +77,64 @@ class PushFilterInsideJoinTest implements MemoPatternMatchSupported { logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates)) ); } + + @Test + public void testPushInsideInnerJoin() { + Expression predicates = new Or(new And( + // score.sid = student.id + new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0)), + // score.cid = course.cid + new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))), + // grade > 2 + new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2))); + + LogicalPlan plan = new LogicalPlanBuilder(scoreScan) + .joinEmptyOn(studentScan, JoinType.CROSS_JOIN) + .join(courseScan, JoinType.INNER_JOIN, + // score.cid = course.cid + ImmutableList.of(new EqualTo(scoreScan.getOutput().get(1), courseScan.getOutput().get(0))), + // grade < 10 + ImmutableList.of(new LessThan(scoreScan.getOutput().get(2), new DoubleLiteral(10)))) + .filter(predicates) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushFilterInsideJoin()) + .printlnTree() + .matchesFromRoot( + logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates))); + } + + @Test + public void testShouldNotPushInsideJoin() { + for (JoinType joinType : JoinType.values()) { + if (JoinType.INNER_JOIN == joinType || JoinType.CROSS_JOIN == joinType) { + continue; + } + shouldNotPushInsideJoin(joinType); + } + } + + private void shouldNotPushInsideJoin(JoinType joinType) { + // score.sid = student.id + Expression eq = new EqualTo(scoreScan.getOutput().get(0), studentScan.getOutput().get(0)); + // default use left side column grade > 2 + Expression predicate = new GreaterThan(scoreScan.getOutput().get(2), new DoubleLiteral(2)); + if (JoinType.RIGHT_ANTI_JOIN == joinType || JoinType.RIGHT_SEMI_JOIN == joinType) { + // use right side column age < 10 + predicate = new LessThan(studentScan.getOutput().get(3), new DoubleLiteral(10)); + } + + LogicalPlan plan = new LogicalPlanBuilder(scoreScan) + .join(studentScan, joinType, ImmutableList.of(eq), ImmutableList.of()) + .filter(predicate) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushFilterInsideJoin()) + .printlnTree() + .matches( + logicalJoin().when(join -> join.getHashJoinConjuncts().get(0).equals(eq)) + .when(join -> join.getOtherJoinConjuncts().isEmpty())); + } }