[feature](Nereids): Left deep tree join order. (#12439)

* [feature](Nereids): Left deep tree join order.
This commit is contained in:
jakevin
2022-09-08 15:09:22 +08:00
committed by GitHub
parent 14221adbbd
commit 7c7ac86fe8
15 changed files with 553 additions and 597 deletions

View File

@ -18,7 +18,8 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
@ -32,6 +33,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalPro
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@ -43,8 +45,10 @@ import java.util.List;
*/
public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG)
.add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG)
.add(JoinCommute.OUTER_LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(new MergeConsecutiveProjects())
.build();
public static final List<Rule> REWRITE_RULES = planRuleFactories()
@ -66,6 +70,40 @@ public class RuleSet {
.add(new LogicalEmptyRelationToPhysicalEmptyRelation())
.build();
public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(JoinLAsscom.OUTER)
.add(JoinLAsscomProject.OUTER)
// semi join Transpose ....
.build();
public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_ZIG_ZAG)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(JoinLAsscom.OUTER)
.add(JoinLAsscomProject.OUTER)
// semi join Transpose ....
.build();
public static final List<Rule> BUSHY_TREE_JOIN_REORDER = planRuleFactories()
.add(JoinCommute.OUTER_BUSHY)
// TODO: add more rule
// .add(JoinLeftAssociate.INNER)
// .add(JoinLeftAssociateProject.INNER)
// .add(JoinRightAssociate.INNER)
// .add(JoinRightAssociateProject.INNER)
// .add(JoinExchange.INNER)
// .add(JoinExchangeBothProject.INNER)
// .add(JoinExchangeLeftProject.INNER)
// .add(JoinExchangeRightProject.INNER)
// .add(JoinRightAssociate.OUTER)
.add(JoinLAsscom.OUTER)
// semi join Transpose ....
.build();
public List<Rule> getExplorationRules() {
return EXPLORATION_RULES;
}

View File

@ -17,54 +17,72 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.List;
/**
* Join Commute
*/
@Developing
public class JoinCommute extends OneExplorationRuleFactory {
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP);
public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG);
public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY);
private final boolean swapOuter;
private final SwapType swapType;
public JoinCommute(boolean swapOuter) {
this.swapOuter = swapOuter;
this.swapType = SwapType.ALL;
public JoinCommute(SwapType swapType) {
this.swapType = swapType;
}
public JoinCommute(boolean swapOuter, SwapType swapType) {
this.swapOuter = swapOuter;
this.swapType = swapType;
enum SwapType {
LEFT_DEEP, ZIG_ZAG, BUSHY
}
@Override
public Rule build() {
return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> {
// TODO: add project for mapping column output.
// List<NamedExpression> newOutput = new ArrayList<>(join.getOutput());
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
// }
return innerLogicalJoin()
.when(this::check)
.then(join -> {
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) {
newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
}
// LogicalProject<LogicalJoin> project = new LogicalProject<>(newOutput, newJoin);
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
return JoinReorderCommon.project(new ArrayList<>(join.getOutput()), newJoin).get();
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
}
private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) {
return false;
}
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
}
private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {
// TODO: tmp way to judge bottomJoin
return containJoin(join.left()) || containJoin(join.right());
}
private boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}
}

View File

@ -1,66 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
/**
* Project-Join commute
*/
public class JoinCommuteProject extends OneExplorationRuleFactory {
public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN);
public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG);
private final SwapType swapType;
private final boolean swapOuter;
public JoinCommuteProject(boolean swapOuter) {
this.swapOuter = swapOuter;
this.swapType = SwapType.ALL;
}
public JoinCommuteProject(boolean swapOuter, SwapType swapType) {
this.swapOuter = swapOuter;
this.swapType = swapType;
}
@Override
public Rule build() {
return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinCondition(),
join.right(), join.left(),
join.getJoinReorderContext());
newJoin.getJoinReorderContext().setHasCommute(true);
// if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) {
// newJoin.getJoinReorderContext().setHasCommuteZigZag(true);
// }
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
}
}

View File

