[fix](Nereids): avoid commute cause dead-loop. (#12616)

* [fix](Nereids): avoid commute cause dead-loop.

* update best plan
This commit is contained in:
jakevin
2022-09-15 10:47:11 +08:00
committed by GitHub
parent 8aa5899484
commit 6543924790
8 changed files with 141 additions and 39 deletions

View File

@ -182,7 +182,7 @@ public class Group {
*/
public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
if (lowestCostPlans.get(properties).first > cost) {
if (lowestCostPlans.get(properties).first >= cost) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {

View File

@ -140,6 +140,10 @@ public class GroupExpression {
ruleMasks.set(rule.getRuleType().ordinal());
}
public void setApplied(RuleType ruleType) {
ruleMasks.set(ruleType.ordinal());
}
public void propagateApplied(GroupExpression toGroupExpression) {
toGroupExpression.ruleMasks.or(ruleMasks);
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
@ -46,7 +47,8 @@ import java.util.List;
*/
public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(JoinCommute.OUTER_LEFT_DEEP)
.add(JoinCommute.LEFT_DEEP)
.add(JoinCommuteProject.LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(new PushdownFilterThroughProject())
@ -73,7 +75,7 @@ public class RuleSet {
.build();
public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_LEFT_DEEP)
.add(JoinCommute.LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(JoinLAsscom.OUTER)
@ -82,7 +84,7 @@ public class RuleSet {
.build();
public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_ZIG_ZAG)
.add(JoinCommute.ZIG_ZAG)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(JoinLAsscom.OUTER)
@ -91,7 +93,7 @@ public class RuleSet {
.build();
public static final List<Rule> BUSHY_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_BUSHY)
.add(JoinCommute.BUSHY)
// TODO: add more rule
// .add(JoinLeftAssociate.INNER)
// .add(JoinLeftAssociateProject.INNER)

View File

@ -114,7 +114,7 @@ public enum RuleType {
// exploration rules
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),

View File

@ -20,23 +20,21 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.List;
/**
* Join Commute
*/
public class JoinCommute extends OneExplorationRuleFactory {
public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY);
public static final JoinCommute LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
public static final JoinCommute ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
public static final JoinCommute BUSHY = new JoinCommute(SwapType.BUSHY);
private final SwapType swapType;
@ -44,14 +42,10 @@ public class JoinCommute extends OneExplorationRuleFactory {
this.swapType = swapType;
}
enum SwapType {
LEFT_DEEP, ZIG_ZAG, BUSHY
}
@Override
public Rule build() {
return logicalJoin()
.when(this::check)
.when(join -> JoinCommuteHelper.check(swapType, join))
.then(join -> {
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType().swap(),
@ -60,30 +54,11 @@ public class JoinCommute extends OneExplorationRuleFactory {
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) {
if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) {
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
}
return PlanUtils.project(new ArrayList<>(join.getOutput()), newJoin).get();
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
}
private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
return false;
}
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
}
private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
// TODO: tmp way to judge bottomJoin
return containJoin(join.left()) || containJoin(join.right());
}
private boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
}
}

View File

@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.List;
/**
* Join Commute Helper
*/
class JoinCommuteHelper {
enum SwapType {
LEFT_DEEP, ZIG_ZAG, BUSHY
}
public static boolean check(SwapType swapType, LogicalJoin<GroupPlan, GroupPlan> join) {
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
return false;
}
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
}
public static boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
// TODO: tmp way to judge bottomJoin
return containJoin(join.left()) || containJoin(join.right());
}
private static boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}
}

View File

@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.PlanUtils;
import java.util.ArrayList;
/**
* Project-Join Commute.
* This rule can prevent double JoinCommute cause dead-loop in Memo.
*/
public class JoinCommuteProject extends OneExplorationRuleFactory {
public static final JoinCommuteProject LEFT_DEEP = new JoinCommuteProject(SwapType.LEFT_DEEP);
public static final JoinCommuteProject ZIG_ZAG = new JoinCommuteProject(SwapType.ZIG_ZAG);
public static final JoinCommuteProject BUSHY = new JoinCommuteProject(SwapType.BUSHY);
private final SwapType swapType;
public JoinCommuteProject(SwapType swapType) {
this.swapType = swapType;
}
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> JoinCommuteHelper.check(swapType, project.child()))
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
// prevent this join match by JoinCommute.
join.getGroupExpression().get().setApplied(RuleType.LOGICAL_JOIN_COMMUTATE);
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType().swap(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) {
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
}
return PlanUtils.project(new ArrayList<>(project.getProjects()), newJoin).get();
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
}
}

View File

@ -53,7 +53,7 @@ public class JoinCommuteTest {
Optional.empty(), scan1, scan2);
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.transform(JoinCommute.OUTER_LEFT_DEEP.build())
.transform(JoinCommute.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());