[feature](Nereids): Left deep tree join order. (#12439)
* [feature](Nereids): Left deep tree join order.
This commit is contained in:
@ -18,7 +18,8 @@
|
||||
package org.apache.doris.nereids.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
|
||||
@ -32,6 +33,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalPro
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
@ -43,8 +45,10 @@ import java.util.List;
|
||||
*/
|
||||
public class RuleSet {
|
||||
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
|
||||
.add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
|
||||
.add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG)
|
||||
.add(JoinCommute.OUTER_LEFT_DEEP)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(new MergeConsecutiveProjects())
|
||||
.build();
|
||||
|
||||
public static final List<Rule> REWRITE_RULES = planRuleFactories()
|
||||
@ -66,6 +70,40 @@ public class RuleSet {
|
||||
.add(new LogicalEmptyRelationToPhysicalEmptyRelation())
|
||||
.build();
|
||||
|
||||
public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_LEFT_DEEP)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(JoinLAsscom.OUTER)
|
||||
.add(JoinLAsscomProject.OUTER)
|
||||
// semi join Transpose ....
|
||||
.build();
|
||||
|
||||
public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_ZIG_ZAG)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(JoinLAsscom.OUTER)
|
||||
.add(JoinLAsscomProject.OUTER)
|
||||
// semi join Transpose ....
|
||||
.build();
|
||||
|
||||
public static final List<Rule> BUSHY_TREE_JOIN_REORDER = planRuleFactories()
|
||||
.add(JoinCommute.OUTER_BUSHY)
|
||||
// TODO: add more rule
|
||||
// .add(JoinLeftAssociate.INNER)
|
||||
// .add(JoinLeftAssociateProject.INNER)
|
||||
// .add(JoinRightAssociate.INNER)
|
||||
// .add(JoinRightAssociateProject.INNER)
|
||||
// .add(JoinExchange.INNER)
|
||||
// .add(JoinExchangeBothProject.INNER)
|
||||
// .add(JoinExchangeLeftProject.INNER)
|
||||
// .add(JoinExchangeRightProject.INNER)
|
||||
// .add(JoinRightAssociate.OUTER)
|
||||
.add(JoinLAsscom.OUTER)
|
||||
// semi join Transpose ....
|
||||
.build();
|
||||
|
||||
public List<Rule> getExplorationRules() {
|
||||
return EXPLORATION_RULES;
|
||||
}
|
||||
|
||||
@ -17,54 +17,72 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Join Commute
|
||||
*/
|
||||
@Developing
|
||||
public class JoinCommute extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
|
||||
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
|
||||
public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
|
||||
public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
|
||||
public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY);
|
||||
|
||||
private final boolean swapOuter;
|
||||
private final SwapType swapType;
|
||||
|
||||
public JoinCommute(boolean swapOuter) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = SwapType.ALL;
|
||||
public JoinCommute(SwapType swapType) {
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
public JoinCommute(boolean swapOuter, SwapType swapType) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = swapType;
|
||||
enum SwapType {
|
||||
LEFT_DEEP, ZIG_ZAG, BUSHY
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> {
|
||||
// TODO: add project for mapping column output.
|
||||
// List<NamedExpression> newOutput = new ArrayList<>(join.getOutput());
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
|
||||
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
// }
|
||||
return innerLogicalJoin()
|
||||
.when(this::check)
|
||||
.then(join -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) {
|
||||
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
}
|
||||
|
||||
// LogicalProject<LogicalJoin> project = new LogicalProject<>(newOutput, newJoin);
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
return JoinReorderCommon.project(new ArrayList<>(join.getOutput()), newJoin).get();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
}
|
||||
|
||||
private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
|
||||
}
|
||||
|
||||
private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
// TODO: tmp way to judge bottomJoin
|
||||
return containJoin(join.left()) || containJoin(join.right());
|
||||
}
|
||||
|
||||
private boolean containJoin(GroupPlan groupPlan) {
|
||||
// TODO: tmp way to judge containJoin
|
||||
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
|
||||
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
/**
|
||||
* Project-Join commute
|
||||
*/
|
||||
public class JoinCommuteProject extends OneExplorationRuleFactory {
|
||||
|
||||
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
|
||||
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
|
||||
|
||||
private final SwapType swapType;
|
||||
private final boolean swapOuter;
|
||||
|
||||
public JoinCommuteProject(boolean swapOuter) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = SwapType.ALL;
|
||||
}
|
||||
|
||||
public JoinCommuteProject(boolean swapOuter, SwapType swapType) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
|
||||
join.getJoinType(),
|
||||
join.getHashJoinConjuncts(),
|
||||
join.getOtherJoinCondition(),
|
||||
join.right(), join.left(),
|
||||
join.getJoinReorderContext());
|
||||
newJoin.getJoinReorderContext().setHasCommute(true);
|
||||
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
|
||||
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
|
||||
// }
|
||||
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
|
||||
}
|
||||
}
|
||||
@ -17,18 +17,42 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Rule for change inner join LAsscom (associative and commutive).
|
||||
*/
|
||||
@Developing
|
||||
public class JoinLAsscom extends OneExplorationRuleFactory {
|
||||
// for inner-inner
|
||||
public static final JoinLAsscom INNER = new JoinLAsscom(Type.INNER);
|
||||
// for inner-leftOuter or leftOuter-leftOuter
|
||||
public static final JoinLAsscom OUTER = new JoinLAsscom(Type.OUTER);
|
||||
|
||||
private final Predicate<LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan>> typeChecker;
|
||||
|
||||
private final Type type;
|
||||
|
||||
/**
|
||||
* Specify join type.
|
||||
*/
|
||||
public JoinLAsscom(Type type) {
|
||||
this.type = type;
|
||||
if (type == Type.INNER) {
|
||||
typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().getJoinType().isInnerJoin();
|
||||
} else {
|
||||
typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
|
||||
Pair.of(join.left().getJoinType(), join.getJoinType()));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
@ -39,18 +63,14 @@ public class JoinLAsscom extends OneExplorationRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin(logicalJoin(), group())
|
||||
.when(JoinLAsscomHelper::check)
|
||||
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin()
|
||||
&& (join.left().getJoinType().isInnerJoin() || join.left().getJoinType().isLeftOuterJoin()))
|
||||
.then(topJoin -> {
|
||||
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin);
|
||||
if (!helper.initJoinOnCondition()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return helper.newTopJoin();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
.when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left()))
|
||||
.when(typeChecker)
|
||||
.then(topJoin -> {
|
||||
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left());
|
||||
if (!helper.initJoinOnCondition()) {
|
||||
return null;
|
||||
}
|
||||
return helper.newTopJoin();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,29 +18,27 @@
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Common function for JoinLAsscom
|
||||
*/
|
||||
public class JoinLAsscomHelper {
|
||||
class JoinLAsscomHelper extends ThreeJoinHelper {
|
||||
/*
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
@ -48,209 +46,79 @@ public class JoinLAsscomHelper {
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
private final LogicalJoin topJoin;
|
||||
private final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
|
||||
private final Plan a;
|
||||
private final Plan b;
|
||||
private final Plan c;
|
||||
|
||||
private final List<Expression> topHashJoinConjuncts;
|
||||
private final List<Expression> bottomHashJoinConjuncts;
|
||||
private final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList();
|
||||
private final List<SlotReference> aOutputSlots;
|
||||
private final List<SlotReference> bOutputSlots;
|
||||
private final List<SlotReference> cOutputSlots;
|
||||
|
||||
private final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList();
|
||||
private final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList();
|
||||
|
||||
private final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList();
|
||||
private final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList();
|
||||
// Pair<bottomJoin, topJoin>
|
||||
// newBottomJoin Type = topJoin Type, newTopJoin Type = bottomJoin Type
|
||||
public static Set<Pair<JoinType, JoinType>> outerSet = ImmutableSet.of(
|
||||
Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN),
|
||||
Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN),
|
||||
Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN));
|
||||
|
||||
/**
|
||||
* Init plan and output.
|
||||
*/
|
||||
public JoinLAsscomHelper(LogicalJoin<? extends Plan, GroupPlan> topJoin,
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
|
||||
this.topJoin = topJoin;
|
||||
this.bottomJoin = bottomJoin;
|
||||
|
||||
a = bottomJoin.left();
|
||||
b = bottomJoin.right();
|
||||
c = topJoin.right();
|
||||
|
||||
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(),
|
||||
"topJoin hashJoinConjuncts must exist.");
|
||||
topHashJoinConjuncts = topJoin.getHashJoinConjuncts();
|
||||
if (topJoin.getOtherJoinCondition().isPresent()) {
|
||||
allNonHashJoinConjuncts.addAll(
|
||||
ExpressionUtils.extractConjunction(topJoin.getOtherJoinCondition().get()));
|
||||
}
|
||||
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
|
||||
"bottomJoin onClause must exist.");
|
||||
bottomHashJoinConjuncts = bottomJoin.getHashJoinConjuncts();
|
||||
if (bottomJoin.getOtherJoinCondition().isPresent()) {
|
||||
allNonHashJoinConjuncts.addAll(
|
||||
ExpressionUtils.extractConjunction(bottomJoin.getOtherJoinCondition().get()));
|
||||
}
|
||||
|
||||
aOutputSlots = Utils.getOutputSlotReference(a);
|
||||
bOutputSlots = Utils.getOutputSlotReference(b);
|
||||
cOutputSlots = Utils.getOutputSlotReference(c);
|
||||
super(topJoin, bottomJoin, bottomJoin.left(), bottomJoin.right(), topJoin.right());
|
||||
}
|
||||
|
||||
public static JoinLAsscomHelper of(LogicalJoin<? extends Plan, GroupPlan> topJoin,
|
||||
/**
|
||||
* Create newTopJoin.
|
||||
*/
|
||||
public Plan newTopJoin() {
|
||||
Pair<List<NamedExpression>, List<NamedExpression>> projectPair = splitProjectExprs(bOutput);
|
||||
List<NamedExpression> newLeftProjectExpr = projectPair.second;
|
||||
List<NamedExpression> newRightProjectExprs = projectPair.first;
|
||||
|
||||
// If add project to B, we should add all slotReference used by hashOnCondition.
|
||||
// TODO: Does nonHashOnCondition also need to be considered.
|
||||
Set<SlotReference> onUsedSlotRef = bottomJoin.getHashJoinConjuncts().stream()
|
||||
.flatMap(expr -> {
|
||||
Set<SlotReference> usedSlotRefs = expr.collect(SlotReference.class::isInstance);
|
||||
return usedSlotRefs.stream();
|
||||
}).filter(Utils.getOutputSlotReference(bottomJoin)::contains).collect(Collectors.toSet());
|
||||
boolean existRightProject = !newRightProjectExprs.isEmpty();
|
||||
boolean existLeftProject = !newLeftProjectExpr.isEmpty();
|
||||
onUsedSlotRef.forEach(slotRef -> {
|
||||
if (existRightProject && bOutput.contains(slotRef) && !newRightProjectExprs.contains(slotRef)) {
|
||||
newRightProjectExprs.add(slotRef);
|
||||
} else if (existLeftProject && aOutput.contains(slotRef) && !newLeftProjectExpr.contains(slotRef)) {
|
||||
newLeftProjectExpr.add(slotRef);
|
||||
}
|
||||
});
|
||||
|
||||
if (existLeftProject) {
|
||||
newLeftProjectExpr.addAll(cOutput);
|
||||
}
|
||||
LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(),
|
||||
newBottomHashJoinConjuncts, ExpressionUtils.andByOptional(newBottomNonHashJoinConjuncts), a, c,
|
||||
bottomJoin.getJoinReorderContext());
|
||||
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
|
||||
newBottomJoin.getJoinReorderContext().setHasCommute(false);
|
||||
|
||||
Plan left = JoinReorderCommon.project(newLeftProjectExpr, newBottomJoin).orElse(newBottomJoin);
|
||||
Plan right = JoinReorderCommon.project(newRightProjectExprs, b).orElse(b);
|
||||
|
||||
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
|
||||
newTopHashJoinConjuncts,
|
||||
ExpressionUtils.andByOptional(newTopNonHashJoinConjuncts), left, right,
|
||||
topJoin.getJoinReorderContext());
|
||||
newTopJoin.getJoinReorderContext().setHasLAsscom(true);
|
||||
|
||||
return JoinReorderCommon.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get();
|
||||
}
|
||||
|
||||
public static boolean check(Type type, LogicalJoin<? extends Plan, GroupPlan> topJoin,
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
|
||||
return new JoinLAsscomHelper(topJoin, bottomJoin);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the onCondition of newTopJoin and newBottomJoin.
|
||||
*/
|
||||
public boolean initJoinOnCondition() {
|
||||
for (Expression topJoinOnClauseConjunct : topHashJoinConjuncts) {
|
||||
// Ignore join with some OnClause like:
|
||||
// Join C = B + A for above example.
|
||||
Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.getInputSlots();
|
||||
if (topJoinUsedSlot.containsAll(aOutputSlots)
|
||||
&& topJoinUsedSlot.containsAll(bOutputSlots)
|
||||
&& topJoinUsedSlot.containsAll(cOutputSlots)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
List<Expression> allHashJoinConjuncts = Lists.newArrayList();
|
||||
allHashJoinConjuncts.addAll(topHashJoinConjuncts);
|
||||
allHashJoinConjuncts.addAll(bottomHashJoinConjuncts);
|
||||
|
||||
Set<Slot> newBottomJoinSlots = new HashSet<>(aOutputSlots);
|
||||
newBottomJoinSlots.addAll(cOutputSlots);
|
||||
|
||||
for (Expression hashConjunct : allHashJoinConjuncts) {
|
||||
Set<Slot> slots = hashConjunct.getInputSlots();
|
||||
if (newBottomJoinSlots.containsAll(slots)) {
|
||||
newBottomHashJoinConjuncts.add(hashConjunct);
|
||||
} else {
|
||||
newTopHashJoinConjuncts.add(hashConjunct);
|
||||
}
|
||||
}
|
||||
for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
|
||||
Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance);
|
||||
if (newBottomJoinSlots.containsAll(slots)) {
|
||||
newBottomNonHashJoinConjuncts.add(nonHashConjunct);
|
||||
} else {
|
||||
newTopNonHashJoinConjuncts.add(nonHashConjunct);
|
||||
}
|
||||
}
|
||||
// newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join.
|
||||
// Example:
|
||||
// A: col1, col2. B: col2, col3. C: col3, col4
|
||||
// (A & B on A.col2=B.col2) & C on B.col3=C.col3.
|
||||
// (A & B) & C -> (A & C) & B.
|
||||
// (A & C) will be cross join (newBottomJoinOnCondition is empty)
|
||||
if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get projectExpr of left and right.
|
||||
* Just for project-inside.
|
||||
*/
|
||||
private Pair<List<NamedExpression>, List<NamedExpression>> getProjectExprs() {
|
||||
Preconditions.checkArgument(topJoin.left() instanceof LogicalProject);
|
||||
LogicalProject project = (LogicalProject) topJoin.left();
|
||||
|
||||
List<NamedExpression> projectExprs = project.getProjects();
|
||||
List<NamedExpression> newRightProjectExprs = Lists.newArrayList();
|
||||
List<NamedExpression> newLeftProjectExpr = Lists.newArrayList();
|
||||
|
||||
HashSet<SlotReference> bOutputSlotsSet = new HashSet<>(bOutputSlots);
|
||||
for (NamedExpression projectExpr : projectExprs) {
|
||||
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
|
||||
if (bOutputSlotsSet.containsAll(usedSlotRefs)) {
|
||||
newRightProjectExprs.add(projectExpr);
|
||||
} else {
|
||||
newLeftProjectExpr.add(projectExpr);
|
||||
}
|
||||
}
|
||||
|
||||
return Pair.of(newLeftProjectExpr, newRightProjectExprs);
|
||||
}
|
||||
|
||||
private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() {
|
||||
Optional<Expression> bottomNonHashExpr;
|
||||
if (newBottomNonHashJoinConjuncts.isEmpty()) {
|
||||
bottomNonHashExpr = Optional.empty();
|
||||
if (type == Type.INNER) {
|
||||
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
|
||||
&& !topJoin.getJoinReorderContext().hasLAsscom();
|
||||
} else {
|
||||
bottomNonHashExpr = Optional.of(ExpressionUtils.and(newBottomNonHashJoinConjuncts));
|
||||
// hasCommute will cause to lack of OuterJoinAssocRule:Left
|
||||
return !topJoin.getJoinReorderContext().hasLeftAssociate()
|
||||
&& !topJoin.getJoinReorderContext().hasRightAssociate()
|
||||
&& !topJoin.getJoinReorderContext().hasExchange()
|
||||
&& !bottomJoin.getJoinReorderContext().hasCommute();
|
||||
}
|
||||
return new LogicalJoin(
|
||||
bottomJoin.getJoinType(),
|
||||
newBottomHashJoinConjuncts,
|
||||
bottomNonHashExpr,
|
||||
a, c);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create topJoin for project-inside.
|
||||
*/
|
||||
public LogicalJoin newProjectTopJoin() {
|
||||
Plan left;
|
||||
Plan right;
|
||||
|
||||
List<NamedExpression> newLeftProjectExpr = getProjectExprs().first;
|
||||
List<NamedExpression> newRightProjectExprs = getProjectExprs().second;
|
||||
if (!newLeftProjectExpr.isEmpty()) {
|
||||
left = new LogicalProject<>(newLeftProjectExpr, newBottomJoin());
|
||||
} else {
|
||||
left = newBottomJoin();
|
||||
}
|
||||
if (!newRightProjectExprs.isEmpty()) {
|
||||
right = new LogicalProject<>(newRightProjectExprs, b);
|
||||
} else {
|
||||
right = b;
|
||||
}
|
||||
Optional<Expression> topNonHashExpr;
|
||||
if (newTopNonHashJoinConjuncts.isEmpty()) {
|
||||
topNonHashExpr = Optional.empty();
|
||||
} else {
|
||||
topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
|
||||
}
|
||||
return new LogicalJoin<>(
|
||||
topJoin.getJoinType(),
|
||||
newTopHashJoinConjuncts,
|
||||
topNonHashExpr,
|
||||
left, right);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create topJoin for no-project-inside.
|
||||
*/
|
||||
public LogicalJoin newTopJoin() {
|
||||
// TODO: add column map (use project)
|
||||
// SlotReference bind() may have solved this problem.
|
||||
// source: | A | B | C |
|
||||
// target: | A | C | B |
|
||||
Optional<Expression> topNonHashExpr;
|
||||
if (newTopNonHashJoinConjuncts.isEmpty()) {
|
||||
topNonHashExpr = Optional.empty();
|
||||
} else {
|
||||
topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
|
||||
}
|
||||
return new LogicalJoin(
|
||||
topJoin.getJoinType(),
|
||||
newTopHashJoinConjuncts,
|
||||
topNonHashExpr,
|
||||
newBottomJoin(), b);
|
||||
}
|
||||
|
||||
public static boolean check(LogicalJoin topJoin) {
|
||||
if (topJoin.getJoinReorderContext().hasCommute()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,18 +17,43 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Rule for change inner join left associative to right.
|
||||
* Rule for change inner join LAsscom (associative and commutive).
|
||||
*/
|
||||
@Developing
|
||||
public class JoinLAsscomProject extends OneExplorationRuleFactory {
|
||||
// for inner-inner
|
||||
public static final JoinLAsscomProject INNER = new JoinLAsscomProject(Type.INNER);
|
||||
// for inner-leftOuter or leftOuter-leftOuter
|
||||
public static final JoinLAsscomProject OUTER = new JoinLAsscomProject(Type.OUTER);
|
||||
|
||||
private final Predicate<LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan>> typeChecker;
|
||||
|
||||
private final Type type;
|
||||
|
||||
/**
|
||||
* Specify join type.
|
||||
*/
|
||||
public JoinLAsscomProject(Type type) {
|
||||
this.type = type;
|
||||
if (type == Type.INNER) {
|
||||
typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().child().getJoinType().isInnerJoin();
|
||||
} else {
|
||||
typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
|
||||
Pair.of(join.left().child().getJoinType(), join.getJoinType()));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
@ -41,19 +66,15 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin(logicalProject(logicalJoin()), group())
|
||||
.when(JoinLAsscomHelper::check)
|
||||
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin()
|
||||
&& (join.left().child().getJoinType().isInnerJoin() || join.left().child().getJoinType()
|
||||
.isLeftOuterJoin()))
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
|
||||
|
||||
JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin);
|
||||
if (!helper.initJoinOnCondition()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return helper.newProjectTopJoin();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
.when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left().child()))
|
||||
.when(typeChecker)
|
||||
.then(topJoin -> {
|
||||
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left().child());
|
||||
helper.initAllProject(topJoin.left());
|
||||
if (!helper.initJoinOnCondition()) {
|
||||
return null;
|
||||
}
|
||||
return helper.newTopJoin();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,32 +17,24 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
/**
|
||||
* Common function for JoinCommute
|
||||
*/
|
||||
public class JoinCommuteHelper {
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
enum SwapType {
|
||||
BOTTOM_JOIN, ZIG_ZAG, ALL
|
||||
class JoinReorderCommon {
|
||||
public enum Type {
|
||||
INNER,
|
||||
OUTER
|
||||
}
|
||||
|
||||
private final boolean swapOuter;
|
||||
private final SwapType swapType;
|
||||
|
||||
public JoinCommuteHelper(boolean swapOuter, SwapType swapType) {
|
||||
this.swapOuter = swapOuter;
|
||||
this.swapType = swapType;
|
||||
}
|
||||
|
||||
public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
|
||||
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
|
||||
}
|
||||
|
||||
public static boolean check(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
|
||||
return check(project.child());
|
||||
public static Optional<Plan> project(List<NamedExpression> projectExprs, Plan plan) {
|
||||
if (!projectExprs.isEmpty()) {
|
||||
return Optional.of(new LogicalProject<>(projectExprs, plan));
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration;
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
/**
|
||||
* JoinReorderContext for Duplicate free.
|
||||
@ -26,6 +26,7 @@ package org.apache.doris.nereids.rules.exploration;
|
||||
public class JoinReorderContext {
|
||||
// left deep tree
|
||||
private boolean hasCommute = false;
|
||||
private boolean hasLAsscom = false;
|
||||
|
||||
// zig-zag tree
|
||||
private boolean hasCommuteZigZag = false;
|
||||
@ -38,16 +39,24 @@ public class JoinReorderContext {
|
||||
public JoinReorderContext() {
|
||||
}
|
||||
|
||||
/**
|
||||
* copy a JoinReorderContext.
|
||||
*/
|
||||
public void copyFrom(JoinReorderContext joinReorderContext) {
|
||||
this.hasCommute = joinReorderContext.hasCommute;
|
||||
this.hasLAsscom = joinReorderContext.hasLAsscom;
|
||||
this.hasExchange = joinReorderContext.hasExchange;
|
||||
this.hasLeftAssociate = joinReorderContext.hasLeftAssociate;
|
||||
this.hasRightAssociate = joinReorderContext.hasRightAssociate;
|
||||
this.hasCommuteZigZag = joinReorderContext.hasCommuteZigZag;
|
||||
}
|
||||
|
||||
/**
|
||||
* clear all.
|
||||
*/
|
||||
public void clear() {
|
||||
hasCommute = false;
|
||||
hasLAsscom = false;
|
||||
hasCommuteZigZag = false;
|
||||
hasExchange = false;
|
||||
hasRightAssociate = false;
|
||||
@ -62,6 +71,14 @@ public class JoinReorderContext {
|
||||
this.hasCommute = hasCommute;
|
||||
}
|
||||
|
||||
public boolean hasLAsscom() {
|
||||
return hasLAsscom;
|
||||
}
|
||||
|
||||
public void setHasLAsscom(boolean hasLAsscom) {
|
||||
this.hasLAsscom = hasLAsscom;
|
||||
}
|
||||
|
||||
public boolean hasExchange() {
|
||||
return hasExchange;
|
||||
}
|
||||
@ -0,0 +1,165 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Common join helper for three-join.
|
||||
*/
|
||||
abstract class ThreeJoinHelper {
|
||||
protected final LogicalJoin<? extends Plan, ? extends Plan> topJoin;
|
||||
protected final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
|
||||
protected final GroupPlan a;
|
||||
protected final GroupPlan b;
|
||||
protected final GroupPlan c;
|
||||
|
||||
protected final List<SlotReference> aOutput;
|
||||
protected final List<SlotReference> bOutput;
|
||||
protected final List<SlotReference> cOutput;
|
||||
|
||||
protected final List<NamedExpression> allProjects = Lists.newArrayList();
|
||||
|
||||
protected final List<Expression> allHashJoinConjuncts = Lists.newArrayList();
|
||||
protected final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList();
|
||||
|
||||
protected final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList();
|
||||
protected final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList();
|
||||
|
||||
protected final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList();
|
||||
protected final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList();
|
||||
|
||||
/**
|
||||
* Init plan and output.
|
||||
*/
|
||||
public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin,
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin, GroupPlan a, GroupPlan b, GroupPlan c) {
|
||||
this.topJoin = topJoin;
|
||||
this.bottomJoin = bottomJoin;
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
this.c = c;
|
||||
|
||||
aOutput = Utils.getOutputSlotReference(a);
|
||||
bOutput = Utils.getOutputSlotReference(b);
|
||||
cOutput = Utils.getOutputSlotReference(c);
|
||||
|
||||
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist.");
|
||||
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
|
||||
"bottomJoin hashJoinConjuncts must exist.");
|
||||
|
||||
allHashJoinConjuncts.addAll(topJoin.getHashJoinConjuncts());
|
||||
allHashJoinConjuncts.addAll(bottomJoin.getHashJoinConjuncts());
|
||||
topJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll(
|
||||
ExpressionUtils.extractConjunction(otherJoinCondition)));
|
||||
bottomJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll(
|
||||
ExpressionUtils.extractConjunction(otherJoinCondition)));
|
||||
}
|
||||
|
||||
@SafeVarargs
|
||||
public final void initAllProject(LogicalProject<? extends Plan>... projects) {
|
||||
for (LogicalProject<? extends Plan> project : projects) {
|
||||
allProjects.addAll(project.getProjects());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the onCondition of newTopJoin and newBottomJoin.
|
||||
*/
|
||||
public boolean initJoinOnCondition() {
|
||||
// Ignore join with some OnClause like:
|
||||
// Join C = B + A for above example.
|
||||
// TODO: also need for otherJoinCondition
|
||||
for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) {
|
||||
Set<SlotReference> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
|
||||
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting(
|
||||
topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Set<Slot> newBottomJoinSlots = new HashSet<>(aOutput);
|
||||
newBottomJoinSlots.addAll(cOutput);
|
||||
for (Expression hashConjunct : allHashJoinConjuncts) {
|
||||
Set<SlotReference> slots = hashConjunct.collect(SlotReference.class::isInstance);
|
||||
if (newBottomJoinSlots.containsAll(slots)) {
|
||||
newBottomHashJoinConjuncts.add(hashConjunct);
|
||||
} else {
|
||||
newTopHashJoinConjuncts.add(hashConjunct);
|
||||
}
|
||||
}
|
||||
for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
|
||||
Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance);
|
||||
if (newBottomJoinSlots.containsAll(slots)) {
|
||||
newBottomNonHashJoinConjuncts.add(nonHashConjunct);
|
||||
} else {
|
||||
newTopNonHashJoinConjuncts.add(nonHashConjunct);
|
||||
}
|
||||
}
|
||||
// newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join.
|
||||
// Example:
|
||||
// A: col1, col2. B: col2, col3. C: col3, col4
|
||||
// (A & B on A.col2=B.col2) & C on B.col3=C.col3.
|
||||
// (A & B) & C -> (A & C) & B.
|
||||
// (A & C) will be cross join (newBottomJoinOnCondition is empty)
|
||||
if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Split inside-project into two part.
|
||||
*
|
||||
* @param topJoinChild output of topJoin groupPlan child.
|
||||
*/
|
||||
protected Pair<List<NamedExpression>, List<NamedExpression>> splitProjectExprs(List<SlotReference> topJoinChild) {
|
||||
List<NamedExpression> newTopJoinChildProjectExprs = Lists.newArrayList();
|
||||
List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList();
|
||||
|
||||
HashSet<SlotReference> topJoinOutputSlotsSet = new HashSet<>(topJoinChild);
|
||||
|
||||
for (NamedExpression projectExpr : allProjects) {
|
||||
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
|
||||
if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) {
|
||||
newTopJoinChildProjectExprs.add(projectExpr);
|
||||
} else {
|
||||
newBottomJoinProjectExprs.add(projectExpr);
|
||||
}
|
||||
}
|
||||
return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs);
|
||||
}
|
||||
}
|
||||
@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.plans.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.rules.exploration.JoinReorderContext;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -29,6 +30,7 @@ import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@ -76,6 +78,14 @@ public class ExpressionUtils {
|
||||
}
|
||||
}
|
||||
|
||||
public static Optional<Expression> andByOptional(List<Expression> expressions) {
|
||||
if (expressions.isEmpty()) {
|
||||
return Optional.empty();
|
||||
} else {
|
||||
return Optional.of(ExpressionUtils.and(expressions));
|
||||
}
|
||||
}
|
||||
|
||||
public static Expression and(List<Expression> expressions) {
|
||||
return combine(And.class, expressions);
|
||||
}
|
||||
@ -120,4 +130,16 @@ public class ExpressionUtils {
|
||||
.reduce(type == And.class ? And::new : Or::new)
|
||||
.orElse(new BooleanLiteral(type == And.class));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether lhs and rhs are intersecting.
|
||||
*/
|
||||
public static boolean isIntersecting(Set<SlotReference> lhs, List<SlotReference> rhs) {
|
||||
for (SlotReference rh : rhs) {
|
||||
if (lhs.contains(rh)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
@ -26,8 +26,10 @@ import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -35,7 +37,6 @@ import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class JoinCommuteTest {
|
||||
@ -51,14 +52,23 @@ public class JoinCommuteTest {
|
||||
JoinType.INNER_JOIN, Lists.newArrayList(onCondition),
|
||||
Optional.empty(), scan1, scan2);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(join);
|
||||
Rule rule = new JoinCommute(true).build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
.transform(JoinCommute.OUTER_LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
|
||||
List<Plan> transform = rule.transform(join, cascadesContext);
|
||||
Assertions.assertEquals(1, transform.size());
|
||||
Plan newJoin = transform.get(0);
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
|
||||
|
||||
Assertions.assertEquals(join.child(0), newJoin.child(1));
|
||||
Assertions.assertEquals(join.child(1), newJoin.child(0));
|
||||
GroupExpression newJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
|
||||
Plan left = newJoinGroupExpr.child(0).getLogicalExpression().getPlan();
|
||||
Plan right = newJoinGroupExpr.child(1).getLogicalExpression().getPlan();
|
||||
Assertions.assertTrue(left instanceof LogicalOlapScan);
|
||||
Assertions.assertTrue(right instanceof LogicalOlapScan);
|
||||
|
||||
Assertions.assertEquals("t2", ((LogicalOlapScan) left).getTable().getName());
|
||||
Assertions.assertEquals("t1", ((LogicalOlapScan) right).getTable().getName());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,134 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class JoinLAsscomProjectTest {
|
||||
|
||||
private static final List<LogicalOlapScan> scans = Lists.newArrayList();
|
||||
private static final List<List<SlotReference>> outputs = Lists.newArrayList();
|
||||
|
||||
@BeforeAll
|
||||
public static void init() {
|
||||
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
scans.add(scan1);
|
||||
scans.add(scan2);
|
||||
scans.add(scan3);
|
||||
|
||||
List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
|
||||
List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
|
||||
List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
|
||||
|
||||
outputs.add(t1Output);
|
||||
outputs.add(t2Output);
|
||||
outputs.add(t3Output);
|
||||
}
|
||||
|
||||
private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(List<NamedExpression> projects) {
|
||||
/*
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
* project C newLeftProject newRightProject
|
||||
* / ──► / \
|
||||
* bottomJoin newBottomJoin B
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
|
||||
Assertions.assertEquals(3, scans.size());
|
||||
|
||||
List<SlotReference> t1 = outputs.get(0);
|
||||
List<SlotReference> t2 = outputs.get(1);
|
||||
List<SlotReference> t3 = outputs.get(2);
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
|
||||
|
||||
LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>(
|
||||
projects,
|
||||
new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.empty(), scans.get(0), scans.get(1)));
|
||||
|
||||
LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan> topJoin
|
||||
= new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
|
||||
Optional.empty(), project, scans.get(2));
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
|
||||
Rule rule = new JoinLAsscomProject().build();
|
||||
List<Plan> transform = rule.transform(topJoin, cascadesContext);
|
||||
Assertions.assertEquals(1, transform.size());
|
||||
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
|
||||
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
|
||||
return Pair.of(topJoin, newTopJoin);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStarJoinProjectLAsscom() {
|
||||
List<SlotReference> t1 = outputs.get(0);
|
||||
List<SlotReference> t2 = outputs.get(1);
|
||||
List<NamedExpression> projects = ImmutableList.of(
|
||||
new Alias(t2.get(0), "t2.id"),
|
||||
new Alias(t1.get(0), "t1.id"),
|
||||
t1.get(1),
|
||||
t2.get(1)
|
||||
);
|
||||
|
||||
Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(projects);
|
||||
|
||||
LogicalJoin oldJoin = pair.first;
|
||||
LogicalJoin newTopJoin = pair.second;
|
||||
|
||||
// Join reorder successfully.
|
||||
Assertions.assertNotEquals(oldJoin, newTopJoin);
|
||||
Assertions.assertEquals("t1.id", ((Alias) ((LogicalProject) newTopJoin.left()).getProjects().get(0)).getName());
|
||||
Assertions.assertEquals("name",
|
||||
((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName());
|
||||
Assertions.assertEquals("t2.id",
|
||||
((Alias) ((LogicalProject) newTopJoin.right()).getProjects().get(0)).getName());
|
||||
Assertions.assertEquals("name",
|
||||
((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName());
|
||||
|
||||
}
|
||||
}
|
||||
@ -17,24 +17,23 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
@ -42,55 +41,13 @@ import java.util.Optional;
|
||||
|
||||
public class JoinLAsscomTest {
|
||||
|
||||
private static List<LogicalOlapScan> scans = Lists.newArrayList();
|
||||
private static List<List<SlotReference>> outputs = Lists.newArrayList();
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@BeforeAll
|
||||
public static void init() {
|
||||
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
scans.add(scan1);
|
||||
scans.add(scan2);
|
||||
scans.add(scan3);
|
||||
|
||||
List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
|
||||
List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
|
||||
List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
|
||||
outputs.add(t1Output);
|
||||
outputs.add(t2Output);
|
||||
outputs.add(t3Output);
|
||||
}
|
||||
|
||||
public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(
|
||||
Expression bottomJoinOnCondition,
|
||||
Expression bottomNonHashExpression,
|
||||
Expression topJoinOnCondition,
|
||||
Expression topNonHashExpression) {
|
||||
/*
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
* bottomJoin C --> newBottomJoin B
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
Assertions.assertEquals(3, scans.size());
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.of(bottomNonHashExpression), scans.get(0), scans.get(1));
|
||||
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
|
||||
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
|
||||
Optional.of(topNonHashExpression), bottomJoin, scans.get(2));
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
|
||||
Rule rule = new JoinLAsscom().build();
|
||||
List<Plan> transform = rule.transform(topJoin, cascadesContext);
|
||||
Assertions.assertEquals(1, transform.size());
|
||||
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
|
||||
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
|
||||
return Pair.of(topJoin, newTopJoin);
|
||||
}
|
||||
private final List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
|
||||
private final List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
|
||||
private final List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
|
||||
|
||||
@Test
|
||||
public void testStarJoinLAsscom() {
|
||||
@ -109,31 +66,35 @@ public class JoinLAsscomTest {
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
|
||||
List<SlotReference> t1 = outputs.get(0);
|
||||
List<SlotReference> t2 = outputs.get(1);
|
||||
List<SlotReference> t3 = outputs.get(2);
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
|
||||
Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
|
||||
Expression topNonHashCondition = new LessThan(t1.get(1), t3.get(1));
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1));
|
||||
|
||||
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(
|
||||
bottomJoinOnCondition,
|
||||
bottomNonHashExpression,
|
||||
topJoinOnCondition,
|
||||
topNonHashCondition);
|
||||
LogicalJoin oldJoin = pair.first;
|
||||
LogicalJoin newTopJoin = pair.second;
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.empty(), scan1, scan2);
|
||||
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
|
||||
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
|
||||
Optional.empty(), bottomJoin, scan3);
|
||||
|
||||
// Join reorder successfully.
|
||||
Assertions.assertNotEquals(oldJoin, newTopJoin);
|
||||
Assertions.assertEquals("t1",
|
||||
((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).left()).getTable().getName());
|
||||
Assertions.assertEquals("t3",
|
||||
((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).right()).getTable().getName());
|
||||
Assertions.assertEquals("t2", ((LogicalOlapScan) newTopJoin.right()).getTable().getName());
|
||||
Assertions.assertEquals(newTopJoin.getOtherJoinCondition(),
|
||||
((LogicalJoin) oldJoin.child(0)).getOtherJoinCondition());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(JoinLAsscom.INNER.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
|
||||
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
|
||||
|
||||
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
|
||||
GroupExpression newBottomJoinGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
|
||||
Plan bottomLeft = newBottomJoinGroupExpr.child(0).getLogicalExpression().getPlan();
|
||||
Plan bottomRight = newBottomJoinGroupExpr.child(1).getLogicalExpression().getPlan();
|
||||
Plan right = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan();
|
||||
|
||||
Assertions.assertEquals("t1", ((LogicalOlapScan) bottomLeft).getTable().getName());
|
||||
Assertions.assertEquals("t3", ((LogicalOlapScan) bottomRight).getTable().getName());
|
||||
Assertions.assertEquals("t2", ((LogicalOlapScan) right).getTable().getName());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -151,27 +112,22 @@ public class JoinLAsscomTest {
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
|
||||
List<SlotReference> t1 = outputs.get(0);
|
||||
List<SlotReference> t2 = outputs.get(1);
|
||||
List<SlotReference> t3 = outputs.get(2);
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
|
||||
Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0));
|
||||
Expression topNonHashExpression = new LessThan(t2.get(0), t3.get(0));
|
||||
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
|
||||
Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0));
|
||||
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
Lists.newArrayList(bottomJoinOnCondition),
|
||||
Optional.empty(), scan1, scan2);
|
||||
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
|
||||
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
|
||||
Optional.empty(), bottomJoin, scan3);
|
||||
|
||||
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(bottomJoinOnCondition, bottomNonHashExpression,
|
||||
topJoinOnCondition, topNonHashExpression);
|
||||
LogicalJoin oldJoin = pair.first;
|
||||
LogicalJoin newTopJoin = pair.second;
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(JoinLAsscom.INNER.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
|
||||
// Join reorder failed.
|
||||
// Chain-Join LAsscom directly will be failed.
|
||||
// After t1 -- t2 -- t3
|
||||
// -- join commute -->
|
||||
// t1 -- t2
|
||||
// |
|
||||
// t3
|
||||
// then, we can LAsscom for this star-join.
|
||||
Assertions.assertEquals(oldJoin, newTopJoin);
|
||||
// TODO: need infer onCondition.
|
||||
Assertions.assertEquals(1, root.getLogicalExpressions().size());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
|
||||
import org.apache.doris.nereids.pattern.GroupExpressionMatching.GroupExpressionIterator;
|
||||
import org.apache.doris.nereids.pattern.MatchingContext;
|
||||
import org.apache.doris.nereids.pattern.PatternDescriptor;
|
||||
import org.apache.doris.nereids.pattern.PatternMatcher;
|
||||
@ -145,10 +144,8 @@ public class PlanChecker {
|
||||
|
||||
public PlanChecker transform(GroupExpression groupExpression, PatternMatcher patternMatcher) {
|
||||
GroupExpressionMatching matchResult = new GroupExpressionMatching(patternMatcher.pattern, groupExpression);
|
||||
GroupExpressionIterator iterator = matchResult.iterator();
|
||||
|
||||
while (iterator.hasNext()) {
|
||||
Plan before = iterator.next();
|
||||
for (Plan before : matchResult) {
|
||||
Plan after = patternMatcher.matchedAction.apply(
|
||||
new MatchingContext(before, patternMatcher.pattern, cascadesContext));
|
||||
if (before != after) {
|
||||
@ -162,6 +159,38 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker transform(Rule rule) {
|
||||
return transform(cascadesContext.getMemo().getRoot(), rule);
|
||||
}
|
||||
|
||||
public PlanChecker transform(Group group, Rule rule) {
|
||||
// copy groupExpressions can prevent ConcurrentModificationException
|
||||
for (GroupExpression logicalExpression : Lists.newArrayList(group.getLogicalExpressions())) {
|
||||
transform(logicalExpression, rule);
|
||||
}
|
||||
|
||||
for (GroupExpression physicalExpression : Lists.newArrayList(group.getPhysicalExpressions())) {
|
||||
transform(physicalExpression, rule);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker transform(GroupExpression groupExpression, Rule rule) {
|
||||
GroupExpressionMatching matchResult = new GroupExpressionMatching(rule.getPattern(), groupExpression);
|
||||
|
||||
for (Plan before : matchResult) {
|
||||
Plan after = rule.transform(before, cascadesContext).get(0);
|
||||
if (before != after) {
|
||||
cascadesContext.getMemo().copyIn(after, before.getGroupExpression().get().getOwnerGroup(), false);
|
||||
}
|
||||
}
|
||||
|
||||
for (Group childGroup : groupExpression.children()) {
|
||||
transform(childGroup, rule);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker matchesFromRoot(PatternDescriptor<? extends Plan> patternDesc) {
|
||||
Memo memo = cascadesContext.getMemo();
|
||||
assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern,
|
||||
|
||||
Reference in New Issue
Block a user