[fix](Nereids): avoid commute cause dead-loop. (#12616)
* [fix](Nereids): avoid commute cause dead-loop. * update best plan
This commit is contained in:
@ -182,7 +182,7 @@ public class Group {
|
||||
*/
|
||||
public void setBestPlan(GroupExpression expression, double cost, PhysicalProperties properties) {
|
||||
if (lowestCostPlans.containsKey(properties)) {
|
||||
if (lowestCostPlans.get(properties).first > cost) {
|
||||
if (lowestCostPlans.get(properties).first >= cost) {
|
||||
lowestCostPlans.put(properties, Pair.of(cost, expression));
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -140,6 +140,10 @@ public class GroupExpression {
|
||||
ruleMasks.set(rule.getRuleType().ordinal());
|
||||
}
|
||||
|
||||
public void setApplied(RuleType ruleType) {
|
||||
ruleMasks.set(ruleType.ordinal());
|
||||
}
|
||||
|
||||
public void propagateApplied(GroupExpression toGroupExpression) {
|
||||
toGroupExpression.ruleMasks.or(ruleMasks);
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
|
||||
@ -46,7 +47,8 @@ import java.util.List;
|
||||
*/
|
||||
public class RuleSet {
|
||||
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_LEFT_DEEP)
|
||||
.add(JoinCommute.LEFT_DEEP)
|
||||
.add(JoinCommuteProject.LEFT_DEEP)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(new PushdownFilterThroughProject())
|
||||
@ -73,7 +75,7 @@ public class RuleSet {
|
||||
.build();
|
||||
|
||||
public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_LEFT_DEEP)
|
||||
.add(JoinCommute.LEFT_DEEP)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(JoinLAsscom.OUTER)
|
||||
@ -82,7 +84,7 @@ public class RuleSet {
|
||||
.build();
|
||||
|
||||
public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_ZIG_ZAG)
|
||||
.add(JoinCommute.ZIG_ZAG)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(JoinLAsscom.OUTER)
|
||||
@ -91,7 +93,7 @@ public class RuleSet {
|
||||
.build();
|
||||
|
||||
public static final List<Rule> BUSHY_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_BUSHY)
|
||||
.add(JoinCommute.BUSHY)
|
||||
// TODO: add more rule
|
||||
// .add(JoinLeftAssociate.INNER)
|
||||
// .add(JoinLeftAssociateProject.INNER)
|
||||
|
||||
@ -114,7 +114,7 @@ public enum RuleType {
|
||||
|
||||
// exploration rules
|
||||
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
|
||||
|
||||
@ -20,23 +20,21 @@ package org.apache.doris.nereids.rules.exploration.join;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Join Commute
|
||||
*/
|
||||
public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
|
||||
public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
|
||||
public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY);
|
||||
public static final JoinCommute LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
|
||||
public static final JoinCommute ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
|
||||
public static final JoinCommute BUSHY = new JoinCommute(SwapType.BUSHY);
|
||||
|
||||
private final SwapType swapType;
|
||||
|
||||
@ -44,14 +42,10 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
enum SwapType {
|
||||
LEFT_DEEP, ZIG_ZAG, BUSHY
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin()
|
||||
.when(this::check)
|
||||
.when(join -> JoinCommuteHelper.check(swapType, join))
|
||||
.then(join -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType().swap(),
|
||||
@ -60,30 +54,11 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) {
|
||||
if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) {
|
||||
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
}
|
||||
|
||||
return PlanUtils.project(new ArrayList<>(join.getOutput()), newJoin).get();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
}
|
||||
|
||||
private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
|
||||
}
|
||||
|
||||
private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
// TODO: tmp way to judge bottomJoin
|
||||
return containJoin(join.left()) || containJoin(join.right());
|
||||
}
|
||||
|
||||
private boolean containJoin(GroupPlan groupPlan) {
|
||||
// TODO: tmp way to judge containJoin
|
||||
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
|
||||
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Join Commute Helper
|
||||
*/
|
||||
class JoinCommuteHelper {
|
||||
enum SwapType {
|
||||
LEFT_DEEP, ZIG_ZAG, BUSHY
|
||||
}
|
||||
|
||||
public static boolean check(SwapType swapType, LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
|
||||
}
|
||||
|
||||
public static boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
// TODO: tmp way to judge bottomJoin
|
||||
return containJoin(join.left()) || containJoin(join.right());
|
||||
}
|
||||
|
||||
private static boolean containJoin(GroupPlan groupPlan) {
|
||||
// TODO: tmp way to judge containJoin
|
||||
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
|
||||
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Project-Join Commute.
|
||||
* This rule can prevent double JoinCommute cause dead-loop in Memo.
|
||||
*/
|
||||
public class JoinCommuteProject extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommuteProject LEFT_DEEP = new JoinCommuteProject(SwapType.LEFT_DEEP);
|
||||
public static final JoinCommuteProject ZIG_ZAG = new JoinCommuteProject(SwapType.ZIG_ZAG);
|
||||
public static final JoinCommuteProject BUSHY = new JoinCommuteProject(SwapType.BUSHY);
|
||||
|
||||
private final SwapType swapType;
|
||||
|
||||
public JoinCommuteProject(SwapType swapType) {
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(logicalJoin())
|
||||
.when(project -> JoinCommuteHelper.check(swapType, project.child()))
|
||||
.then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
// prevent this join match by JoinCommute.
|
||||
join.getGroupExpression().get().setApplied(RuleType.LOGICAL_JOIN_COMMUTATE);
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType().swap(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
if (swapType == SwapType.ZIG_ZAG && JoinCommuteHelper.isNotBottomJoin(join)) {
|
||||
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
}
|
||||
|
||||
return PlanUtils.project(new ArrayList<>(project.getProjects()), newJoin).get();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
|
||||
}
|
||||
}
|
||||
@ -53,7 +53,7 @@ public class JoinCommuteTest {
|
||||
Optional.empty(), scan1, scan2);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
.transform(JoinCommute.OUTER_LEFT_DEEP.build())
|
||||
.transform(JoinCommute.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
|
||||
Reference in New Issue
Block a user