@ -17,18 +17,42 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.function.Predicate;
/**
* Rule for change inner join LAsscom (associative and commutive).
*/
@Developing
public class JoinLAsscom extends OneExplorationRuleFactory {
// for inner-inner
public static final JoinLAsscom INNER = new JoinLAsscom(Type.INNER);
// for inner-leftOuter or leftOuter-leftOuter
public static final JoinLAsscom OUTER = new JoinLAsscom(Type.OUTER);
private final Predicate<LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan>> typeChecker;
private final Type type;
/**
* Specify join type.
*/
public JoinLAsscom(Type type) {
this.type = type;
if (type == Type.INNER) {
typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().getJoinType().isInnerJoin();
} else {
typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
Pair.of(join.left().getJoinType(), join.getJoinType()));
}
}
/*
* topJoin newTopJoin
* / \ / \
@ -39,18 +63,14 @@ public class JoinLAsscom extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalJoin(), group())
.when(JoinLAsscomHelper::check)
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin()
&& (join.left().getJoinType().isInnerJoin() || join.left().getJoinType().isLeftOuterJoin()))
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin);
if (!helper.initJoinOnCondition()) {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
.when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left()))
.when(typeChecker)
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left());
if (!helper.initJoinOnCondition()) {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
}

View File

@ -18,29 +18,27 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.ImmutableSet;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Common function for JoinLAsscom
*/
public class JoinLAsscomHelper {
class JoinLAsscomHelper extends ThreeJoinHelper {
/*
* topJoin newTopJoin
* / \ / \
@ -48,209 +46,79 @@ public class JoinLAsscomHelper {
* / \ / \
* A B A C
*/
private final LogicalJoin topJoin;
private final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
private final Plan a;
private final Plan b;
private final Plan c;
private final List<Expression> topHashJoinConjuncts;
private final List<Expression> bottomHashJoinConjuncts;
private final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList();
private final List<SlotReference> aOutputSlots;
private final List<SlotReference> bOutputSlots;
private final List<SlotReference> cOutputSlots;
private final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList();
private final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList();
private final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList();
private final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList();
// Pair<bottomJoin, topJoin>
// newBottomJoin Type = topJoin Type, newTopJoin Type = bottomJoin Type
public static Set<Pair<JoinType, JoinType>> outerSet = ImmutableSet.of(
Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN),
Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN),
Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN));
/**
* Init plan and output.
*/
public JoinLAsscomHelper(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
this.topJoin = topJoin;
this.bottomJoin = bottomJoin;
a = bottomJoin.left();
b = bottomJoin.right();
c = topJoin.right();
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(),
"topJoin hashJoinConjuncts must exist.");
topHashJoinConjuncts = topJoin.getHashJoinConjuncts();
if (topJoin.getOtherJoinCondition().isPresent()) {
allNonHashJoinConjuncts.addAll(
ExpressionUtils.extractConjunction(topJoin.getOtherJoinCondition().get()));
}
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
"bottomJoin onClause must exist.");
bottomHashJoinConjuncts = bottomJoin.getHashJoinConjuncts();
if (bottomJoin.getOtherJoinCondition().isPresent()) {
allNonHashJoinConjuncts.addAll(
ExpressionUtils.extractConjunction(bottomJoin.getOtherJoinCondition().get()));
}
aOutputSlots = Utils.getOutputSlotReference(a);
bOutputSlots = Utils.getOutputSlotReference(b);
cOutputSlots = Utils.getOutputSlotReference(c);
super(topJoin, bottomJoin, bottomJoin.left(), bottomJoin.right(), topJoin.right());
}
public static JoinLAsscomHelper of(LogicalJoin<? extends Plan, GroupPlan> topJoin,
/**
* Create newTopJoin.
*/
public Plan newTopJoin() {
Pair<List<NamedExpression>, List<NamedExpression>> projectPair = splitProjectExprs(bOutput);
List<NamedExpression> newLeftProjectExpr = projectPair.second;
List<NamedExpression> newRightProjectExprs = projectPair.first;
// If add project to B, we should add all slotReference used by hashOnCondition.
// TODO: Does nonHashOnCondition also need to be considered.
Set<SlotReference> onUsedSlotRef = bottomJoin.getHashJoinConjuncts().stream()
.flatMap(expr -> {
Set<SlotReference> usedSlotRefs = expr.collect(SlotReference.class::isInstance);
return usedSlotRefs.stream();
}).filter(Utils.getOutputSlotReference(bottomJoin)::contains).collect(Collectors.toSet());
boolean existRightProject = !newRightProjectExprs.isEmpty();
boolean existLeftProject = !newLeftProjectExpr.isEmpty();
onUsedSlotRef.forEach(slotRef -> {
if (existRightProject && bOutput.contains(slotRef) && !newRightProjectExprs.contains(slotRef)) {
newRightProjectExprs.add(slotRef);
} else if (existLeftProject && aOutput.contains(slotRef) && !newLeftProjectExpr.contains(slotRef)) {
newLeftProjectExpr.add(slotRef);
}
});
if (existLeftProject) {
newLeftProjectExpr.addAll(cOutput);
}
LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(),
newBottomHashJoinConjuncts, ExpressionUtils.andByOptional(newBottomNonHashJoinConjuncts), a, c,
bottomJoin.getJoinReorderContext());
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);
Plan left = JoinReorderCommon.project(newLeftProjectExpr, newBottomJoin).orElse(newBottomJoin);
Plan right = JoinReorderCommon.project(newRightProjectExprs, b).orElse(b);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
newTopHashJoinConjuncts,
ExpressionUtils.andByOptional(newTopNonHashJoinConjuncts), left, right,
topJoin.getJoinReorderContext());
newTopJoin.getJoinReorderContext().setHasLAsscom(true);
return JoinReorderCommon.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get();
}
public static boolean check(Type type, LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
return new JoinLAsscomHelper(topJoin, bottomJoin);
}
/**
* Get the onCondition of newTopJoin and newBottomJoin.
*/
public boolean initJoinOnCondition() {
for (Expression topJoinOnClauseConjunct : topHashJoinConjuncts) {
// Ignore join with some OnClause like:
// Join C = B + A for above example.
Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.getInputSlots();
if (topJoinUsedSlot.containsAll(aOutputSlots)
&& topJoinUsedSlot.containsAll(bOutputSlots)
&& topJoinUsedSlot.containsAll(cOutputSlots)) {
return false;
}
}
List<Expression> allHashJoinConjuncts = Lists.newArrayList();
allHashJoinConjuncts.addAll(topHashJoinConjuncts);
allHashJoinConjuncts.addAll(bottomHashJoinConjuncts);
Set<Slot> newBottomJoinSlots = new HashSet<>(aOutputSlots);
newBottomJoinSlots.addAll(cOutputSlots);
for (Expression hashConjunct : allHashJoinConjuncts) {
Set<Slot> slots = hashConjunct.getInputSlots();
if (newBottomJoinSlots.containsAll(slots)) {
newBottomHashJoinConjuncts.add(hashConjunct);
} else {
newTopHashJoinConjuncts.add(hashConjunct);
}
}
for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance);
if (newBottomJoinSlots.containsAll(slots)) {
newBottomNonHashJoinConjuncts.add(nonHashConjunct);
} else {
newTopNonHashJoinConjuncts.add(nonHashConjunct);
}
}
// newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join.
// Example:
// A: col1, col2. B: col2, col3. C: col3, col4
// (A & B on A.col2=B.col2) & C on B.col3=C.col3.
// (A & B) & C -> (A & C) & B.
// (A & C) will be cross join (newBottomJoinOnCondition is empty)
if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) {
return false;
}
return true;
}
/**
* Get projectExpr of left and right.
* Just for project-inside.
*/
private Pair<List<NamedExpression>, List<NamedExpression>> getProjectExprs() {
Preconditions.checkArgument(topJoin.left() instanceof LogicalProject);
LogicalProject project = (LogicalProject) topJoin.left();
List<NamedExpression> projectExprs = project.getProjects();
List<NamedExpression> newRightProjectExprs = Lists.newArrayList();
List<NamedExpression> newLeftProjectExpr = Lists.newArrayList();
HashSet<SlotReference> bOutputSlotsSet = new HashSet<>(bOutputSlots);
for (NamedExpression projectExpr : projectExprs) {
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
if (bOutputSlotsSet.containsAll(usedSlotRefs)) {
newRightProjectExprs.add(projectExpr);
} else {
newLeftProjectExpr.add(projectExpr);
}
}
return Pair.of(newLeftProjectExpr, newRightProjectExprs);
}
private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() {
Optional<Expression> bottomNonHashExpr;
if (newBottomNonHashJoinConjuncts.isEmpty()) {
bottomNonHashExpr = Optional.empty();
if (type == Type.INNER) {
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom();
} else {
bottomNonHashExpr = Optional.of(ExpressionUtils.and(newBottomNonHashJoinConjuncts));
// hasCommute will cause to lack of OuterJoinAssocRule:Left
return !topJoin.getJoinReorderContext().hasLeftAssociate()
&& !topJoin.getJoinReorderContext().hasRightAssociate()
&& !topJoin.getJoinReorderContext().hasExchange()
&& !bottomJoin.getJoinReorderContext().hasCommute();
}
return new LogicalJoin(
bottomJoin.getJoinType(),
newBottomHashJoinConjuncts,
bottomNonHashExpr,
a, c);
}
/**
* Create topJoin for project-inside.
*/
public LogicalJoin newProjectTopJoin() {
Plan left;
Plan right;
List<NamedExpression> newLeftProjectExpr = getProjectExprs().first;
List<NamedExpression> newRightProjectExprs = getProjectExprs().second;
if (!newLeftProjectExpr.isEmpty()) {
left = new LogicalProject<>(newLeftProjectExpr, newBottomJoin());
} else {
left = newBottomJoin();
}
if (!newRightProjectExprs.isEmpty()) {
right = new LogicalProject<>(newRightProjectExprs, b);
} else {
right = b;
}
Optional<Expression> topNonHashExpr;
if (newTopNonHashJoinConjuncts.isEmpty()) {
topNonHashExpr = Optional.empty();
} else {
topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
}
return new LogicalJoin<>(
topJoin.getJoinType(),
newTopHashJoinConjuncts,
topNonHashExpr,
left, right);
}
/**
* Create topJoin for no-project-inside.
*/
public LogicalJoin newTopJoin() {
// TODO: add column map (use project)
// SlotReference bind() may have solved this problem.
// source: | A | B | C |
// target: | A | C | B |
Optional<Expression> topNonHashExpr;
if (newTopNonHashJoinConjuncts.isEmpty()) {
topNonHashExpr = Optional.empty();
} else {
topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts));
}
return new LogicalJoin(
topJoin.getJoinType(),
newTopHashJoinConjuncts,
topNonHashExpr,
newBottomJoin(), b);
}
public static boolean check(LogicalJoin topJoin) {
if (topJoin.getJoinReorderContext().hasCommute()) {
return false;
}
return true;
}
}

