diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index 64c9518a2d..58537792ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -98,6 +98,16 @@ public class Memo { */ public Plan copyOut(Group group, boolean includeGroupExpression) { GroupExpression logicalExpression = group.getLogicalExpression(); + return copyOut(logicalExpression, includeGroupExpression); + } + + /** + * copyOut the logicalExpression. + * @param logicalExpression the logicalExpression what want to copyOut + * @param includeGroupExpression whether include group expression in the plan + * @return plan + */ + public Plan copyOut(GroupExpression logicalExpression, boolean includeGroupExpression) { List children = Lists.newArrayList(); for (Group child : logicalExpression.children()) { children.add(copyOut(child, includeGroupExpression)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java new file mode 100644 index 0000000000..f5c5580d8d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; + +import java.util.Set; + +/** + * Planner rule that pushes a SemoJoin down in a tree past a LogicalJoin + * in order to trigger other rules that will convert {@code SemiJoin}s. + * + * + *

+ * Whether this first or second conversion is applied depends on + * which operands actually participate in the semi-join. + */ +public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { + @Override + public Rule build() { + return leftSemiLogicalJoin(logicalJoin(), group()) + .when(this::conditionChecker) + .then(topSemiJoin -> { + LogicalJoin bottomJoin = topSemiJoin.left(); + GroupPlan a = bottomJoin.left(); + GroupPlan b = bottomJoin.right(); + GroupPlan c = topSemiJoin.right(); + + boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput()); + + if (lasscom) { + /* + * topSemiJoin newTopJoin + * / \ / \ + * bottomJoin C --> newBottomSemiJoin B + * / \ / \ + * A B A C + */ + LogicalJoin newBottomSemiJoin = new LogicalJoin<>( + topSemiJoin.getJoinType(), + topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), a, c); + return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b); + } else { + /* + * topSemiJoin newTopJoin + * / \ / \ + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B B C + */ + LogicalJoin newBottomSemiJoin = new LogicalJoin<>( + topSemiJoin.getJoinType(), + topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), b, c); + return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin); + } + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + } + + // bottomJoin just return A OR B, else return false. + private boolean conditionChecker(LogicalJoin, GroupPlan> topJoin) { + Set bottomOutputSet = topJoin.left().getOutputSet(); + + Set aOutputSet = topJoin.left().left().getOutputSet(); + Set bOutputSet = topJoin.left().right().getOutputSet(); + + boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet); + boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet); + + Preconditions.checkState(isProjectA || isProjectB, "join output must contain child"); + return !(isProjectA && isProjectB); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java new file mode 100644 index 0000000000..45cdc7a19e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; + +import java.util.List; +import java.util.Set; + +/** + * Planner rule that pushes a SemoJoin down in a tree past a LogicalJoin + * in order to trigger other rules that will convert {@code SemiJoin}s. + * + *

+ *

+ * Whether this first or second conversion is applied depends on + * which operands actually participate in the semi-join. + */ +public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory { + @Override + public Rule build() { + return leftSemiLogicalJoin(logicalProject(logicalJoin()), group()) + .when(this::conditionChecker) + .then(topSemiJoin -> { + LogicalProject> project = topSemiJoin.left(); + LogicalJoin bottomJoin = project.child(); + GroupPlan a = bottomJoin.left(); + GroupPlan b = bottomJoin.right(); + GroupPlan c = topSemiJoin.right(); + + boolean lasscom = a.getOutputSet().containsAll(project.getOutput()); + + if (lasscom) { + /*- + * topSemiJoin newTopProject + * / \ | + * project C newTopJoin + * | -> / \ + * bottomJoin newBottomSemiJoin B + * / \ / \ + * A B aNewProject C + * | + * A + */ + List projects = project.getProjects(); + LogicalProject aNewProject = new LogicalProject<>(projects, a); + LogicalJoin, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( + topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), + topSemiJoin.getOtherJoinCondition(), aNewProject, c); + LogicalJoin, GroupPlan>, GroupPlan> newTopJoin + = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b); + return new LogicalProject<>(projects, newTopJoin); + } else { + /*- + * topSemiJoin newTopProject + * / \ | + * project C newTopJoin + * | / \ + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B bNewProject C + * | + * B + */ + List projects = project.getProjects(); + LogicalProject bNewProject = new LogicalProject<>(projects, b); + LogicalJoin, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( + topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), + topSemiJoin.getOtherJoinCondition(), bNewProject, c); + + LogicalJoin, GroupPlan>> newTopJoin + = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin); + return new LogicalProject<>(projects, newTopJoin); + } + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + } + + // bottomJoin just return A OR B, else return false. + private boolean conditionChecker( + LogicalJoin>, GroupPlan> topJoin) { + Set projectOutputSet = topJoin.left().getOutputSet(); + + Set aOutputSet = topJoin.left().child().left().getOutputSet(); + Set bOutputSet = topJoin.left().child().right().getOutputSet(); + + boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet); + boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet); + + Preconditions.checkState(isProjectA || isProjectB, "project must contain child"); + return !(isProjectA && isProjectB); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java new file mode 100644 index 0000000000..31d326612d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +/** + * Rule for SemiJoinTranspose. + *

+ * LEFT-Semi/ANTI(LEFT-Semi/ANTI(X, Y), Z) + * -> + * LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z)) + */ +public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { + + public static Set> typeSet = ImmutableSet.of( + Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN), + Pair.of(JoinType.LEFT_ANTI_JOIN, JoinType.LEFT_ANTI_JOIN), + Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_ANTI_JOIN), + Pair.of(JoinType.LEFT_ANTI_JOIN, JoinType.LEFT_SEMI_JOIN)); + + /* + * topJoin newTopJoin + * / \ / \ + * bottomJoin C --> newBottomJoin B + * / \ / \ + * A B A C + */ + @Override + public Rule build() { + return logicalJoin(logicalJoin(), group()) + .when(this::typeChecker) + .then(topJoin -> { + LogicalJoin bottomJoin = topJoin.left(); + GroupPlan a = bottomJoin.left(); + GroupPlan b = bottomJoin.right(); + GroupPlan c = topJoin.right(); + + LogicalJoin newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), + topJoin.getHashJoinConjuncts(), topJoin.getOtherJoinCondition(), a, c); + LogicalJoin, GroupPlan> newTopJoin = new LogicalJoin<>( + bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinCondition(), + newBottomJoin, b); + + return newTopJoin; + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + } + + private boolean typeChecker(LogicalJoin, GroupPlan> topJoin) { + return typeSet.contains(Pair.of(topJoin.getJoinType(), topJoin.left().getJoinType())); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java new file mode 100644 index 0000000000..c7aa852449 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class SemiJoinTransposeTest { + public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testSemiJoinLogicalTransposeCommute() { + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .project(ImmutableList.of(0)) + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform((new SemiJoinLogicalJoinTransposeProject()).build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + + Plan join = plan.child(0); + Assertions.assertTrue(join instanceof LogicalJoin); + Assertions.assertEquals(JoinType.INNER_JOIN, ((LogicalJoin) join).getJoinType()); + Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, + ((LogicalJoin) ((LogicalJoin) join).left()).getJoinType()); + }); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java new file mode 100644 index 0000000000..950647a674 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class LogicalPlanBuilder { + private final LogicalPlan plan; + + public LogicalPlanBuilder(LogicalPlan plan) { + this.plan = plan; + } + + public LogicalPlan build() { + return plan; + } + + public LogicalPlanBuilder from(LogicalPlan plan) { + return new LogicalPlanBuilder(plan); + } + + public LogicalPlanBuilder scan(long tableId, String tableName, int hashColumn) { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(tableId, tableName, hashColumn); + return from(scan); + } + + public LogicalPlanBuilder projectWithExprs(List projectExprs) { + LogicalProject project = new LogicalProject<>(projectExprs, this.plan); + return from(project); + } + + public LogicalPlanBuilder project(List slots) { + List projectExprs = Lists.newArrayList(); + for (int i = 0; i < slots.size(); i++) { + projectExprs.add(this.plan.getOutput().get(i)); + } + LogicalProject project = new LogicalProject<>(projectExprs, this.plan); + return from(project); + } + + public LogicalPlanBuilder hashJoinUsing(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { + ImmutableList hashConjunts = ImmutableList.of( + new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); + + LogicalJoin join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjunts), + Optional.empty(), this.plan, right); + return from(join); + } +}