[enhance](Nereids): handle project of OuterJoin in Reorder. (#19137)

This commit is contained in:
jakevin
2023-04-27 22:17:03 +08:00
committed by GitHub
parent 0f895640d9
commit a35fc02bd4
5 changed files with 17 additions and 121 deletions

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.rules.exploration.join.JoinExchangeBothProject;
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssocProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin;
@ -164,6 +165,7 @@ public class RuleSet {
.add(JoinExchange.INSTANCE)
.add(JoinExchangeBothProject.INSTANCE)
.add(OuterJoinAssoc.INSTANCE)
.add(OuterJoinAssocProject.INSTANCE)
.build();
public List<Rule> getOtherReorderRules() {

View File

@ -35,19 +35,6 @@ import java.util.stream.Collectors;
* Common
*/
public class CBOUtils {
/**
* Split project according to whether namedExpr contains by splitChildExprIds.
* Notice: projects must all be Slot.
*/
public static Map<Boolean, List<NamedExpression>> splitProject(List<NamedExpression> projects,
Set<ExprId> splitChildExprIds) {
return projects.stream()
.collect(Collectors.partitioningBy(expr -> {
Slot slot = (Slot) expr;
return splitChildExprIds.contains(slot.getExprId());
}));
}
/**
* If projects is empty or project output equal plan output, return the original plan.
*/
@ -58,23 +45,6 @@ public class CBOUtils {
return new LogicalProject<>(projects, plan);
}
/**
* When project not empty, we add all slots used by hashOnCondition into projects.
*/
public static void addSlotsUsedByOn(Set<Slot> usedSlots, List<NamedExpression> projects) {
if (projects.isEmpty()) {
return;
}
Set<ExprId> projectExprIdSet = projects.stream()
.map(NamedExpression::getExprId)
.collect(Collectors.toSet());
usedSlots.forEach(slot -> {
if (!projectExprIdSet.contains(slot.getExprId())) {
projects.add(slot);
}
});
}
public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? extends Plan> join, boolean left) {
Set<Slot> childSlots = left ? join.left().getOutputSet() : join.right().getOutputSet();
return join.getConditionSlot().stream()

View File

@ -24,23 +24,18 @@ import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* OuterJoinAssocProject.
@ -68,12 +63,10 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
.thenApply(ctx -> {
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin = ctx.root;
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();
/*
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
@ -92,39 +85,15 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
}
}
/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
List<NamedExpression> aProjects = map.get(true);
List<NamedExpression> bProjects = map.get(false);
if (bProjects.isEmpty()) {
return null;
}
Set<ExprId> aProjectsExprIds = aProjects.stream().map(NamedExpression::getExprId)
.collect(Collectors.toSet());
// topJoin condition can't contain aProject. just can (B C)
if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream())
.anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), aProjectsExprIds))) {
return null;
}
// Add all slots used by OnCondition when projects not empty.
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet()));
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);
bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet()));
/* ********** new Plan ********** */
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
Plan left = CBOUtils.projectOrSelf(aProjects, a);
Plan right = CBOUtils.projectOrSelf(bProjects, newBottomJoin);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, a);
Plan right = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());

View File

@ -92,7 +92,7 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
* <p>
* Same with OtherJoinConjunct.
*/
private boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<ExprId> bOutputExprIdSet) {
public static boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<ExprId> bOutputExprIdSet) {
return Stream.concat(
topJoin.getHashJoinConjuncts().stream(),
topJoin.getOtherJoinConjuncts().stream())

View File

@ -23,20 +23,13 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Rule for change inner join LAsscom (associative and commutive).
@ -61,43 +54,15 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(topJoin -> OuterJoinLAsscom.checkCondition(topJoin,
topJoin.left().child().right().getOutputExprIdSet()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();
/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
List<NamedExpression> aProjects = map.get(true);
if (aProjects.isEmpty()) {
return null;
}
List<NamedExpression> bProjects = map.get(false);
Set<ExprId> bProjectsExprIds = bProjects.stream().map(NamedExpression::getExprId)
.collect(Collectors.toSet());
// topJoin condition can't contain bProject output. just can (A C)
if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream())
.anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), bProjectsExprIds))) {
return null;
}
// Add all slots used by OnCondition when projects not empty.
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet()));
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);
aProjects.addAll(forceToNullable(c.getOutputSet()));
/* ********** new Plan ********** */
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(a, c);
@ -105,8 +70,11 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);
Plan left = CBOUtils.projectOrSelf(aProjects, newBottomJoin);
Plan right = CBOUtils.projectOrSelf(bProjects, b);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, b);
LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
@ -115,17 +83,4 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin);
}).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT);
}
/**
* Force all slots in set to nullable.
*/
public static Set<Slot> forceToNullable(Set<Slot> slotSet) {
return slotSet.stream().map(s -> (Slot) s.rewriteUp(e -> {
if (e instanceof SlotReference) {
return ((SlotReference) e).withNullable(true);
} else {
return e;
}
})).collect(Collectors.toSet());
}
}