diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index e2aac0ae37..546d7f3e74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject; import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose; import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject; import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose; +import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows; import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation; @@ -66,6 +67,7 @@ public class RuleSet { .add(SemiJoinLogicalJoinTranspose.LEFT_DEEP) .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP) .add(SemiJoinSemiJoinTranspose.INSTANCE) + .add(SemiJoinSemiJoinTransposeProject.INSTANCE) .add(new AggregateDisassemble()) .add(new PushdownFilterThroughProject()) .add(new MergeProjects()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 8e17deff01..9728b49b38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -156,6 +156,7 @@ public enum RuleType { LOGICAL_ASSERT_NUM_ROWS_TO_PHYSICAL_ASSERT_NUM_ROWS(RuleTypeClass.IMPLEMENTATION), IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION), + LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT(RuleTypeClass.EXPLORATION), // sentinel, use to count rules SENTINEL(RuleTypeClass.SENTINEL), ; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java new file mode 100644 index 0000000000..682f6a0bde --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * rule for semi-semi transpose + */ +public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory { + public static final SemiJoinSemiJoinTransposeProject INSTANCE = new SemiJoinSemiJoinTransposeProject(); + + /* + * topSemi newTopSemi + * / \ / \ + * abProject C acProject B + * / ──► / + * bottomSemi newBottomSemi + * / \ / \ + * A B A C + */ + @Override + public Rule build() { + return logicalJoin(logicalProject(logicalJoin()), group()) + .when(this::typeChecker) + .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child())) + .then(topSemi -> { + LogicalJoin bottomSemi = topSemi.left().child(); + LogicalProject abProject = topSemi.left(); + GroupPlan a = bottomSemi.left(); + GroupPlan b = bottomSemi.right(); + GroupPlan c = topSemi.right(); + Set aOutputSet = a.getOutputSet(); + Set acProjects = new HashSet(abProject.getProjects()); + + bottomSemi.getHashJoinConjuncts().stream().forEach( + expression -> { + expression.getInputSlots().stream().forEach( + slot -> { + if (aOutputSet.contains(slot)) { + acProjects.add(slot); + } + }); + } + ); + LogicalJoin newBottomSemi = new LogicalJoin(topSemi.getJoinType(), topSemi.getHashJoinConjuncts(), + topSemi.getOtherJoinConjuncts(), a, c, + bottomSemi.getJoinReorderContext()); + newBottomSemi.getJoinReorderContext().setHasCommute(false); + newBottomSemi.getJoinReorderContext().setHasLAsscom(false); + LogicalProject acProject = new LogicalProject(acProjects.stream().collect(Collectors.toList()), + newBottomSemi); + LogicalJoin newTopSemi = new LogicalJoin(bottomSemi.getJoinType(), + bottomSemi.getHashJoinConjuncts(), bottomSemi.getOtherJoinConjuncts(), + acProject, b, + topSemi.getJoinReorderContext()); + newTopSemi.getJoinReorderContext().setHasLAsscom(true); + //return newTopSemi; + if (topSemi.getLogicalProperties().equals(newTopSemi)) { + return newTopSemi; + } else { + return new LogicalProject<>(new ArrayList<>(topSemi.getOutput()), newTopSemi); + } + }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT); + } + + public boolean typeChecker(LogicalJoin>, GroupPlan> topJoin) { + return SemiJoinSemiJoinTranspose.VALID_TYPE_PAIR_SET + .contains(Pair.of(topJoin.getJoinType(), topJoin.left().child().getJoinType())); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java new file mode 100644 index 0000000000..56d80933eb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +public class SemiJoinSemiJoinTransposeProjectTest implements PatternMatchSupported { + public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testSemiProjectSemiCommute() { + /* + * t1.name=t3.name t1.id=t2.id + * topJoin newTopJoin + * / \ / \ + * project t3 project t2 + * t1.name t1.name, t1.id + * | | + * t1.id=t2.id t1.name=t3.name + * bottomJoin --> newBottomJoin + * / \ / \ + * t1 t2 t1 t3 + */ + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0)) + .project(ImmutableList.of(1)) + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build()) + .printlnExploration() + .matchesExploration( + logicalProject( + logicalJoin( + logicalProject( + logicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN) + ).when(project -> project.getProjects().size() == 2 + && project.getProjects().get(0).getName().equals("id") + && project.getProjects().get(1).getName().equals("name") + ), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) + ).when(join -> join.getJoinType() == JoinType.LEFT_ANTI_JOIN) + ) + ); + } +}