diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index ca431e8959..34b4b50b14 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -182,7 +182,7 @@ public class Group { */ public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) { if (lowestCostPlans.containsKey(properties)) { - if (lowestCostPlans.get(properties).first > cost) { + if (lowestCostPlans.get(properties).first >= cost) { lowestCostPlans.put(properties, Pair.of(cost, expression)); } } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index 6137647887..9601f54b42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -140,6 +140,10 @@ public class GroupExpression { ruleMasks.set(rule.getRuleType().ordinal()); } + public void setApplied(RuleType ruleType) { + ruleMasks.set(ruleType.ordinal()); + } + public void propagateApplied(GroupExpression toGroupExpression) { toGroupExpression.ruleMasks.or(ruleMasks); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 70251110ec..fa480135ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules; import org.apache.doris.nereids.rules.exploration.join.JoinCommute; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject; import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom; import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; @@ -46,7 +47,8 @@ import java.util.List; */ public class RuleSet { public static final List EXPLORATION_RULES = planRuleFactories() - .add(JoinCommute.OUTER_LEFT_DEEP) + .add(JoinCommute.LEFT_DEEP) + .add(JoinCommuteProject.LEFT_DEEP) .add(JoinLAsscom.INNER) .add(JoinLAsscomProject.INNER) .add(new PushdownFilterThroughProject()) @@ -73,7 +75,7 @@ public class RuleSet { .build(); public static final List LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories() - .add(JoinCommute.OUTER_LEFT_DEEP) + .add(JoinCommute.LEFT_DEEP) .add(JoinLAsscom.INNER) .add(JoinLAsscomProject.INNER) .add(JoinLAsscom.OUTER) @@ -82,7 +84,7 @@ public class RuleSet { .build(); public static final List ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories() - .add(JoinCommute.OUTER_ZIG_ZAG) + .add(JoinCommute.ZIG_ZAG) .add(JoinLAsscom.INNER) .add(JoinLAsscomProject.INNER) .add(JoinLAsscom.OUTER) @@ -91,7 +93,7 @@ public class RuleSet { .build(); public static final List BUSHY_TREE_JOIN_REORDER = planRuleFactories() - .add(JoinCommute.OUTER_BUSHY) + .add(JoinCommute.BUSHY) // TODO: add more rule // .add(JoinLeftAssociate.INNER) // .add(JoinLeftAssociateProject.INNER) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 98aac8e923..d90e0c089d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -114,7 +114,7 @@ public enum RuleType { // exploration rules TEST_EXPLORATION(RuleTypeClass.EXPLORATION), - LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION), + LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION), LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 1b40df3bc9..7df4b662d9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -20,23 +20,21 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.PlanUtils; -import org.apache.doris.nereids.util.Utils; import java.util.ArrayList; -import java.util.List; /** * Join Commute */ public class JoinCommute extends OneExplorationRuleFactory { - public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP); - public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG); - public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY); + public static final JoinCommute LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP); + public static final JoinCommute ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG); + public static final JoinCommute BUSHY = new JoinCommute(SwapType.BUSHY); private final SwapType swapType; @@ -44,14 +42,10 @@ public class JoinCommute extends OneExplorationRuleFactory { this.swapType = swapType; } - enum SwapType { - LEFT_DEEP, ZIG_ZAG, BUSHY - } - @Override public Rule build() { return logicalJoin() - .when(this::check) + .when(join -> JoinCommuteHelper.check(swapType, join)) .then(join -> { LogicalJoin newJoin = new LogicalJoin<>( join.getJoinType().swap(), @@ -60,30 +54,11 @@ public class JoinCommute extends OneExplorationRuleFactory { join.right(), join.left(), join.getJoinReorderContext()); newJoin.getJoinReorderContext().setHasCommute(true); - if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) { + if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) { newJoin.getJoinReorderContext().setHasCommuteZigZag(true); } return PlanUtils.project(new ArrayList<>(join.getOutput()), newJoin).get(); - }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); - } - - private boolean check(LogicalJoin join) { - if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) { - return false; - } - - return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange(); - } - - private boolean isNotBottomJoin(LogicalJoin join) { - // TODO: tmp way to judge bottomJoin - return containJoin(join.left()) || containJoin(join.right()); - } - - private boolean containJoin(GroupPlan groupPlan) { - // TODO: tmp way to judge containJoin - List output = Utils.getOutputSlotReference(groupPlan); - return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals); + }).toRule(RuleType.LOGICAL_JOIN_COMMUTATE); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java new file mode 100644 index 0000000000..288c566a72 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.Utils; + +import java.util.List; + +/** + * Join Commute Helper + */ +class JoinCommuteHelper { + enum SwapType { + LEFT_DEEP, ZIG_ZAG, BUSHY + } + + public static boolean check(SwapType swapType, LogicalJoin join) { + if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) { + return false; + } + + return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange(); + } + + public static boolean isNotBottomJoin(LogicalJoin join) { + // TODO: tmp way to judge bottomJoin + return containJoin(join.left()) || containJoin(join.right()); + } + + private static boolean containJoin(GroupPlan groupPlan) { + // TODO: tmp way to judge containJoin + List output = Utils.getOutputSlotReference(groupPlan); + return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java new file mode 100644 index 0000000000..c439036722 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.PlanUtils; + +import java.util.ArrayList; + +/** + * Project-Join Commute. + * This rule can prevent double JoinCommute cause dead-loop in Memo. + */ +public class JoinCommuteProject extends OneExplorationRuleFactory { + + public static final JoinCommuteProject LEFT_DEEP = new JoinCommuteProject(SwapType.LEFT_DEEP); + public static final JoinCommuteProject ZIG_ZAG = new JoinCommuteProject(SwapType.ZIG_ZAG); + public static final JoinCommuteProject BUSHY = new JoinCommuteProject(SwapType.BUSHY); + + private final SwapType swapType; + + public JoinCommuteProject(SwapType swapType) { + this.swapType = swapType; + } + + @Override + public Rule build() { + return logicalProject(logicalJoin()) + .when(project -> JoinCommuteHelper.check(swapType, project.child())) + .then(project -> { + LogicalJoin join = project.child(); + // prevent this join match by JoinCommute. + join.getGroupExpression().get().setApplied(RuleType.LOGICAL_JOIN_COMMUTATE); + LogicalJoin newJoin = new LogicalJoin<>( + join.getJoinType().swap(), + join.getHashJoinConjuncts(), + join.getOtherJoinCondition(), + join.right(), join.left(), + join.getJoinReorderContext()); + newJoin.getJoinReorderContext().setHasCommute(true); + if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) { + newJoin.getJoinReorderContext().setHasCommuteZigZag(true); + } + + return PlanUtils.project(new ArrayList<>(project.getProjects()), newJoin).get(); + }).toRule(RuleType.LOGICAL_JOIN_COMMUTATE); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java index 46e31ead85..5eae21170d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java @@ -53,7 +53,7 @@ public class JoinCommuteTest { Optional.empty(), scan1, scan2); PlanChecker.from(MemoTestUtils.createConnectContext(), join) - .transform(JoinCommute.OUTER_LEFT_DEEP.build()) + .transform(JoinCommute.LEFT_DEEP.build()) .checkMemo(memo -> { Group root = memo.getRoot(); Assertions.assertEquals(2, root.getLogicalExpressions().size());