View File

@ -17,18 +17,43 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.function.Predicate;
/**
* Rule for change inner join left associative to right.
* Rule for change inner join LAsscom (associative and commutive).
*/
@Developing
public class JoinLAsscomProject extends OneExplorationRuleFactory {
// for inner-inner
public static final JoinLAsscomProject INNER = new JoinLAsscomProject(Type.INNER);
// for inner-leftOuter or leftOuter-leftOuter
public static final JoinLAsscomProject OUTER = new JoinLAsscomProject(Type.OUTER);
private final Predicate<LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan>> typeChecker;
private final Type type;
/**
* Specify join type.
*/
public JoinLAsscomProject(Type type) {
this.type = type;
if (type == Type.INNER) {
typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().child().getJoinType().isInnerJoin();
} else {
typeChecker = join -> JoinLAsscomHelper.outerSet.contains(
Pair.of(join.left().child().getJoinType(), join.getJoinType()));
}
}
/*
* topJoin newTopJoin
* / \ / \
@ -41,19 +66,15 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalProject(logicalJoin()), group())
.when(JoinLAsscomHelper::check)
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin()
&& (join.left().child().getJoinType().isInnerJoin() || join.left().child().getJoinType()
.isLeftOuterJoin()))
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin);
if (!helper.initJoinOnCondition()) {
return null;
}
return helper.newProjectTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
.when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left().child()))
.when(typeChecker)
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left().child());
helper.initAllProject(topJoin.left());
if (!helper.initJoinOnCondition()) {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}
}

View File

@ -17,32 +17,24 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* Common function for JoinCommute
*/
public class JoinCommuteHelper {
import java.util.List;
import java.util.Optional;
enum SwapType {
BOTTOM_JOIN, ZIG_ZAG, ALL
class JoinReorderCommon {
public enum Type {
INNER,
OUTER
}
private final boolean swapOuter;
private final SwapType swapType;
public JoinCommuteHelper(boolean swapOuter, SwapType swapType) {
this.swapOuter = swapOuter;
this.swapType = swapType;
}
public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) {
return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange();
}
public static boolean check(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return check(project.child());
public static Optional<Plan> project(List<NamedExpression> projectExprs, Plan plan) {
if (!projectExprs.isEmpty()) {
return Optional.of(new LogicalProject<>(projectExprs, plan));
} else {
return Optional.empty();
}
}
}

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration;
package org.apache.doris.nereids.rules.exploration.join;
/**
* JoinReorderContext for Duplicate free.
@ -26,6 +26,7 @@ package org.apache.doris.nereids.rules.exploration;
public class JoinReorderContext {
// left deep tree
private boolean hasCommute = false;
private boolean hasLAsscom = false;
// zig-zag tree
private boolean hasCommuteZigZag = false;
@ -38,16 +39,24 @@ public class JoinReorderContext {
public JoinReorderContext() {
}
/**
* copy a JoinReorderContext.
*/
public void copyFrom(JoinReorderContext joinReorderContext) {
this.hasCommute = joinReorderContext.hasCommute;
this.hasLAsscom = joinReorderContext.hasLAsscom;
this.hasExchange = joinReorderContext.hasExchange;
this.hasLeftAssociate = joinReorderContext.hasLeftAssociate;
this.hasRightAssociate = joinReorderContext.hasRightAssociate;
this.hasCommuteZigZag = joinReorderContext.hasCommuteZigZag;
}
/**
* clear all.
*/
public void clear() {
hasCommute = false;
hasLAsscom = false;
hasCommuteZigZag = false;
hasExchange = false;
hasRightAssociate = false;
@ -62,6 +71,14 @@ public class JoinReorderContext {
this.hasCommute = hasCommute;
}
public boolean hasLAsscom() {
return hasLAsscom;
}
public void setHasLAsscom(boolean hasLAsscom) {
this.hasLAsscom = hasLAsscom;
}
public boolean hasExchange() {
return hasExchange;
}

View File

@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* Common join helper for three-join.
*/
abstract class ThreeJoinHelper {
protected final LogicalJoin<? extends Plan, ? extends Plan> topJoin;
protected final LogicalJoin<GroupPlan, GroupPlan> bottomJoin;
protected final GroupPlan a;
protected final GroupPlan b;
protected final GroupPlan c;
protected final List<SlotReference> aOutput;
protected final List<SlotReference> bOutput;
protected final List<SlotReference> cOutput;
protected final List<NamedExpression> allProjects = Lists.newArrayList();
protected final List<Expression> allHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList();
/**
* Init plan and output.
*/
public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin, GroupPlan a, GroupPlan b, GroupPlan c) {
this.topJoin = topJoin;
this.bottomJoin = bottomJoin;
this.a = a;
this.b = b;
this.c = c;
aOutput = Utils.getOutputSlotReference(a);
bOutput = Utils.getOutputSlotReference(b);
cOutput = Utils.getOutputSlotReference(c);
Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist.");
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
"bottomJoin hashJoinConjuncts must exist.");
allHashJoinConjuncts.addAll(topJoin.getHashJoinConjuncts());
allHashJoinConjuncts.addAll(bottomJoin.getHashJoinConjuncts());
topJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll(
ExpressionUtils.extractConjunction(otherJoinCondition)));
bottomJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll(
ExpressionUtils.extractConjunction(otherJoinCondition)));
}
@SafeVarargs
public final void initAllProject(LogicalProject<? extends Plan>... projects) {
for (LogicalProject<? extends Plan> project : projects) {
allProjects.addAll(project.getProjects());
}
}
/**
* Get the onCondition of newTopJoin and newBottomJoin.
*/
public boolean initJoinOnCondition() {
// Ignore join with some OnClause like:
// Join C = B + A for above example.
// TODO: also need for otherJoinCondition
for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) {
Set<SlotReference> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting(
topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
return false;
}
}
Set<Slot> newBottomJoinSlots = new HashSet<>(aOutput);
newBottomJoinSlots.addAll(cOutput);
for (Expression hashConjunct : allHashJoinConjuncts) {
Set<SlotReference> slots = hashConjunct.collect(SlotReference.class::isInstance);
if (newBottomJoinSlots.containsAll(slots)) {
newBottomHashJoinConjuncts.add(hashConjunct);
} else {
newTopHashJoinConjuncts.add(hashConjunct);
}
}
for (Expression nonHashConjunct : allNonHashJoinConjuncts) {
Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance);
if (newBottomJoinSlots.containsAll(slots)) {
newBottomNonHashJoinConjuncts.add(nonHashConjunct);
} else {
newTopNonHashJoinConjuncts.add(nonHashConjunct);
}
}
// newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join.
// Example:
// A: col1, col2. B: col2, col3. C: col3, col4
// (A & B on A.col2=B.col2) & C on B.col3=C.col3.
// (A & B) & C -> (A & C) & B.
// (A & C) will be cross join (newBottomJoinOnCondition is empty)
if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) {
return false;
}
return true;
}
/**
* Split inside-project into two part.
*
* @param topJoinChild output of topJoin groupPlan child.
*/
protected Pair<List<NamedExpression>, List<NamedExpression>> splitProjectExprs(List<SlotReference> topJoinChild) {
List<NamedExpression> newTopJoinChildProjectExprs = Lists.newArrayList();
List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList();
HashSet<SlotReference> topJoinOutputSlotsSet = new HashSet<>(topJoinChild);
for (NamedExpression projectExpr : allProjects) {
Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance);
if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) {
newTopJoinChildProjectExprs.add(projectExpr);
} else {
newBottomJoinProjectExprs.add(projectExpr);
}
}
return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs);
}
}

