[enhancement](nereids) allow reorder mark join (#30644)

This commit is contained in:
starocean999
2024-03-07 10:51:50 +08:00
committed by yiguolei
parent 474cacd572
commit 5905ffa1da
44 changed files with 449 additions and 467 deletions

View File

@ -47,7 +47,7 @@ public class CollectJoinConstraint implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalJoin().whenNot(LogicalJoin::isMarkJoin).thenApply(ctx -> {
logicalJoin().thenApply(ctx -> {
if (!ctx.cascadesContext.isLeadingJoin()) {
return ctx.root;
}

View File

@ -56,7 +56,6 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
return innerLogicalJoin(innerLogicalJoin(), group())
.when(topJoin -> checkReorder(topJoin, topJoin.left(), leftZigZag))
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();
@ -104,12 +103,10 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
double bRows = bottomJoin.right().getGroup().getStatistics().getRowCount();
double cRows = topJoin.right().getGroup().getStatistics().getRowCount();
return bRows < cRows && !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom()
&& (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin());
&& !topJoin.getJoinReorderContext().hasLAsscom();
} else {
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom()
&& (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin());
&& !topJoin.getJoinReorderContext().hasLAsscom();
}
}

View File

@ -62,7 +62,6 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
.when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child(), enableLeftZigZag))
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */

View File

@ -51,7 +51,6 @@ public class InnerJoinLeftAssociate extends OneExplorationRuleFactory {
return innerLogicalJoin(group(), innerLogicalJoin())
.when(InnerJoinLeftAssociate::checkReorder)
.whenNot(join -> join.hasDistributeHint() || join.right().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.right().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right();
GroupPlan a = topJoin.left();

View File

@ -52,7 +52,6 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory {
return innerLogicalJoin(group(), logicalProject(innerLogicalJoin()))
.when(InnerJoinLeftAssociate::checkReorder)
.whenNot(join -> join.hasDistributeHint() || join.right().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.right().child().isMarkJoin())
.when(join -> join.right().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();

View File

@ -49,7 +49,6 @@ public class InnerJoinRightAssociate extends OneExplorationRuleFactory {
return innerLogicalJoin(innerLogicalJoin(), group())
.when(InnerJoinRightAssociate::checkReorder)
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -50,7 +50,6 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
.when(InnerJoinRightAssociate::checkReorder)
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContain
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterType;
@ -59,7 +60,11 @@ public class JoinCommute extends OneExplorationRuleFactory {
.when(join -> check(swapType, join))
.whenNot(LogicalJoin::hasDistributeHint)
.whenNot(join -> joinOrderMatchBitmapRuntimeFilterOrder(join))
.whenNot(LogicalJoin::isMarkJoin)
// null aware mark join will be translated to null aware left semi/anti join
// we don't support null aware right semi/anti join, so should not commute
.whenNot(join -> JoinUtils.isNullAwareMarkJoin(join))
// commuting nest loop mark join is not supported by be
.whenNot(join -> join.isMarkJoin() && join.getHashJoinConjuncts().isEmpty())
.then(join -> {
LogicalJoin<Plan, Plan> newJoin = join.withTypeChildren(join.getJoinType().swap(),
join.right(), join.left(), null);

View File

@ -56,7 +56,6 @@ public class JoinExchange extends OneExplorationRuleFactory {
.when(JoinExchange::checkReorder)
.whenNot(join -> join.hasDistributeHint()
|| join.left().hasDistributeHint() || join.right().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left();
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right();

View File

@ -59,7 +59,6 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory {
.when(join -> join.left().isAllSlots() && join.right().isAllSlots())
.whenNot(join -> join.hasDistributeHint()
|| join.left().child().hasDistributeHint() || join.right().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().child().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left().child();
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right().child();

View File

@ -59,7 +59,6 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory {
.when(join -> join.left().isAllSlots())
.whenNot(join -> join.hasDistributeHint()
|| join.left().child().hasDistributeHint() || join.right().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left().child();
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right();

View File

@ -59,7 +59,6 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory {
.when(join -> join.right().isAllSlots())
.whenNot(join -> join.hasDistributeHint()
|| join.left().hasDistributeHint() || join.right().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().child().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left();
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right().child();

View File

@ -43,9 +43,7 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory {
.when(topJoin -> (topJoin.left().getJoinType().isLeftSemiOrAntiJoin()
&& (topJoin.getJoinType().isInnerJoin()
|| topJoin.getJoinType().isLeftOuterJoin())))
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint()
|| topJoin.left().isMarkJoin())
.whenNot(LogicalJoin::isMarkJoin)
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();
@ -60,9 +58,7 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory {
.when(topJoin -> (topJoin.right().getJoinType().isLeftSemiOrAntiJoin()
&& (topJoin.getJoinType().isInnerJoin()
|| topJoin.getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.right().hasDistributeHint()
|| topJoin.right().isMarkJoin())
.whenNot(LogicalJoin::isMarkJoin)
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.right().hasDistributeHint())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right();
GroupPlan a = topJoin.left();

View File

@ -44,9 +44,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
&& (topJoin.getJoinType().isInnerJoin()
|| topJoin.getJoinType().isLeftOuterJoin())))
.whenNot(topJoin -> topJoin.hasDistributeHint()
|| topJoin.left().child().hasDistributeHint()
|| topJoin.left().child().isMarkJoin())
.whenNot(LogicalJoin::isMarkJoin)
|| topJoin.left().child().hasDistributeHint())
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
@ -66,8 +64,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto
&& (topJoin.getJoinType().isInnerJoin()
|| topJoin.getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.hasDistributeHint()
|| topJoin.right().child().hasDistributeHint()
|| topJoin.right().child().isMarkJoin())
|| topJoin.right().child().hasDistributeHint())
.when(join -> join.right().isAllSlots())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child();

View File

@ -59,7 +59,6 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory {
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left()))
.when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet()))
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.thenApply(ctx -> {
LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin = ctx.root;
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();

View File

@ -57,7 +57,6 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
.when(join -> join.left().isAllSlots())
.thenApply(ctx -> {

View File

@ -64,7 +64,6 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
.whenNot(LogicalJoin::isMarkJoin)
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -54,7 +54,6 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(topJoin -> OuterJoinLAsscom.checkCondition(topJoin,
topJoin.left().child().right().getOutputExprIdSet()))
.when(join -> join.left().isAllSlots())

View File

@ -56,7 +56,7 @@ public class PushDownProjectThroughInnerOuterJoin implements ExplorationRuleFact
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalJoin(logicalProject(logicalJoin().whenNot(LogicalJoin::isMarkJoin)), group())
logicalJoin(logicalProject(logicalJoin()), group())
.when(j -> j.left().child().getJoinType().isOuterJoin()
|| j.left().child().getJoinType().isInnerJoin())
// Just pushdown project with non-column expr like (t.id + 1)
@ -70,7 +70,7 @@ public class PushDownProjectThroughInnerOuterJoin implements ExplorationRuleFact
}
return topJoin.withChildren(newLeft, topJoin.right());
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_OUTER_JOIN_LEFT),
logicalJoin(group(), logicalProject(logicalJoin().whenNot(LogicalJoin::isMarkJoin)))
logicalJoin(group(), logicalProject(logicalJoin()))
.when(j -> j.right().child().getJoinType().isOuterJoin()
|| j.right().child().getJoinType().isInnerJoin())
// Just pushdown project with non-column expr like (t.id + 1)

View File

@ -63,7 +63,6 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
return logicalJoin(logicalJoin(), group())
.when(this::typeChecker)
.whenNot(join -> join.hasDistributeHint() || join.left().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -31,8 +32,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for semi-semi transpose
@ -55,7 +56,6 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
.when(this::typeChecker)
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child(), false))
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> join.left().isAllSlots())
.then(topSemi -> {
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
@ -64,7 +64,9 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
GroupPlan b = bottomSemi.right();
GroupPlan c = topSemi.right();
Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
Set<NamedExpression> acProjects = new HashSet<NamedExpression>(abProject.getProjects());
Set<NamedExpression> acProjects = (Set<NamedExpression>) abProject.getProjects()
.stream().filter(slot -> !(slot instanceof MarkJoinSlotReference))
.collect(Collectors.toSet());
bottomSemi.getConditionSlot()
.forEach(slot -> {
@ -73,6 +75,9 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
}
});
LogicalJoin newBottomSemi = topSemi.withChildrenNoContext(a, c, null);
if (topSemi.isMarkJoin()) {
acProjects.add(topSemi.getMarkJoinSlotReference().get());
}
newBottomSemi.getJoinReorderContext().copyFrom(bottomSemi.getJoinReorderContext());
newBottomSemi.getJoinReorderContext().setHasCommute(false);
newBottomSemi.getJoinReorderContext().setHasLAsscom(false);

View File

@ -21,12 +21,16 @@ import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Set;
/**
* Push the predicate in the LogicalFilter to the join children.
@ -39,13 +43,18 @@ public class PushFilterInsideJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(logicalJoin())
.whenNot(filter -> filter.child().isMarkJoin())
// TODO: current just handle cross/inner join.
.when(filter -> filter.child().getJoinType().isCrossJoin()
|| filter.child().getJoinType().isInnerJoin())
.then(filter -> {
List<Expression> otherConditions = Lists.newArrayList(filter.getConjuncts());
LogicalJoin<Plan, Plan> join = filter.child();
Set<Slot> childOutput = join.getOutputSet();
if (ExpressionUtils.getInputSlotSet(otherConditions).stream()
.filter(MarkJoinSlotReference.class::isInstance)
.anyMatch(slot -> childOutput.contains(slot))) {
return null;
}
otherConditions.addAll(join.getOtherJoinConjuncts());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
otherConditions, join.getDistributeHint(), join.getMarkJoinSlotReference(),

View File

@ -35,7 +35,6 @@ public class SemiJoinCommute extends OneRewriteRuleFactory {
.whenNot(join -> ConnectContext.get().getSessionVariable().isDisableJoinReorder())
.whenNot(join -> join.isLeadingJoin())
.whenNot(LogicalJoin::hasDistributeHint)
.whenNot(LogicalJoin::isMarkJoin)
.then(join -> join.withTypeChildren(join.getJoinType().swap(), join.right(), join.left(), null))
.toRule(RuleType.LOGICAL_SEMI_JOIN_COMMUTE);
}

View File

@ -44,7 +44,6 @@ public class TransposeSemiJoinLogicalJoin extends OneRewriteRuleFactory {
|| topJoin.left().getJoinType().isLeftOuterJoin()
|| topJoin.left().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.hasDistributeHint() || topJoin.left().hasDistributeHint())
.whenNot(LogicalJoin::isMarkJoin)
.whenNot(topJoin -> topJoin.isLeadingJoin() || topJoin.left().isLeadingJoin())
.then(topSemiJoin -> {
LogicalJoin<Plan, Plan> bottomJoin = topSemiJoin.left();

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -29,6 +30,7 @@ import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Set;
@ -49,7 +51,6 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.when(join -> join.left().isAllSlots())
.whenNot(join -> join.hasDistributeHint() || join.left().child().hasDistributeHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.whenNot(topJoin -> topJoin.isLeadingJoin() || topJoin.left().child().isLeadingJoin())
.when(join -> join.left().getProjects().stream().allMatch(expr -> expr instanceof Slot))
.then(topSemiJoin -> {
@ -65,7 +66,8 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
if (containsType == ContainsType.ALL) {
return null;
}
ImmutableList<NamedExpression> topProjects = topSemiJoin.getOutput().stream()
.map(slot -> (NamedExpression) slot).collect(ImmutableList.toImmutableList());
if (containsType == ContainsType.LEFT) {
/*-
* topSemiJoin project
@ -85,7 +87,7 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
Plan newBottomSemiJoin = topSemiJoin.withChildren(a, c);
Plan newTopJoin = bottomJoin.withChildren(newBottomSemiJoin, b);
return project.withChildren(newTopJoin);
return project.withProjectsAndChild(topProjects, newTopJoin);
} else {
/*-
* topSemiJoin project
@ -105,7 +107,7 @@ public class TransposeSemiJoinLogicalJoinProject extends OneRewriteRuleFactory {
Plan newBottomSemiJoin = topSemiJoin.withChildren(b, c);
Plan newTopJoin = bottomJoin.withChildren(a, newBottomSemiJoin);
return project.withChildren(newTopJoin);
return project.withProjectsAndChild(topProjects, newTopJoin);
}
}).toRule(RuleType.TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT);
}

View File

@ -185,11 +185,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
markJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
@ -198,11 +195,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
* getConditionExprId
*/
public Set<ExprId> getConditionExprId() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionExprId method");
return Stream.concat(getHashJoinConjuncts().stream(), getOtherJoinConjuncts().stream())
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
markJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
}

View File

@ -36,7 +36,6 @@ import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
@ -260,11 +259,8 @@ public abstract class AbstractPhysicalJoin<
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
markJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}

View File

@ -210,11 +210,8 @@ public class PhysicalNestedLoopJoin<
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
return Stream.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
markJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}

View File

@ -382,4 +382,14 @@ public class JoinUtils {
.build();
}
}
public static boolean hasMarkConjuncts(Join join) {
return !join.getMarkJoinConjuncts().isEmpty();
}
public static boolean isNullAwareMarkJoin(Join join) {
// if mark join's hash conjuncts is empty, we use mark conjuncts as hash conjuncts
// and translate join type to NULL_AWARE_LEFT_SEMI_JOIN or NULL_AWARE_LEFT_ANTI_JOIN
return join.getHashJoinConjuncts().isEmpty() && !join.getMarkJoinConjuncts().isEmpty();
}
}