[enhancement](nereids) allow reorder mark join (#30644)
This commit is contained in:
@ -47,7 +47,7 @@ public class CollectJoinConstraint implements RewriteRuleFactory {
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
logicalJoin().whenNot(LogicalJoin::isMarkJoin).thenApply(ctx -> {
|
||||
logicalJoin().thenApply(ctx -> {
|
||||
if (!ctx.cascadesContext.isLeadingJoin()) {
|
||||
return ctx.root;
|
||||
}
|
||||
|
||||
@ -56,7 +56,6 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(innerLogicalJoin(), group())
|
||||
.when(topJoin -> checkReorder(topJoin, topJoin.left(), leftZigZag))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
@ -104,12 +103,10 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
|
||||
double bRows = bottomJoin.right().getGroup().getStatistics().getRowCount();
|
||||
double cRows = topJoin.right().getGroup().getStatistics().getRowCount();
|
||||
return bRows < cRows && !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
|
||||
&& !topJoin.getJoinReorderContext().hasLAsscom()
|
||||
&& (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin());
|
||||
&& !topJoin.getJoinReorderContext().hasLAsscom();
|
||||
} else {
|
||||
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
|
||||
&& !topJoin.getJoinReorderContext().hasLAsscom()
|
||||
&& (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin());
|
||||
&& !topJoin.getJoinReorderContext().hasLAsscom();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -62,7 +62,6 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
|
||||
.when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child(), enableLeftZigZag))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
/* ********** init ********** */
|
||||
|
||||
@ -51,7 +51,6 @@ public class InnerJoinLeftAssociate extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(group(), innerLogicalJoin())
|
||||
.when(InnerJoinLeftAssociate::checkReorder)
|
||||
.whenNot(join -> join.hasDistributeHint() || join.right().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.right().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right();
|
||||
GroupPlan a = topJoin.left();
|
||||
|
||||
@ -52,7 +52,6 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(group(), logicalProject(innerLogicalJoin()))
|
||||
.when(InnerJoinLeftAssociate::checkReorder)
|
||||
.whenNot(join -> join.hasDistributeHint() || join.right().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.right().child().isMarkJoin())
|
||||
.when(join -> join.right().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();
|
||||
|
||||
@ -49,7 +49,6 @@ public class InnerJoinRightAssociate extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(innerLogicalJoin(), group())
|
||||
.when(InnerJoinRightAssociate::checkReorder)
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
|
||||
@ -50,7 +50,6 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory {
|
||||
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
|
||||
.when(InnerJoinRightAssociate::checkReorder)
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
|
||||
|
||||
@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContain
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.thrift.TRuntimeFilterType;
|
||||
|
||||
@ -59,7 +60,11 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
.when(join -> check(swapType, join))
|
||||
.whenNot(LogicalJoin::hasDistributeHint)
|
||||
.whenNot(join -> joinOrderMatchBitmapRuntimeFilterOrder(join))
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
// null aware mark join will be translated to null aware left semi/anti join
|
||||
// we don't support null aware right semi/anti join, so should not commute
|
||||
.whenNot(join -> JoinUtils.isNullAwareMarkJoin(join))
|
||||
// commuting nest loop mark join is not supported by be
|
||||
.whenNot(join -> join.isMarkJoin() && join.getHashJoinConjuncts().isEmpty())
|
||||
.then(join -> {
|
||||
LogicalJoin<Plan, Plan> newJoin = join.withTypeChildren(join.getJoinType().swap(),
|
||||
join.right(), join.left(), null);
|
||||
|
||||
@ -56,7 +56,6 @@ public class JoinExchange extends OneExplorationRuleFactory {
|
||||
.when(JoinExchange::checkReorder)
|
||||
.whenNot(join -> join.hasDistributeHint()
|
||||
|| join.left().hasDistributeHint() || join.right().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left();
|
||||
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right();
|
||||
|
||||
@ -59,7 +59,6 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory {
|
||||
.when(join -> join.left().isAllSlots() && join.right().isAllSlots())
|
||||
.whenNot(join -> join.hasDistributeHint()
|
||||
|| join.left().child().hasDistributeHint() || join.right().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().child().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left().child();
|
||||
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right().child();
|
||||
|
||||
@ -59,7 +59,6 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory {
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.whenNot(join -> join.hasDistributeHint()
|
||||
|| join.left().child().hasDistributeHint() || join.right().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left().child();
|
||||
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right();
|
||||
|
||||
@ -59,7 +59,6 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory {
|
||||
.when(join -> join.right().isAllSlots())
|
||||
.whenNot(join -> join.hasDistributeHint()
|
||||
|| join.left().hasDistributeHint() || join.right().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().child().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left();
|
||||
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right().child();
|
||||
|
||||
@ -43,9 +43,7 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory {
|
||||
.when(topJoin -> (topJoin.left().getJoinType().isLeftSemiOrAntiJoin()
|
||||
&& (topJoin.getJoinType().isInnerJoin()
|
||||
|| topJoin.getJoinType().isLeftOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint()
|
||||
|| topJoin.left().isMarkJoin())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
@ -60,9 +58,7 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory {
|
||||
.when(topJoin -> (topJoin.right().getJoinType().isLeftSemiOrAntiJoin()
|
||||
&& (topJoin.getJoinType().isInnerJoin()
|
||||
|| topJoin.getJoinType().isRightOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.right().hasDistributeHint()
|
||||
|| topJoin.right().isMarkJoin())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.right().hasDistributeHint())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right();
|
||||
GroupPlan a = topJoin.left();
|
||||
|
||||
@ -44,9 +44,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
|
||||
&& (topJoin.getJoinType().isInnerJoin()
|
||||
|| topJoin.getJoinType().isLeftOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint()
|
||||
|| topJoin.left().child().hasDistributeHint()
|
||||
|| topJoin.left().child().isMarkJoin())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
|| topJoin.left().child().hasDistributeHint())
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
|
||||
@ -66,8 +64,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
|
||||
&& (topJoin.getJoinType().isInnerJoin()
|
||||
|| topJoin.getJoinType().isRightOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint()
|
||||
|| topJoin.right().child().hasDistributeHint()
|
||||
|| topJoin.right().child().isMarkJoin())
|
||||
|| topJoin.right().child().hasDistributeHint())
|
||||
.when(join -> join.right().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();
|
||||
|
||||
@ -59,7 +59,6 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory {
|
||||
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
|
||||
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left()))
|
||||
.when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet()))
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
|
||||
.thenApply(ctx -> {
|
||||
LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin = ctx.root;
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
|
||||
@ -57,7 +57,6 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
|
||||
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
|
||||
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.thenApply(ctx -> {
|
||||
|
||||
@ -64,7 +64,6 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
|
||||
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
|
||||
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
|
||||
@ -54,7 +54,6 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
|
||||
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
|
||||
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(topJoin -> OuterJoinLAsscom.checkCondition(topJoin,
|
||||
topJoin.left().child().right().getOutputExprIdSet()))
|
||||
.when(join -> join.left().isAllSlots())
|
||||
|
||||
@ -56,7 +56,7 @@ public class PushDownProjectThroughInnerOuterJoin implements ExplorationRuleFact
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
logicalJoin(logicalProject(logicalJoin().whenNot(LogicalJoin::isMarkJoin)), group())
|
||||
logicalJoin(logicalProject(logicalJoin()), group())
|
||||
.when(j -> j.left().child().getJoinType().isOuterJoin()
|
||||
|| j.left().child().getJoinType().isInnerJoin())
|
||||
// Just pushdown project with non-column expr like (t.id + 1)
|
||||
@ -70,7 +70,7 @@ public class PushDownProjectThroughInnerOuterJoin implements ExplorationRuleFact
|
||||
}
|
||||
return topJoin.withChildren(newLeft, topJoin.right());
|
||||
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN_LEFT),
|
||||
logicalJoin(group(), logicalProject(logicalJoin().whenNot(LogicalJoin::isMarkJoin)))
|
||||
logicalJoin(group(), logicalProject(logicalJoin()))
|
||||
.when(j -> j.right().child().getJoinType().isOuterJoin()
|
||||
|| j.right().child().getJoinType().isInnerJoin())
|
||||
// Just pushdown project with non-column expr like (t.id + 1)
|
||||
|
||||
@ -63,7 +63,6 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
|
||||
return logicalJoin(logicalJoin(), group())
|
||||
.when(this::typeChecker)
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.CBOUtils;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
@ -31,8 +32,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* rule for semi-semi transpose
|
||||
@ -55,7 +56,6 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
|
||||
.when(this::typeChecker)
|
||||
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child(), false))
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.then(topSemi -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
|
||||
@ -64,7 +64,9 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
|
||||
GroupPlan b = bottomSemi.right();
|
||||
GroupPlan c = topSemi.right();
|
||||
Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
|
||||
Set<NamedExpression> acProjects = new HashSet<NamedExpression>(abProject.getProjects());
|
||||
Set<NamedExpression> acProjects = (Set<NamedExpression>) abProject.getProjects()
|
||||
.stream().filter(slot -> !(slot instanceof MarkJoinSlotReference))
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
bottomSemi.getConditionSlot()
|
||||
.forEach(slot -> {
|
||||
@ -73,6 +75,9 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
|
||||
}
|
||||
});
|
||||
LogicalJoin newBottomSemi = topSemi.withChildrenNoContext(a, c, null);
|
||||
if (topSemi.isMarkJoin()) {
|
||||
acProjects.add(topSemi.getMarkJoinSlotReference().get());
|
||||
}
|
||||
newBottomSemi.getJoinReorderContext().copyFrom(bottomSemi.getJoinReorderContext());
|
||||
newBottomSemi.getJoinReorderContext().setHasCommute(false);
|
||||
newBottomSemi.getJoinReorderContext().setHasLAsscom(false);
|
||||
|
||||
@ -21,12 +21,16 @@ import org.apache.doris.nereids.annotation.DependsRules;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Push the predicate in the LogicalFilter to the join children.
|
||||
@ -39,13 +43,18 @@ public class PushFilterInsideJoin extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFilter(logicalJoin())
|
||||
.whenNot(filter -> filter.child().isMarkJoin())
|
||||
// TODO: current just handle cross/inner join.
|
||||
.when(filter -> filter.child().getJoinType().isCrossJoin()
|
||||
|| filter.child().getJoinType().isInnerJoin())
|
||||
.then(filter -> {
|
||||
List<Expression> otherConditions = Lists.newArrayList(filter.getConjuncts());
|
||||
LogicalJoin<Plan, Plan> join = filter.child();
|
||||
Set<Slot> childOutput = join.getOutputSet();
|
||||
if (ExpressionUtils.getInputSlotSet(otherConditions).stream()
|
||||
.filter(MarkJoinSlotReference.class::isInstance)
|
||||
.anyMatch(slot -> childOutput.contains(slot))) {
|
||||
return null;
|
||||
}
|
||||
otherConditions.addAll(join.getOtherJoinConjuncts());
|
||||
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
|
||||
otherConditions, join.getDistributeHint(), join.getMarkJoinSlotReference(),
|
||||
|
||||
@ -35,7 +35,6 @@ public class SemiJoinCommute extends OneRewriteRuleFactory {
|
||||
.whenNot(join -> ConnectContext.get().getSessionVariable().isDisableJoinReorder())
|
||||
.whenNot(join -> join.isLeadingJoin())
|
||||
.whenNot(LogicalJoin::hasDistributeHint)
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.then(join -> join.withTypeChildren(join.getJoinType().swap(), join.right(), join.left(), null))
|
||||
.toRule(RuleType.LOGICAL_SEMI_JOIN_COMMUTE);
|
||||
}
|
||||
|
||||
@ -44,7 +44,6 @@ public class TransposeSemiJoinLogicalJoin extends OneRewriteRuleFactory {
|
||||
|| topJoin.left().getJoinType().isLeftOuterJoin()
|
||||
|| topJoin.left().getJoinType().isRightOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.whenNot(topJoin -> topJoin.isLeadingJoin() || topJoin.left().isLeadingJoin())
|
||||
.then(topSemiJoin -> {
|
||||
LogicalJoin<Plan, Plan> bottomJoin = topSemiJoin.left();
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -29,6 +30,7 @@ import org.apache.doris.nereids.util.Utils;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@ -49,7 +51,6 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
|
||||
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.whenNot(topJoin -> topJoin.isLeadingJoin() || topJoin.left().child().isLeadingJoin())
|
||||
.when(join -> join.left().getProjects().stream().allMatch(expr -> expr instanceof Slot))
|
||||
.then(topSemiJoin -> {
|
||||
@ -65,7 +66,8 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
|
||||
if (containsType == ContainsType.ALL) {
|
||||
return null;
|
||||
}
|
||||
|
||||
ImmutableList<NamedExpression> topProjects = topSemiJoin.getOutput().stream()
|
||||
.map(slot -> (NamedExpression) slot).collect(ImmutableList.toImmutableList());
|
||||
if (containsType == ContainsType.LEFT) {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
@ -85,7 +87,7 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
|
||||
|
||||
Plan newBottomSemiJoin = topSemiJoin.withChildren(a, c);
|
||||
Plan newTopJoin = bottomJoin.withChildren(newBottomSemiJoin, b);
|
||||
return project.withChildren(newTopJoin);
|
||||
return project.withProjectsAndChild(topProjects, newTopJoin);
|
||||
} else {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
@ -105,7 +107,7 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
|
||||
|
||||
Plan newBottomSemiJoin = topSemiJoin.withChildren(b, c);
|
||||
Plan newTopJoin = bottomJoin.withChildren(a, newBottomSemiJoin);
|
||||
return project.withChildren(newTopJoin);
|
||||
return project.withProjectsAndChild(topProjects, newTopJoin);
|
||||
}
|
||||
}).toRule(RuleType.TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT);
|
||||
}
|
||||
|
||||
@ -185,11 +185,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
* getConditionSlot
|
||||
*/
|
||||
public Set<Slot> getConditionSlot() {
|
||||
// this function is called by rules which reject mark join
|
||||
// so markJoinConjuncts is not processed here
|
||||
Preconditions.checkState(!isMarkJoin(),
|
||||
"shouldn't call mark join's getConditionSlot method");
|
||||
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
|
||||
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
|
||||
markJoinConjuncts.stream())
|
||||
.flatMap(expr -> expr.getInputSlots().stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
@ -198,11 +195,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
* getConditionExprId
|
||||
*/
|
||||
public Set<ExprId> getConditionExprId() {
|
||||
// this function is called by rules which reject mark join
|
||||
// so markJoinConjuncts is not processed here
|
||||
Preconditions.checkState(!isMarkJoin(),
|
||||
"shouldn't call mark join's getConditionExprId method");
|
||||
return Stream.concat(getHashJoinConjuncts().stream(), getOtherJoinConjuncts().stream())
|
||||
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
|
||||
markJoinConjuncts.stream())
|
||||
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
|
||||
@ -36,7 +36,6 @@ import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -260,11 +259,8 @@ public abstract class AbstractPhysicalJoin<
|
||||
* getConditionSlot
|
||||
*/
|
||||
public Set<Slot> getConditionSlot() {
|
||||
// this function is called by rules which reject mark join
|
||||
// so markJoinConjuncts is not processed here
|
||||
Preconditions.checkState(!isMarkJoin(),
|
||||
"shouldn't call mark join's getConditionSlot method");
|
||||
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
|
||||
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
|
||||
markJoinConjuncts.stream())
|
||||
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
|
||||
@ -210,11 +210,8 @@ public class PhysicalNestedLoopJoin<
|
||||
* getConditionSlot
|
||||
*/
|
||||
public Set<Slot> getConditionSlot() {
|
||||
// this function is called by rules which reject mark join
|
||||
// so markJoinConjuncts is not processed here
|
||||
Preconditions.checkState(!isMarkJoin(),
|
||||
"shouldn't call mark join's getConditionSlot method");
|
||||
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
|
||||
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
|
||||
markJoinConjuncts.stream())
|
||||
.flatMap(expr -> expr.getInputSlots().stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
@ -382,4 +382,14 @@ public class JoinUtils {
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean hasMarkConjuncts(Join join) {
|
||||
return !join.getMarkJoinConjuncts().isEmpty();
|
||||
}
|
||||
|
||||
public static boolean isNullAwareMarkJoin(Join join) {
|
||||
// if mark join's hash conjuncts is empty, we use mark conjuncts as hash conjuncts
|
||||
// and translate join type to NULL_AWARE_LEFT_SEMI_JOIN or NULL_AWARE_LEFT_ANTI_JOIN
|
||||
return join.getHashJoinConjuncts().isEmpty() && !join.getMarkJoinConjuncts().isEmpty();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user