View File

@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.rules.exploration.JoinReorderContext;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import com.google.common.base.Preconditions;
@ -29,6 +30,7 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
@ -76,6 +78,14 @@ public class ExpressionUtils {
}
}
public static Optional<Expression> andByOptional(List<Expression> expressions) {
if (expressions.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(ExpressionUtils.and(expressions));
}
}
public static Expression and(List<Expression> expressions) {
return combine(And.class, expressions);
}
@ -120,4 +130,16 @@ public class ExpressionUtils {
.reduce(type == And.class ? And::new : Or::new)
.orElse(new BooleanLiteral(type == And.class));
}
/**
* Check whether lhs and rhs are intersecting.
*/
public static boolean isIntersecting(Set<SlotReference> lhs, List<SlotReference> rhs) {
for (SlotReference rh : rhs) {
if (lhs.contains(rh)) {
return true;
}
}
return false;
}
}

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@ -26,8 +26,10 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
@ -35,7 +37,6 @@ import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
public class JoinCommuteTest {
@ -51,14 +52,23 @@ public class JoinCommuteTest {
JoinType.INNER_JOIN, Lists.newArrayList(onCondition),
Optional.empty(), scan1, scan2);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(join);
Rule rule = new JoinCommute(true).build();
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.transform(JoinCommute.OUTER_LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
List<Plan> transform = rule.transform(join, cascadesContext);
Assertions.assertEquals(1, transform.size());
Plan newJoin = transform.get(0);
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
Assertions.assertEquals(join.child(0), newJoin.child(1));
Assertions.assertEquals(join.child(1), newJoin.child(0));
GroupExpression newJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
Plan left = newJoinGroupExpr.child(0).getLogicalExpression().getPlan();
Plan right = newJoinGroupExpr.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(left instanceof LogicalOlapScan);
Assertions.assertTrue(right instanceof LogicalOlapScan);
Assertions.assertEquals("t2", ((LogicalOlapScan) left).getTable().getName());
Assertions.assertEquals("t1", ((LogicalOlapScan) right).getTable().getName());
});
}
}

