diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 4f8d96a0d9..9b696f8235 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -216,7 +216,7 @@ public enum RuleType { // exploration rules TEST_EXPLORATION(RuleTypeClass.EXPLORATION), - LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION), + LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION), @@ -225,7 +225,10 @@ public enum RuleType { LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), - LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION), + LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), + LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_EXCHANGE_LEFT_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_EXCHANGE_RIGHT_PROJECT(RuleTypeClass.EXPLORATION), @@ -234,7 +237,6 @@ public enum RuleType { LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), - LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 367b00530e..11f9955a76 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -67,7 +67,7 @@ public class JoinCommute extends OneExplorationRuleFactory { } return newJoin; - }).toRule(RuleType.LOGICAL_JOIN_COMMUTATE); + }).toRule(RuleType.LOGICAL_JOIN_COMMUTE); } enum SwapType { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java index 0b57ed1b70..73f202cd26 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java @@ -58,7 +58,6 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { && (topJoin.left().getJoinType().isInnerJoin() || topJoin.left().getJoinType().isLeftOuterJoin() || topJoin.left().getJoinType().isRightOuterJoin()))) - .whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin()) .when(this::conditionChecker) .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint()) .whenNot(LogicalJoin::isMarkJoin) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index afa88aad98..a645ca3062 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -17,14 +17,10 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; @@ -34,9 +30,6 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -65,7 +58,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto && (topJoin.left().child().getJoinType().isInnerJoin() || topJoin.left().child().getJoinType().isLeftOuterJoin() || topJoin.left().child().getJoinType().isRightOuterJoin()))) - .whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin()) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) @@ -76,9 +68,8 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto GroupPlan b = bottomJoin.right(); GroupPlan c = topSemiJoin.right(); - // push topSemiJoin down project, so we need replace conjuncts by project. - Pair, List> conjuncts = replaceConjuncts(topSemiJoin, project); - Set conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream()) + Set conjunctsIds = Stream.concat(topSemiJoin.getHashJoinConjuncts().stream(), + topSemiJoin.getOtherJoinConjuncts().stream()) .flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet()); ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(), b.getOutputExprIdSet()); @@ -99,8 +90,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto // RIGHT_OUTER_JOIN should be eliminated in rewrite phase Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN); - Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second, - a, c); + Plan newBottomSemiJoin = topSemiJoin.withChildren(a, c); Plan newTopJoin = bottomJoin.withChildren(newBottomSemiJoin, b); return project.withChildren(newTopJoin); } else { @@ -112,37 +102,20 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto * / \ | * project C newTopJoin * | / \ - * bottomJoin C --> A newBottomSemiJoin - * / \ / \ - * A B B C + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B B C */ // LEFT_OUTER_JOIN should be eliminated in rewrite phase Preconditions.checkState(bottomJoin.getJoinType() != JoinType.LEFT_OUTER_JOIN); - Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second, - b, c); + Plan newBottomSemiJoin = topSemiJoin.withChildren(b, c); Plan newTopJoin = bottomJoin.withChildren(a, newBottomSemiJoin); return project.withChildren(newTopJoin); } }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT); } - private Pair, List> replaceConjuncts(LogicalJoin join, - LogicalProject project) { - Map outputToInput = new HashMap<>(); - for (NamedExpression outputExpr : project.getProjects()) { - Set usedSlots = outputExpr.getInputSlots(); - Preconditions.checkState(usedSlots.size() == 1); - Slot inputSlot = usedSlots.iterator().next(); - outputToInput.put(outputExpr.getExprId(), inputSlot); - } - List topHashConjuncts = - JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput); - List topOtherConjuncts = - JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput); - return Pair.of(topHashConjuncts, topOtherConjuncts); - } - enum ContainsType { LEFT, RIGHT, ALL } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java index 6977dd9d62..67c97e5788 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java @@ -73,7 +73,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { Plan newBottomJoin = topJoin.withChildren(a, c); Plan newTopJoin = bottomJoin.withChildren(newBottomJoin, b); return newTopJoin; - }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE); + }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE); } private boolean typeChecker(LogicalJoin, GroupPlan> topJoin) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java index be00e49dc1..7b42fe4e5d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java @@ -20,11 +20,10 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; @@ -33,13 +32,13 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -public class SemiJoinLogicalJoinTransposeProjectTest { +class SemiJoinLogicalJoinTransposeProjectTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); @Test - public void testSemiJoinLogicalTransposeProjectLAsscom() { + void testSemiJoinLogicalTransposeProjectLAsscom() { /*- * topSemiJoin project * / \ | @@ -57,28 +56,21 @@ public class SemiJoinLogicalJoinTransposeProjectTest { PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build()) - .checkMemo(memo -> { - Group root = memo.getRoot(); - Assertions.assertEquals(2, root.getLogicalExpressions().size()); - Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); - - LogicalJoin newTopJoin = (LogicalJoin) plan.child(0); - LogicalJoin newBottomJoin = (LogicalJoin) newTopJoin.left(); - Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); - Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); - - LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); - LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); - LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right(); - - Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName()); - Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); - Assertions.assertEquals("t2", newTopJoinRight.getTable().getName()); - }); + .matchesExploration( + logicalProject( + innerLogicalJoin( + leftSemiLogicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + ), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) + ) + ) + ); } @Test - public void testSemiJoinLogicalTransposeProjectLAsscomFail() { + void testSemiJoinLogicalTransposeProjectLAsscomFail() { LogicalPlan topJoin = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .project(ImmutableList.of(0, 2)) // t1.id, t2.id @@ -94,15 +86,15 @@ public class SemiJoinLogicalJoinTransposeProjectTest { } @Test - public void testSemiJoinLogicalTransposeProjectAll() { + void testSemiJoinLogicalTransposeProjectAll() { /*- * topSemiJoin project * / \ | * project C newTopJoin * | / \ * bottomJoin C --> A newBottomSemiJoin - * / \ / \ - * A B B C + * / \ / \ + * A B B C */ LogicalPlan topJoin = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id @@ -112,23 +104,16 @@ public class SemiJoinLogicalJoinTransposeProjectTest { PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(SemiJoinLogicalJoinTransposeProject.ALL.build()) - .checkMemo(memo -> { - Group root = memo.getRoot(); - Assertions.assertEquals(2, root.getLogicalExpressions().size()); - Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); - - LogicalJoin newTopJoin = (LogicalJoin) plan.child(0); - LogicalJoin newBottomJoin = (LogicalJoin) newTopJoin.right(); - Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); - Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); - - LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); - LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); - LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left(); - - Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName()); - Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName()); - Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); - }); + .matchesExploration( + logicalProject( + logicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + leftSemiLogicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + ) + ) + ) + ); } }