View File

@ -1,134 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
public class JoinLAsscomProjectTest {
private static final List<LogicalOlapScan> scans = Lists.newArrayList();
private static final List<List<SlotReference>> outputs = Lists.newArrayList();
@BeforeAll
public static void init() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
scans.add(scan1);
scans.add(scan2);
scans.add(scan3);
List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
outputs.add(t1Output);
outputs.add(t2Output);
outputs.add(t3Output);
}
private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(List<NamedExpression> projects) {
/*
* topJoin newTopJoin
* / \ / \
* project C newLeftProject newRightProject
* / ──► / \
* bottomJoin newBottomJoin B
* / \ / \
* A B A C
*/
Assertions.assertEquals(3, scans.size());
List<SlotReference> t1 = outputs.get(0);
List<SlotReference> t2 = outputs.get(1);
List<SlotReference> t3 = outputs.get(2);
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>(
projects,
new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(bottomJoinOnCondition),
Optional.empty(), scans.get(0), scans.get(1)));
LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan> topJoin
= new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
Optional.empty(), project, scans.get(2));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
Rule rule = new JoinLAsscomProject().build();
List<Plan> transform = rule.transform(topJoin, cascadesContext);
Assertions.assertEquals(1, transform.size());
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
return Pair.of(topJoin, newTopJoin);
}
@Test
public void testStarJoinProjectLAsscom() {
List<SlotReference> t1 = outputs.get(0);
List<SlotReference> t2 = outputs.get(1);
List<NamedExpression> projects = ImmutableList.of(
new Alias(t2.get(0), "t2.id"),
new Alias(t1.get(0), "t1.id"),
t1.get(1),
t2.get(1)
);
Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(projects);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;
// Join reorder successfully.
Assertions.assertNotEquals(oldJoin, newTopJoin);
Assertions.assertEquals("t1.id", ((Alias) ((LogicalProject) newTopJoin.left()).getProjects().get(0)).getName());
Assertions.assertEquals("name",
((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName());
Assertions.assertEquals("t2.id",
((Alias) ((LogicalProject) newTopJoin.right()).getProjects().get(0)).getName());
Assertions.assertEquals("name",
((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName());
}
}

View File

@ -17,24 +17,23 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -42,55 +41,13 @@ import java.util.Optional;
public class JoinLAsscomTest {
private static List<LogicalOlapScan> scans = Lists.newArrayList();
private static List<List<SlotReference>> outputs = Lists.newArrayList();
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@BeforeAll
public static void init() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
scans.add(scan1);
scans.add(scan2);
scans.add(scan3);
List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
outputs.add(t1Output);
outputs.add(t2Output);
outputs.add(t3Output);
}
public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(
Expression bottomJoinOnCondition,
Expression bottomNonHashExpression,
Expression topJoinOnCondition,
Expression topNonHashExpression) {
/*
* topJoin newTopJoin
* / \ / \
* bottomJoin C --> newBottomJoin B
* / \ / \
* A B A C
*/
Assertions.assertEquals(3, scans.size());
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
Optional.of(bottomNonHashExpression), scans.get(0), scans.get(1));
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
Optional.of(topNonHashExpression), bottomJoin, scans.get(2));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
Rule rule = new JoinLAsscom().build();
List<Plan> transform = rule.transform(topJoin, cascadesContext);
Assertions.assertEquals(1, transform.size());
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
return Pair.of(topJoin, newTopJoin);
}
private final List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1);
private final List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2);
private final List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3);
@Test
public void testStarJoinLAsscom() {
@ -109,31 +66,35 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
List<SlotReference> t1 = outputs.get(0);
List<SlotReference> t2 = outputs.get(1);
List<SlotReference> t3 = outputs.get(2);
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0));
Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
Expression topNonHashCondition = new LessThan(t1.get(1), t3.get(1));
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1));
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(
bottomJoinOnCondition,
bottomNonHashExpression,
topJoinOnCondition,
topNonHashCondition);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
Optional.empty(), scan1, scan2);
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
Optional.empty(), bottomJoin, scan3);
// Join reorder successfully.
Assertions.assertNotEquals(oldJoin, newTopJoin);
Assertions.assertEquals("t1",
((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).left()).getTable().getName());
Assertions.assertEquals("t3",
((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).right()).getTable().getName());
Assertions.assertEquals("t2", ((LogicalOlapScan) newTopJoin.right()).getTable().getName());
Assertions.assertEquals(newTopJoin.getOtherJoinCondition(),
((LogicalJoin) oldJoin.child(0)).getOtherJoinCondition());
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(JoinLAsscom.INNER.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin);
Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject);
GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression();
GroupExpression newBottomJoinGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression();
Plan bottomLeft = newBottomJoinGroupExpr.child(0).getLogicalExpression().getPlan();
Plan bottomRight = newBottomJoinGroupExpr.child(1).getLogicalExpression().getPlan();
Plan right = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan();
Assertions.assertEquals("t1", ((LogicalOlapScan) bottomLeft).getTable().getName());
Assertions.assertEquals("t3", ((LogicalOlapScan) bottomRight).getTable().getName());
Assertions.assertEquals("t2", ((LogicalOlapScan) right).getTable().getName());
});
}
@Test
@ -151,27 +112,22 @@ public class JoinLAsscomTest {
* t1 t2 t1 t3
*/
List<SlotReference> t1 = outputs.get(0);
List<SlotReference> t2 = outputs.get(1);
List<SlotReference> t3 = outputs.get(2);
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0));
Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0));
Expression topNonHashExpression = new LessThan(t2.get(0), t3.get(0));
Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0));
Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0));
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(bottomJoinOnCondition),
Optional.empty(), scan1, scan2);
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
Optional.empty(), bottomJoin, scan3);
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(bottomJoinOnCondition, bottomNonHashExpression,
topJoinOnCondition, topNonHashExpression);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(JoinLAsscom.INNER.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
// Join reorder failed.
// Chain-Join LAsscom directly will be failed.
// After t1 -- t2 -- t3
// -- join commute -->
// t1 -- t2
// |
// t3
// then, we can LAsscom for this star-join.
Assertions.assertEquals(oldJoin, newTopJoin);
// TODO: need infer onCondition.
Assertions.assertEquals(1, root.getLogicalExpressions().size());
});
}
}

View File

@ -23,7 +23,6 @@ import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.GroupExpressionMatching.GroupExpressionIterator;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.pattern.PatternMatcher;
@ -145,10 +144,8 @@ public class PlanChecker {
public PlanChecker transform(GroupExpression groupExpression, PatternMatcher patternMatcher) {
GroupExpressionMatching matchResult = new GroupExpressionMatching(patternMatcher.pattern, groupExpression);
GroupExpressionIterator iterator = matchResult.iterator();
while (iterator.hasNext()) {
Plan before = iterator.next();
for (Plan before : matchResult) {
Plan after = patternMatcher.matchedAction.apply(
new MatchingContext(before, patternMatcher.pattern, cascadesContext));
if (before != after) {
@ -162,6 +159,38 @@ public class PlanChecker {
return this;
}
public PlanChecker transform(Rule rule) {
return transform(cascadesContext.getMemo().getRoot(), rule);
}
public PlanChecker transform(Group group, Rule rule) {
// copy groupExpressions can prevent ConcurrentModificationException
for (GroupExpression logicalExpression : Lists.newArrayList(group.getLogicalExpressions())) {
transform(logicalExpression, rule);
}
for (GroupExpression physicalExpression : Lists.newArrayList(group.getPhysicalExpressions())) {
transform(physicalExpression, rule);
}
return this;
}
public PlanChecker transform(GroupExpression groupExpression, Rule rule) {
GroupExpressionMatching matchResult = new GroupExpressionMatching(rule.getPattern(), groupExpression);
for (Plan before : matchResult) {
Plan after = rule.transform(before, cascadesContext).get(0);
if (before != after) {
cascadesContext.getMemo().copyIn(after, before.getGroupExpression().get().getOwnerGroup(), false);
}
}
for (Group childGroup : groupExpression.children()) {
transform(childGroup, rule);
}
return this;
}
public PlanChecker matchesFromRoot(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern,