[feature](nereids)support mark join (#30133)

Co-authored-by: Jerry Hu <mrhhsg@gmail.com>
This commit is contained in:
starocean999
2024-01-26 18:45:18 +08:00
committed by yiguolei
parent f25af15842
commit 713798d549
79 changed files with 3696 additions and 791 deletions

View File

@ -37,7 +37,9 @@ public enum JoinOperator {
// NOT IN subqueries. It can have a single equality join conjunct
// that returns TRUE when the rhs is NULL.
NULL_AWARE_LEFT_ANTI_JOIN("NULL AWARE LEFT ANTI JOIN",
TJoinOp.NULL_AWARE_LEFT_ANTI_JOIN);
TJoinOp.NULL_AWARE_LEFT_ANTI_JOIN),
NULL_AWARE_LEFT_SEMI_JOIN("NULL AWARE LEFT SEMI JOIN",
TJoinOp.NULL_AWARE_LEFT_SEMI_JOIN);
private final String description;
private final TJoinOp thriftJoinOp;

View File

@ -1183,9 +1183,22 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
List<Expr> markConjuncts = ImmutableList.of();
boolean isHashJoinConjunctsEmpty = hashJoin.getHashJoinConjuncts().isEmpty();
boolean isMarkJoinConjunctsEmpty = hashJoin.getMarkJoinConjuncts().isEmpty();
if (isHashJoinConjunctsEmpty) {
// if hash join conjuncts is empty, means mark join conjuncts must be EqualPredicate
// BE should use mark join conjuncts to build hash table
Preconditions.checkState(!isMarkJoinConjunctsEmpty, "mark join conjuncts should not be empty.");
markConjuncts = hashJoin.getMarkJoinConjuncts().stream()
.map(EqualPredicate.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
}
HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftPlanRoot,
rightPlanRoot, JoinType.toJoinOperator(joinType), execEqConjuncts, Lists.newArrayList(),
rightPlanRoot, JoinType.toJoinOperator(joinType), execEqConjuncts, Lists.newArrayList(), markConjuncts,
null, null, null, hashJoin.isMarkJoin());
hashJoinNode.setNereidsId(hashJoin.getId());
hashJoinNode.setDistributeExprLists(distributeExprLists);
@ -1246,6 +1259,15 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.flatMap(e -> e.getInputSlots().stream())
.map(SlotReference.class::cast)
.forEach(s -> hashOutputSlotReferenceMap.put(s.getExprId(), s));
if (!isHashJoinConjunctsEmpty && !isMarkJoinConjunctsEmpty) {
// if hash join conjuncts is NOT empty, mark join conjuncts would be processed like other conjuncts
// BE should deal with mark join conjuncts differently, its result is 3 value bool(true, false, null)
hashJoin.getMarkJoinConjuncts()
.stream()
.flatMap(e -> e.getInputSlots().stream())
.map(SlotReference.class::cast)
.forEach(s -> hashOutputSlotReferenceMap.put(s.getExprId(), s));
}
hashJoin.getFilterConjuncts().stream()
.filter(e -> !(e.equals(BooleanLiteral.TRUE)))
.flatMap(e -> e.getInputSlots().stream())
@ -1271,7 +1293,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
List<SlotDescriptor> rightIntermediateSlotDescriptor = Lists.newArrayList();
TupleDescriptor intermediateDescriptor = context.generateTupleDesc();
if (hashJoin.getOtherJoinConjuncts().isEmpty()
if (hashJoin.getOtherJoinConjuncts().isEmpty() && (isHashJoinConjunctsEmpty != isMarkJoinConjunctsEmpty)
&& (joinType == JoinType.LEFT_ANTI_JOIN
|| joinType == JoinType.LEFT_SEMI_JOIN
|| joinType == JoinType.NULL_AWARE_LEFT_ANTI_JOIN)) {
@ -1294,7 +1316,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
leftIntermediateSlotDescriptor.add(sd);
}
} else if (hashJoin.getOtherJoinConjuncts().isEmpty()
} else if (hashJoin.getOtherJoinConjuncts().isEmpty() && (isHashJoinConjunctsEmpty != isMarkJoinConjunctsEmpty)
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
@ -1391,6 +1413,15 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
hashJoinNode.setOtherJoinConjuncts(otherJoinConjuncts);
if (!isHashJoinConjunctsEmpty && !isMarkJoinConjunctsEmpty) {
// add mark join conjuncts to hash join node
List<Expr> markJoinConjuncts = hashJoin.getMarkJoinConjuncts()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
hashJoinNode.setMarkJoinConjuncts(markJoinConjuncts);
}
hashJoinNode.setvIntermediateTupleDescList(Lists.newArrayList(intermediateDescriptor));
if (hashJoin.isShouldTranslateOutput()) {
@ -1564,6 +1595,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
nestedLoopJoinNode.setJoinConjuncts(joinConjuncts);
if (!nestedLoopJoin.getOtherJoinConjuncts().isEmpty()) {
List<Expr> markJoinConjuncts = nestedLoopJoin.getMarkJoinConjuncts().stream()
.map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
nestedLoopJoinNode.setMarkJoinConjuncts(markJoinConjuncts);
}
nestedLoopJoin.getFilterConjuncts().stream()
.filter(e -> !(e.equals(BooleanLiteral.TRUE)))
.map(e -> ExpressionTranslator.translate(e, context))
@ -1713,6 +1750,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
for (Expr expr : otherConjuncts) {
Expr.extractSlots(expr, requiredOtherConjunctsSlotIdSet);
}
if (!((HashJoinNode) joinNode).getEqJoinConjuncts().isEmpty()
&& !((HashJoinNode) joinNode).getMarkJoinConjuncts().isEmpty()) {
List<Expr> markConjuncts = ((HashJoinNode) joinNode).getMarkJoinConjuncts();
for (Expr expr : markConjuncts) {
Expr.extractSlots(expr, requiredOtherConjunctsSlotIdSet);
}
}
requiredOtherConjunctsSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e)));
requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e)));
for (ExprId exprId : requiredExprIds) {

View File

@ -211,6 +211,7 @@ public enum RuleType {
ELIMINATE_NOT_NULL(RuleTypeClass.REWRITE),
ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE),
ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_MARK_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY(RuleTypeClass.REWRITE),
ELIMINATE_JOIN_BY_UK(RuleTypeClass.REWRITE),
ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE),

View File

@ -47,7 +47,7 @@ public class CollectJoinConstraint implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalJoin().thenApply(ctx -> {
logicalJoin().whenNot(LogicalJoin::isMarkJoin).thenApply(ctx -> {
if (!ctx.cascadesContext.isLeadingJoin()) {
return ctx.root;
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
@ -49,6 +50,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@ -75,8 +77,6 @@ public class SubqueryToApply implements AnalysisRuleFactory {
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
boolean shouldOutputMarkJoinSlot = filter.getConjuncts().stream()
.anyMatch(expr -> shouldOutputMarkJoinSlot(expr, SearchState.SearchNot));
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryToApply::canConvertToSupply))
.collect(ImmutableList.toImmutableList());
@ -84,6 +84,11 @@ public class SubqueryToApply implements AnalysisRuleFactory {
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
return filter;
}
ImmutableList<Boolean> shouldOutputMarkJoinSlot =
filter.getConjuncts().stream()
.map(expr -> !(expr instanceof SubqueryExpr)
&& expr.containsType(SubqueryExpr.class))
.collect(ImmutableList.toImmutableList());
List<Expression> oldConjuncts = ImmutableList.copyOf(filter.getConjuncts());
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
@ -101,15 +106,29 @@ public class SubqueryToApply implements AnalysisRuleFactory {
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(
ctx.statementContext, shouldOutputMarkJoinSlot);
ctx.statementContext, shouldOutputMarkJoinSlot.get(i));
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(oldConjuncts.get(i), context);
/*
* the idea is replacing each mark join slot with null and false literal
* then run FoldConstant rule, if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false values)
* the mark slot can be non-nullable boolean
* we pass this info to LogicalApply. And in InApplyToJoin rule
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
: false;
applyPlan = subqueryToApply(subqueryExprs.stream()
.collect(ImmutableList.toImmutableList()), tmpPlan,
context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext,
Optional.of(conjunct), false);
Optional.of(conjunct), false, isMarkSlotNotNull);
tmpPlan = applyPlan;
newConjuncts.add(conjunct);
}
@ -120,8 +139,9 @@ public class SubqueryToApply implements AnalysisRuleFactory {
return new LogicalProject<>(applyPlan.getOutput().stream()
.filter(s -> !(s instanceof MarkJoinSlotReference))
.collect(ImmutableList.toImmutableList()), newFilter);
} else {
return newFilter;
}
return new LogicalFilter<>(conjuncts, applyPlan);
})
),
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
@ -155,7 +175,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
childPlan, context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext,
Optional.of(newProject), true);
Optional.of(newProject), true, false);
childPlan = applyPlan;
newProjects.add((NamedExpression) newProject);
}
@ -216,12 +236,25 @@ public class SubqueryToApply implements AnalysisRuleFactory {
ReplaceSubquery replaceSubquery = new ReplaceSubquery(ctx.statementContext, true);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(subqueryConjuncts.get(i), context);
/*
* the idea is replacing each mark join slot with null and false literal
* then run FoldConstant rule, if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false values)
* the mark slot can be non-nullable boolean
* we pass this info to LogicalApply. And in InApplyToJoin rule
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
: false;
applyPlan = subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
relatedInfoList.get(i) == RelatedInfo.RelatedToLeft ? leftChildPlan : rightChildPlan,
context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext, Optional.of(conjunct), false);
ctx.cascadesContext, Optional.of(conjunct), false, isMarkSlotNotNull);
if (relatedInfoList.get(i) == RelatedInfo.RelatedToLeft) {
leftChildPlan = applyPlan;
} else {
@ -282,7 +315,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
SubqueryExpr subqueryExpr = subqueryExprs.get(0);
List<Slot> correlatedSlots = subqueryExpr.getCorrelateSlots();
if (subqueryExpr instanceof ScalarSubquery) {
Set<Slot> inputSlots = expression.getInputSlots();
Set<Slot> inputSlots = subqueryExpr.getInputSlots();
if (correlatedSlots.isEmpty() && inputSlots.isEmpty()) {
relatedInfo = RelatedInfo.Unrelated;
} else if (leftOutputSlots.containsAll(inputSlots)
@ -322,7 +355,8 @@ public class SubqueryToApply implements AnalysisRuleFactory {
private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
CascadesContext ctx,
Optional<Expression> conjunct, boolean isProject) {
Optional<Expression> conjunct, boolean isProject,
boolean isMarkJoinSlotNotNull) {
LogicalPlan tmpPlan = childPlan;
for (int i = 0; i < subqueryExprs.size(); ++i) {
SubqueryExpr subqueryExpr = subqueryExprs.get(i);
@ -336,7 +370,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
if (!ctx.subqueryIsAnalyzed(subqueryExpr)) {
tmpPlan = addApply(subqueryExpr, tmpPlan,
subqueryToMarkJoinSlot, ctx, conjunct,
isProject, subqueryExprs.size() == 1);
isProject, subqueryExprs.size() == 1, isMarkJoinSlotNotNull);
}
}
return tmpPlan;
@ -354,7 +388,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
private LogicalPlan addApply(SubqueryExpr subquery, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
CascadesContext ctx, Optional<Expression> conjunct,
boolean isProject, boolean singleSubquery) {
boolean isProject, boolean singleSubquery, boolean isMarkJoinSlotNotNull) {
ctx.setSubqueryExprIsAnalyzed(subquery, true);
boolean needAddScalarSubqueryOutputToProjects = isConjunctContainsScalarSubqueryOutput(
subquery, conjunct, isProject, singleSubquery);
@ -362,7 +396,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
subquery.getCorrelateSlots(),
subquery, Optional.empty(),
subqueryToMarkJoinSlot.get(subquery),
needAddScalarSubqueryOutputToProjects, isProject,
needAddScalarSubqueryOutputToProjects, isProject, isMarkJoinSlotNotNull,
childPlan, subquery.getQueryPlan());
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.expression;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -39,7 +40,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
@ -179,34 +179,39 @@ public class ExpressionRewrite implements RewriteRuleFactory {
LogicalJoin<Plan, Plan> join = ctx.root;
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts();
if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()) {
List<Expression> markJoinConjuncts = join.getMarkJoinConjuncts();
if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()
&& markJoinConjuncts.isEmpty()) {
return join;
}
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<Expression> rewriteHashJoinConjuncts = Lists.newArrayList();
boolean hashJoinConjunctsChanged = false;
for (Expression expr : hashJoinConjuncts) {
Expression newExpr = rewriter.rewrite(expr, context);
hashJoinConjunctsChanged = hashJoinConjunctsChanged || !newExpr.equals(expr);
rewriteHashJoinConjuncts.addAll(ExpressionUtils.extractConjunction(newExpr));
}
Pair<Boolean, List<Expression>> newHashJoinConjuncts = rewriteConjuncts(hashJoinConjuncts, context);
Pair<Boolean, List<Expression>> newOtherJoinConjuncts = rewriteConjuncts(otherJoinConjuncts, context);
Pair<Boolean, List<Expression>> newMarkJoinConjuncts = rewriteConjuncts(markJoinConjuncts, context);
List<Expression> rewriteOtherJoinConjuncts = Lists.newArrayList();
boolean otherJoinConjunctsChanged = false;
for (Expression expr : otherJoinConjuncts) {
Expression newExpr = rewriter.rewrite(expr, context);
otherJoinConjunctsChanged = otherJoinConjunctsChanged || !newExpr.equals(expr);
rewriteOtherJoinConjuncts.addAll(ExpressionUtils.extractConjunction(newExpr));
}
if (!hashJoinConjunctsChanged && !otherJoinConjunctsChanged) {
if (!newHashJoinConjuncts.first && !newOtherJoinConjuncts.first
&& !newMarkJoinConjuncts.first) {
return join;
}
return new LogicalJoin<>(join.getJoinType(), rewriteHashJoinConjuncts,
rewriteOtherJoinConjuncts, join.getDistributeHint(), join.getMarkJoinSlotReference(),
join.children());
return new LogicalJoin<>(join.getJoinType(), newHashJoinConjuncts.second,
newOtherJoinConjuncts.second, newMarkJoinConjuncts.second,
join.getDistributeHint(), join.getMarkJoinSlotReference(), join.children());
}).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
}
private Pair<Boolean, List<Expression>> rewriteConjuncts(List<Expression> conjuncts,
ExpressionRewriteContext context) {
boolean isChanged = false;
ImmutableList.Builder<Expression> rewrittenConjuncts = new ImmutableList.Builder<>();
for (Expression expr : conjuncts) {
Expression newExpr = rewriter.rewrite(expr, context);
isChanged = isChanged || !newExpr.equals(expr);
rewrittenConjuncts.addAll(ExpressionUtils.extractConjunction(newExpr));
}
return Pair.of(isChanged, rewrittenConjuncts.build());
}
}
private class SortExpressionRewrite extends OneRewriteRuleFactory {

View File

@ -0,0 +1,113 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
/**
* TrySimplifyPredicateWithMarkJoinSlot
*/
public class TrySimplifyPredicateWithMarkJoinSlot extends AbstractExpressionRewriteRule {
public static final TrySimplifyPredicateWithMarkJoinSlot INSTANCE =
new TrySimplifyPredicateWithMarkJoinSlot();
@Override
public Expression visitAnd(And and, ExpressionRewriteContext context) {
/*
* predicate(with mark slot) and predicate(no mark slot)
* false and TRUE -> false(*) -> discard
* false and NULL -> null -> discard
* false and FALSE -> false -> discard
*
* null and TRUE -> null(*) -> discard
* null and NULL -> null -> discard
* null and FALSE -> false -> discard
*
* true and TRUE -> true(x) -> keep
* true and NULL -> null -> discard
* true and FALSE -> false -> discard
*
* we can see only 'predicate(with mark slot) and TRUE' may produce different results(*)
* because in filter predicate, we discard null and false values and only keep true values
* we can substitute mark slot with null and false to evaluate the predicate
* if the result are true, or result is either false or null, we can use non-nullable mark slot
* see ExpressionUtils.canInferNotNullForMarkSlot for more info
* we change 'predicate(with mark slot) and predicate(no mark slot)' -> predicate(with mark slot) and true
* to evaluate the predicate
*/
Expression left = and.left();
Expression newLeft = left.accept(this, context);
if (newLeft.getInputSlots().stream().noneMatch(MarkJoinSlotReference.class::isInstance)) {
newLeft = BooleanLiteral.TRUE;
}
Expression right = and.right();
Expression newRight = right.accept(this, context);
if (newRight.getInputSlots().stream().noneMatch(MarkJoinSlotReference.class::isInstance)) {
newRight = BooleanLiteral.TRUE;
}
Expression expr = new And(newLeft, newRight);
return expr;
}
@Override
public Expression visitOr(Or or, ExpressionRewriteContext context) {
/*
* predicate(with mark slot) or predicate(no mark slot)
* false or TRUE -> true -> keep
* false or NULL -> null(^) -> discard
* false or FALSE -> false(*) -> discard
*
* null or TRUE -> true -> keep
* null or NULL -> null(^) -> discard
* null or FALSE -> null(*) -> discard
*
* true or TRUE -> true -> keep
* true or NULL -> true(#) -> keep
* true or FALSE -> true(x) -> keep
*
* like And operator, even there are more differences. we can get the same conclusion.
* by substituting mark slot with null and false to evaluate the predicate
* if the result are true, or result is either false or null, we can use non-nullable mark slot
* we change 'predicate(with mark slot) or predicate(no mark slot)' -> predicate(with mark slot) or false
* to evaluate the predicate
*/
Expression left = or.left();
Expression newLeft = left.accept(this, context);
if (newLeft.getInputSlots().stream().noneMatch(MarkJoinSlotReference.class::isInstance)) {
newLeft = BooleanLiteral.FALSE;
}
Expression right = or.right();
Expression newRight = right.accept(this, context);
if (newRight.getInputSlots().stream().noneMatch(MarkJoinSlotReference.class::isInstance)) {
newRight = BooleanLiteral.FALSE;
}
Expression expr = new Or(newLeft, newRight);
return expr;
}
}

View File

@ -34,6 +34,7 @@ public class LogicalJoinToHashJoin extends OneImplementationRuleFactory {
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinConjuncts(),
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),

View File

@ -34,6 +34,7 @@ public class LogicalJoinToNestedLoopJoin extends OneImplementationRuleFactory {
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinConjuncts(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),

View File

@ -64,6 +64,9 @@ public class AdjustConjunctsReturnType extends DefaultPlanRewriter<Void> impleme
List<Expression> otherConjuncts = join.getOtherJoinConjuncts().stream()
.map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE))
.collect(Collectors.toList());
return join.withJoinConjuncts(hashConjuncts, otherConjuncts);
List<Expression> markConjuncts = join.getMarkJoinConjuncts().stream()
.map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE))
.collect(Collectors.toList());
return join.withJoinConjuncts(hashConjuncts, otherConjuncts, markConjuncts);
}
}

View File

@ -49,6 +49,7 @@ import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
@ -117,9 +118,22 @@ public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>> imple
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Map<ExprId, Slot> replaceMap) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, replaceMap);
List<Expression> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), replaceMap);
List<Expression> markConjuncts;
if (hashConjuncts.isEmpty()) {
// if hashConjuncts is empty, mark join conjuncts may used to build hash table
// so need call updateExpressions for mark join conjuncts before adjust nullable by output slot
markConjuncts = updateExpressions(join.getMarkJoinConjuncts(), replaceMap);
} else {
markConjuncts = null;
}
join.getOutputSet().forEach(o -> replaceMap.put(o.getExprId(), o));
if (markConjuncts == null) {
// hashConjuncts is not empty, mark join conjuncts are processed like other join conjuncts
Preconditions.checkState(!hashConjuncts.isEmpty(), "hash conjuncts should not be empty");
markConjuncts = updateExpressions(join.getMarkJoinConjuncts(), replaceMap);
}
List<Expression> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), replaceMap);
return join.withJoinConjuncts(hashConjuncts, otherConjuncts).recomputeLogicalProperties();
return join.withJoinConjuncts(hashConjuncts, otherConjuncts, markConjuncts).recomputeLogicalProperties();
}
@Override

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.types.UnsupportedType;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Set;
/**
@ -58,7 +59,12 @@ public class CheckDataTypes implements CustomRewriter {
}
private void checkLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> plan) {
plan.getHashJoinConjuncts().forEach(expr -> {
List<Expression> conjuncts = plan.getHashJoinConjuncts();
if (conjuncts.isEmpty()) {
// if hash conjuncts are empty, we may use mark conjuncts to build hash table
conjuncts = plan.getMarkJoinConjuncts();
}
conjuncts.forEach(expr -> {
DataType leftType = expr.child(0).getDataType();
DataType rightType = expr.child(1).getDataType();
if (!leftType.acceptsType(rightType)) {

View File

@ -37,11 +37,13 @@ public class ConvertInnerOrCrossJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
innerLogicalJoin()
.when(join -> join.getHashJoinConjuncts().size() == 0 && join.getOtherJoinConjuncts().size() == 0)
.when(join -> join.getHashJoinConjuncts().isEmpty() && join.getOtherJoinConjuncts().isEmpty()
&& join.getMarkJoinConjuncts().isEmpty())
.then(join -> join.withJoinType(JoinType.CROSS_JOIN))
.toRule(RuleType.INNER_TO_CROSS_JOIN),
crossLogicalJoin()
.when(join -> join.getHashJoinConjuncts().size() != 0 || join.getOtherJoinConjuncts().size() != 0)
.when(join -> !join.getHashJoinConjuncts().isEmpty() || !join.getOtherJoinConjuncts().isEmpty()
|| !join.getMarkJoinConjuncts().isEmpty())
.then(join -> join.withJoinType(JoinType.INNER_JOIN))
.toRule(RuleType.CROSS_TO_INNER_JOIN)
);

View File

@ -37,11 +37,14 @@ public class EliminateDedupJoinCondition extends OneRewriteRuleFactory {
.distinct().collect(Collectors.toList());
List<Expression> dedupOtherJoinConjuncts = join.getOtherJoinConjuncts().stream()
.distinct().collect(Collectors.toList());
List<Expression> dedupMarkJoinConjuncts = join.getMarkJoinConjuncts().stream()
.distinct().collect(Collectors.toList());
if (dedupHashJoinConjuncts.size() == join.getHashJoinConjuncts().size()
&& dedupOtherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()) {
&& dedupOtherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()
&& dedupMarkJoinConjuncts.size() == join.getMarkJoinConjuncts().size()) {
return null;
}
return join.withJoinConjuncts(dedupHashJoinConjuncts, dedupOtherJoinConjuncts);
return join.withJoinConjuncts(dedupHashJoinConjuncts, dedupOtherJoinConjuncts, dedupMarkJoinConjuncts);
}).toRule(RuleType.ELIMINATE_DEDUP_JOIN_CONDITION);
}
}

View File

@ -39,11 +39,15 @@ public class EliminateJoinCondition extends OneRewriteRuleFactory {
List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts().stream()
.filter(expression -> !expression.equals(BooleanLiteral.TRUE))
.collect(Collectors.toList());
List<Expression> markJoinConjuncts = join.getMarkJoinConjuncts().stream()
.filter(expression -> !expression.equals(BooleanLiteral.TRUE))
.collect(Collectors.toList());
if (hashJoinConjuncts.size() == join.getHashJoinConjuncts().size()
&& otherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()) {
&& otherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()
&& markJoinConjuncts.size() == join.getMarkJoinConjuncts().size()) {
return null;
}
return join.withJoinConjuncts(hashJoinConjuncts, otherJoinConjuncts);
return join.withJoinConjuncts(hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts);
}).toRule(RuleType.ELIMINATE_JOIN_CONDITION);
}
}

View File

@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import java.util.Set;
/**
* Eliminate mark join.
*/
public class EliminateMarkJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(logicalJoin().when(
join -> join.getJoinType().isSemiJoin() && !join.getMarkJoinConjuncts().isEmpty()))
.when(filter -> canSimplifyMarkJoin(filter.getConjuncts()))
.then(filter -> filter.withChildren(eliminateMarkJoin(filter.child())))
.toRule(RuleType.ELIMINATE_MARK_JOIN);
}
private boolean canSimplifyMarkJoin(Set<Expression> predicates) {
return ExpressionUtils
.canInferNotNullForMarkSlot(TrySimplifyPredicateWithMarkJoinSlot.INSTANCE
.rewrite(ExpressionUtils.and(predicates), null));
}
private LogicalJoin<Plan, Plan> eliminateMarkJoin(LogicalJoin<Plan, Plan> join) {
ImmutableList.Builder<Expression> newHashConjuncts = ImmutableList.builder();
newHashConjuncts.addAll(join.getHashJoinConjuncts());
newHashConjuncts.addAll(join.getMarkJoinConjuncts());
return join.withJoinConjuncts(newHashConjuncts.build(), join.getOtherJoinConjuncts(),
ExpressionUtils.EMPTY_CONDITION);
}
}

View File

@ -31,7 +31,10 @@ public class EliminateNullAwareLeftAntiJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return nullAwareLeftAntiLogicalJoin().then(antiJoin -> {
if (Stream.concat(antiJoin.getHashJoinConjuncts().stream(), antiJoin.getOtherJoinConjuncts().stream())
if (Stream.concat(Stream.concat(
antiJoin.getHashJoinConjuncts().stream(),
antiJoin.getOtherJoinConjuncts().stream()),
antiJoin.getMarkJoinConjuncts().stream())
.noneMatch(expression -> expression.nullable())) {
return antiJoin.withJoinType(JoinType.LEFT_ANTI_JOIN);
} else {

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.List;
@ -35,6 +36,7 @@ public class EliminateSemiJoin extends OneRewriteRuleFactory {
public Rule build() {
return logicalJoin()
// right will be converted to left
.whenNot(LogicalJoin::isMarkJoin)
.when(join -> join.getJoinType().isLeftSemiOrAntiJoin())
.when(join -> join.getHashJoinConjuncts().isEmpty())
.then(join -> {

View File

@ -40,7 +40,9 @@ public class ExtractFilterFromCrossJoin extends OneRewriteRuleFactory {
return crossLogicalJoin()
.then(join -> {
LogicalJoin<Plan, Plan> newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, join.getDistributeHint(),
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION,
join.getMarkJoinConjuncts(),
join.getDistributeHint(),
join.getMarkJoinSlotReference(), join.children());
Set<Expression> predicates = Stream.concat(join.getHashJoinConjuncts().stream(),
join.getOtherJoinConjuncts().stream())

View File

@ -73,6 +73,7 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory {
return new LogicalJoin<>(joinType,
combinedHashJoinConjuncts,
remainedNonHashJoinConjuncts,
join.getMarkJoinConjuncts(),
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
join.children());

View File

@ -100,32 +100,54 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
// TODO: trick here, because when deep copy logical plan the apply right child
// is not same with query plan in subquery expr, since the scan node copy twice
Expression right = inSubquery.getSubqueryOutput((LogicalPlan) apply.right());
if (apply.isCorrelated()) {
if (inSubquery.isNot()) {
predicate = ExpressionUtils.and(ExpressionUtils.or(new EqualTo(left, right),
new IsNull(left), new IsNull(right)),
apply.getCorrelationFilter().get());
} else {
predicate = ExpressionUtils.and(new EqualTo(left, right),
apply.getCorrelationFilter().get());
}
} else {
if (apply.isMarkJoin()) {
List<Expression> joinConjuncts = apply.getCorrelationFilter().isPresent()
? ExpressionUtils.extractConjunction(apply.getCorrelationFilter().get())
: Lists.newArrayList();
predicate = new EqualTo(left, right);
}
List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
if (inSubquery.isNot()) {
List<Expression> markConjuncts = Lists.newArrayList(predicate);
if (!predicate.nullable() || (apply.isMarkJoinSlotNotNull() && !inSubquery.isNot())) {
// we can merge mark conjuncts with hash conjuncts in 2 scenarios
// 1. the mark join predicate is not nullable, so no null value would be produced
// 2. semi join with non-nullable mark slot.
// because semi join only care about mark slot with true value and discard false and null
// it's safe the use false instead of null in this case
joinConjuncts.addAll(markConjuncts);
markConjuncts.clear();
}
return new LogicalJoin<>(
predicate.nullable() && !apply.isCorrelated()
? JoinType.NULL_AWARE_LEFT_ANTI_JOIN
: JoinType.LEFT_ANTI_JOIN,
Lists.newArrayList(), conjuncts, new DistributeHint(DistributeType.NONE),
apply.getMarkJoinSlotReference(), apply.children());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
conjuncts,
inSubquery.isNot() ? JoinType.LEFT_ANTI_JOIN : JoinType.LEFT_SEMI_JOIN,
Lists.newArrayList(), joinConjuncts, markConjuncts,
new DistributeHint(DistributeType.NONE), apply.getMarkJoinSlotReference(),
apply.children());
} else {
if (apply.isCorrelated()) {
if (inSubquery.isNot()) {
predicate = ExpressionUtils.and(ExpressionUtils.or(new EqualTo(left, right),
new IsNull(left), new IsNull(right)),
apply.getCorrelationFilter().get());
} else {
predicate = ExpressionUtils.and(new EqualTo(left, right),
apply.getCorrelationFilter().get());
}
} else {
predicate = new EqualTo(left, right);
}
List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
if (inSubquery.isNot()) {
return new LogicalJoin<>(
predicate.nullable() && !apply.isCorrelated()
? JoinType.NULL_AWARE_LEFT_ANTI_JOIN
: JoinType.LEFT_ANTI_JOIN,
Lists.newArrayList(), conjuncts, new DistributeHint(DistributeType.NONE),
apply.getMarkJoinSlotReference(), apply.children());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
conjuncts,
new DistributeHint(DistributeType.NONE), apply.getMarkJoinSlotReference(),
apply.children());
}
}
}).toRule(RuleType.IN_APPLY_TO_JOIN);
}

View File

@ -74,6 +74,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(any(), any()).when(JoinUtils::shouldNestedLoopJoin)
.whenNot(LogicalJoin::isMarkJoin)
.when(join -> supportJoinType.contains(join.getJoinType())
&& ConnectContext.get().getSessionVariable().getEnablePipelineEngine())
.thenApply(ctx -> {

View File

@ -96,8 +96,9 @@ public class PushDownAliasThroughJoin extends OneRewriteRuleFactory {
List<Expression> newHash = replaceJoinConjuncts(join.getHashJoinConjuncts(), replaceMap);
List<Expression> newOther = replaceJoinConjuncts(join.getOtherJoinConjuncts(), replaceMap);
List<Expression> newMark = replaceJoinConjuncts(join.getMarkJoinConjuncts(), replaceMap);
Plan newJoin = join.withConjunctsChildren(newHash, newOther, left, right);
Plan newJoin = join.withConjunctsChildren(newHash, newOther, newMark, left, right);
return project.withProjectsAndChild(newProjects, newJoin);
}).toRule(RuleType.PUSH_DOWN_ALIAS_THROUGH_JOIN);
}

View File

@ -99,6 +99,19 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
}
});
// add mark conjuncts used slots to project exprs
join.getMarkJoinConjuncts().stream().flatMap(conjunct ->
conjunct.getInputSlots().stream()
).forEach(slot -> {
if (leftExprIdSet.contains(slot.getExprId())) {
// belong to left child
leftProjectExprs.add(slot);
} else {
// belong to right child
rightProjectExprs.add(slot);
}
});
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())

View File

@ -137,6 +137,7 @@ public class PushDownFilterThroughJoin extends OneRewriteRuleFactory {
new LogicalJoin<>(join.getJoinType(),
join.getHashJoinConjuncts(),
joinConditions,
join.getMarkJoinConjuncts(),
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
PlanUtils.filterOrSelf(leftPredicates, join.left()),

View File

@ -89,7 +89,8 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
remainingOther, join.getDistributeHint(), join.getMarkJoinSlotReference(), left, right);
remainingOther, join.getMarkJoinConjuncts(), join.getDistributeHint(),
join.getMarkJoinSlotReference(), left, right);
}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
}

View File

@ -89,7 +89,7 @@ public class UnCorrelatedApplyAggregateFilter extends OneRewriteRuleFactory {
ExpressionUtils.optionalAnd(correlatedPredicate),
apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(),
apply.isInProject(), apply.left(), newAgg);
apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), newAgg);
}).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER);
}
}

View File

@ -69,7 +69,7 @@ public class UnCorrelatedApplyFilter extends OneRewriteRuleFactory {
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(),
apply.isInProject(), apply.left(), child);
apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), child);
}).toRule(RuleType.UN_CORRELATED_APPLY_FILTER);
}
}

View File

@ -90,7 +90,7 @@ public class UnCorrelatedApplyProjectFilter extends OneRewriteRuleFactory {
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(),
apply.isInProject(), apply.left(), newProject);
apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), newProject);
}).toRule(RuleType.UN_CORRELATED_APPLY_PROJECT_FILTER);
}
}

View File

@ -60,7 +60,7 @@ public class JoinEstimation {
private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics leftStats,
Statistics rightStats, Join join) {
for (Expression expr : join.getHashJoinConjuncts()) {
for (Expression expr : join.getEqualPredicates()) {
for (Slot slot : expr.getInputSlots()) {
ColumnStatistic colStats = leftStats.findColumnStatistics(slot);
if (colStats == null) {
@ -87,7 +87,7 @@ public class JoinEstimation {
boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount();
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
List<EqualPredicate> trustableConditions = join.getHashJoinConjuncts().stream()
List<EqualPredicate> trustableConditions = join.getEqualPredicates().stream()
.map(expression -> (EqualPredicate) expression)
.filter(
expression -> {
@ -173,7 +173,7 @@ public class JoinEstimation {
private static double computeSelectivityForBuildSideWhenColStatsUnknown(Statistics buildStats, Join join) {
double sel = 1.0;
for (Expression cond : join.getHashJoinConjuncts()) {
for (Expression cond : join.getEqualPredicates()) {
if (cond instanceof EqualTo) {
EqualTo equal = (EqualTo) cond;
if (equal.left() instanceof Slot && equal.right() instanceof Slot) {
@ -204,7 +204,7 @@ public class JoinEstimation {
}
Statistics innerJoinStats;
if (join.getHashJoinConjuncts().isEmpty()) {
if (join.getEqualPredicates().isEmpty()) {
innerJoinStats = estimateNestLoopJoin(leftStats, rightStats, join);
} else {
innerJoinStats = estimateHashJoin(leftStats, rightStats, join);
@ -283,7 +283,7 @@ public class JoinEstimation {
}
}
double rowCount = Double.POSITIVE_INFINITY;
for (Expression conjunct : join.getHashJoinConjuncts()) {
for (Expression conjunct : join.getEqualPredicates()) {
double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
join, (EqualPredicate) conjunct);
if (rowCount > eqRowCount) {
@ -359,7 +359,7 @@ public class JoinEstimation {
*/
private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
for (Expression expr : join.getHashJoinConjuncts()) {
for (Expression expr : join.getEqualPredicates()) {
EqualPredicate equalTo = (EqualPredicate) expr;
ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats);
ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats);

View File

@ -115,7 +115,8 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext
Optional<MarkJoinSlotReference> markJoinSlotReference = apply.getMarkJoinSlotReference()
.map(m -> (MarkJoinSlotReference) ExpressionDeepCopier.INSTANCE.deepCopy(m, context));
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, apply.isNeedAddSubOutputToProjects(), apply.isInProject(), left, right);
markJoinSlotReference, apply.isNeedAddSubOutputToProjects(), apply.isInProject(),
apply.isMarkJoinSlotNotNull(), left, right);
}
@Override
@ -336,7 +337,10 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream()
.map(c -> ExpressionDeepCopier.INSTANCE.deepCopy(c, context))
.collect(ImmutableList.toImmutableList());
return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, otherJoinConjuncts,
List<Expression> markJoinConjuncts = join.getMarkJoinConjuncts().stream()
.map(c -> ExpressionDeepCopier.INSTANCE.deepCopy(c, context))
.collect(ImmutableList.toImmutableList());
return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
join.getDistributeHint(), join.getMarkJoinSlotReference(), children);
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.plans.algebra;
import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
@ -28,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Common interface for logical/physical join.
@ -42,25 +44,28 @@ public interface Join {
.collect(Collectors.toList());
}
default List<EqualPredicate> getEqualPredicates() {
return Stream.concat(getHashJoinConjuncts().stream(), getMarkJoinConjuncts().stream())
.filter(EqualPredicate.class::isInstance).map(EqualPredicate.class::cast)
.collect(Collectors.toList());
}
List<Expression> getOtherJoinConjuncts();
List<Expression> getMarkJoinConjuncts();
Optional<Expression> getOnClauseCondition();
DistributeHint getDistributeHint();
boolean isMarkJoin();
Optional<MarkJoinSlotReference> getMarkJoinSlotReference();
default boolean hasDistributeHint() {
return getDistributeHint().distributeType != DistributeType.NONE;
}
/**
* The join plan has join condition or not.
*/
default boolean hasJoinCondition() {
return !getHashJoinConjuncts().isEmpty() || !getOtherJoinConjuncts().isEmpty();
}
default JoinDistributeType getLeftHint() {
return JoinDistributeType.NONE;
}
@ -78,8 +83,4 @@ public interface Join {
return JoinDistributeType.NONE;
}
}
default Optional<MarkJoinSlotReference> getLeftMarkJoinSlotReference() {
return Optional.empty();
}
}

View File

@ -62,6 +62,14 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
// Whether adding the subquery's output to projects
private final boolean needAddSubOutputToProjects;
/*
* This flag is indicate the mark join slot can be non-null or not
* in InApplyToJoin rule, if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
* see SubqueryToApply rule for more info
*/
private final boolean isMarkJoinSlotNotNull;
private LogicalApply(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
List<Expression> correlationSlot,
@ -69,6 +77,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
Optional<MarkJoinSlotReference> markJoinSlotReference,
boolean needAddSubOutputToProjects,
boolean inProject,
boolean isMarkJoinSlotNotNull,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties, leftChild, rightChild);
this.correlationSlot = correlationSlot == null ? ImmutableList.of() : ImmutableList.copyOf(correlationSlot);
@ -77,14 +86,15 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
this.markJoinSlotReference = markJoinSlotReference;
this.needAddSubOutputToProjects = needAddSubOutputToProjects;
this.inProject = inProject;
this.isMarkJoinSlotNotNull = isMarkJoinSlotNotNull;
}
public LogicalApply(List<Expression> correlationSlot, SubqueryExpr subqueryExpr,
Optional<Expression> correlationFilter, Optional<MarkJoinSlotReference> markJoinSlotReference,
boolean needAddSubOutputToProjects, boolean inProject,
boolean needAddSubOutputToProjects, boolean inProject, boolean isMarkJoinSlotNotNull,
LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) {
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, input,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, input,
subquery);
}
@ -136,6 +146,10 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return inProject;
}
public boolean isMarkJoinSlotNotNull() {
return isMarkJoinSlotNotNull;
}
@Override
public List<Slot> computeOutput() {
return ImmutableList.<Slot>builder()
@ -153,6 +167,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
"correlationSlot", correlationSlot,
"correlationFilter", correlationFilter,
"isMarkJoin", markJoinSlotReference.isPresent(),
"isMarkJoinSlotNotNull", isMarkJoinSlotNotNull,
"MarkJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty");
}
@ -170,14 +185,15 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
&& Objects.equals(correlationFilter, that.getCorrelationFilter())
&& Objects.equals(markJoinSlotReference, that.getMarkJoinSlotReference())
&& needAddSubOutputToProjects == that.needAddSubOutputToProjects
&& inProject == that.inProject;
&& inProject == that.inProject
&& isMarkJoinSlotNotNull == that.isMarkJoinSlotNotNull;
}
@Override
public int hashCode() {
return Objects.hash(
correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject);
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull);
}
@Override
@ -201,14 +217,15 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public LogicalApply<Plan, Plan> withSubqueryExprAndChildren(SubqueryExpr subqueryExpr, List<Plan> children) {
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, children.get(0), children.get(1));
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
children.get(0), children.get(1));
}
@Override
public LogicalApply<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
children.get(0), children.get(1));
}
@ -216,7 +233,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalApply<>(groupExpression, Optional.of(getLogicalProperties()),
correlationSlot, subqueryExpr, correlationFilter, markJoinSlotReference,
needAddSubOutputToProjects, inProject, left(), right());
needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, left(), right());
}
@Override
@ -225,6 +242,6 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(groupExpression, logicalProperties, correlationSlot, subqueryExpr,
correlationFilter, markJoinSlotReference,
needAddSubOutputToProjects, inProject, children.get(0), children.get(1));
needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, children.get(0), children.get(1));
}
}

View File

@ -67,6 +67,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
private final JoinType joinType;
private final List<Expression> otherJoinConjuncts;
private final List<Expression> hashJoinConjuncts;
private final List<Expression> markJoinConjuncts;
// When the predicate condition contains subqueries and disjunctions, the join will be marked as MarkJoin.
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
@ -80,114 +81,92 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public LogicalJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION,
new DistributeHint(DistributeType.NONE),
Optional.empty(), Optional.empty(), Optional.empty(), leftChild, rightChild);
ExpressionUtils.EMPTY_CONDITION, new DistributeHint(DistributeType.NONE),
Optional.empty(), Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, ExpressionUtils.EMPTY_CONDITION,
new DistributeHint(DistributeType.NONE), Optional.empty(),
Optional.empty(), Optional.empty(), leftChild, rightChild);
ExpressionUtils.EMPTY_CONDITION, new DistributeHint(DistributeType.NONE),
Optional.empty(), Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts,
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION,
new DistributeHint(DistributeType.NONE), Optional.empty(),
Optional.empty(), Optional.empty(), leftChild, rightChild);
Optional.empty(), Optional.empty(), ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
DistributeHint hint, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, Optional.empty(), Optional.empty(),
Optional.empty(), leftChild, rightChild);
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, DistributeHint hint, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, hint,
Optional.empty(), Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts,
otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), Optional.empty(), leftChild, rightChild);
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, hint,
markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(
long bitmap,
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts,
otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), Optional.empty(), leftChild, rightChild);
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, List<Expression> markJoinConjuncts, DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, hint,
markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
}
public LogicalJoin(long bitmap, JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, hint,
markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(leftChild, rightChild), null);
this.bitmap = LongBitmap.or(this.bitmap, bitmap);
}
public LogicalJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
List<Plan> children) {
this(joinType, hashJoinConjuncts,
otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), Optional.empty(), children);
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference, List<Plan> children) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, hint,
markJoinSlotReference, Optional.empty(), Optional.empty(), children, null);
}
private LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, List<Expression> markJoinConjuncts, DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference, List<Plan> children) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, hint,
markJoinSlotReference, Optional.empty(), Optional.empty(), children, null);
}
private LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, List<Expression> markJoinConjuncts,
DistributeHint hint, Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties,
List<Plan> children, JoinReorderContext joinReorderContext) {
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children,
JoinReorderContext joinReorderContext) {
// Just use in withXXX method. Don't need check/copyOf()
super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, children);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = hashJoinConjuncts;
this.otherJoinConjuncts = otherJoinConjuncts;
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.joinReorderContext.copyFrom(joinReorderContext);
this.markJoinSlotReference = markJoinSlotReference;
}
private LogicalJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, leftChild, rightChild);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.markJoinSlotReference = markJoinSlotReference;
}
private LogicalJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
List<Plan> children) {
super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, children);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.markJoinConjuncts = ImmutableList.copyOf(markJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
if (joinReorderContext != null) {
this.joinReorderContext.copyFrom(joinReorderContext);
}
this.markJoinSlotReference = markJoinSlotReference;
}
@ -204,26 +183,56 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return hashJoinConjuncts;
}
/**
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
/**
* getConditionExprId
*/
public Set<ExprId> getConditionExprId() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionExprId method");
return Stream.concat(getHashJoinConjuncts().stream(), getOtherJoinConjuncts().stream())
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
}
/**
* getLeftConditionSlot
*/
public Set<Slot> getLeftConditionSlot() {
// TODO this function is used by TransposeSemiJoinAgg, we assume it can handle mark join correctly.
Set<Slot> leftOutputSet = this.left().getOutputSet();
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream())
.filter(leftOutputSet::contains)
return Stream
.concat(Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()),
markJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).filter(leftOutputSet::contains)
.collect(ImmutableSet.toImmutableSet());
}
/**
* getOnClauseCondition
*/
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
// TODO this function is called by AggScalarSubQueryToWindowFunction and InferPredicates
// we assume they can handle mark join correctly
Optional<Expression> normalJoinConjuncts =
ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
return normalJoinConjuncts.isPresent()
? ExpressionUtils.optionalAnd(ImmutableList.of(normalJoinConjuncts.get()),
markJoinConjuncts)
: ExpressionUtils.optionalAnd(markJoinConjuncts);
}
public JoinType getJoinType() {
@ -242,6 +251,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return markJoinSlotReference.isPresent();
}
public List<Expression> getMarkJoinConjuncts() {
return markJoinConjuncts;
}
public JoinReorderContext getJoinReorderContext() {
return joinReorderContext;
}
@ -261,7 +274,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
"type", joinType,
"markJoinSlotReference", markJoinSlotReference,
"hashJoinConjuncts", hashJoinConjuncts,
"otherJoinConjuncts", otherJoinConjuncts);
"otherJoinConjuncts", otherJoinConjuncts,
"markJoinConjuncts", markJoinConjuncts);
if (hint.distributeType != DistributeType.NONE) {
args.add("hint");
args.add(hint.getExplainString());
@ -282,12 +296,13 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
&& hint.equals(that.hint)
&& hashJoinConjuncts.equals(that.hashJoinConjuncts)
&& otherJoinConjuncts.equals(that.otherJoinConjuncts)
&& markJoinConjuncts.equals(that.markJoinConjuncts)
&& Objects.equals(markJoinSlotReference, that.markJoinSlotReference);
}
@Override
public int hashCode() {
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference);
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference);
}
@Override
@ -300,6 +315,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return new ImmutableList.Builder<Expression>()
.addAll(hashJoinConjuncts)
.addAll(otherJoinConjuncts)
.addAll(markJoinConjuncts)
.build();
}
@ -328,67 +344,86 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
@Override
public LogicalJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), Optional.empty(), children, joinReorderContext);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(), children,
joinReorderContext);
}
@Override
public LogicalJoin<Plan, Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, Optional.of(getLogicalProperties()), children, joinReorderContext);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, groupExpression, Optional.of(getLogicalProperties()),
children, joinReorderContext);
}
@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties, children, joinReorderContext);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, groupExpression, logicalProperties, children,
joinReorderContext);
}
public LogicalJoin<Plan, Plan> withChildrenNoContext(Plan left, Plan right) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, left, right);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(left, right), null);
}
public LogicalJoin<Plan, Plan> withJoinConjuncts(
List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
hint, markJoinSlotReference, children);
public LogicalJoin<Plan, Plan> withJoinConjuncts(List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(), children, null);
}
public LogicalJoin<Plan, Plan> withJoinConjuncts(List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(), children, null);
}
public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
List<Expression> hashJoinConjuncts, Plan left, Plan right) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, left, right);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(left, right), null);
}
public LogicalJoin<Plan, Plan> withConjunctsChildren(List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, Plan left, Plan right) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference, left,
right);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(left, right), null);
}
public LogicalJoin<Plan, Plan> withConjunctsChildren(List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts, Plan left, Plan right) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(left, right), null);
}
public LogicalJoin<Plan, Plan> withJoinType(JoinType joinType) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, children);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(), children, null);
}
public LogicalJoin<Plan, Plan> withTypeChildren(JoinType joinType, Plan left, Plan right) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, left, right);
}
public LogicalJoin<Plan, Plan> withOtherJoinConjuncts(List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, children);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
ImmutableList.of(left, right), null);
}
/**
* extractNullRejectHashKeys
*/
public @Nullable Pair<Set<Slot>, Set<Slot>> extractNullRejectHashKeys() {
// this function is only used by computeFuncDeps, and function dependence calculation is disabled for mark join
// so markJoinConjuncts is not processed now
Set<Slot> leftKeys = new HashSet<>();
Set<Slot> rightKeys = new HashSet<>();
for (Expression expression : hashJoinConjuncts) {
@ -413,6 +448,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
@Override
public FunctionalDependencies computeFuncDeps(Supplier<List<Slot>> outputSupplier) {
if (isMarkJoin()) {
// TODO disable function dependence calculation for mark join, but need re-think this in future.
return FunctionalDependencies.EMPTY_FUNC_DEPS;
}
//1. NALAJ and FOJ block functional dependencies
if (joinType.isNullAwareLeftAntiJoin() || joinType.isFullOuterJoin()) {
return FunctionalDependencies.EMPTY_FUNC_DEPS;
@ -478,6 +517,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
* get Equal slot from join
*/
public ImmutableEqualSet<Slot> getEqualSlots() {
// this function is only used by EliminateJoinByFK rule, and EliminateJoinByFK is disabled for mark join
// so markJoinConjuncts is not processed now
// TODO: Use fd in the future
if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) {
return ImmutableEqualSet.empty();
@ -499,6 +540,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
properties.put("JoinType", joinType.toString());
properties.put("HashJoinConjuncts", hashJoinConjuncts.toString());
properties.put("OtherJoinConjuncts", otherJoinConjuncts.toString());
properties.put("MarkJoinConjuncts", markJoinConjuncts.toString());
properties.put("DistributeHint", hint.toString());
properties.put("MarkJoinSlotReference", markJoinSlotReference.toString());
logicalJoin.put("Properties", properties);

View File

@ -33,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
import java.util.Optional;
@ -48,13 +47,12 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
private final ImmutableList<Expression> otherJoinConjuncts;
private final ImmutableList<Expression> hashJoinConjuncts;
private final DistributeHint hint;
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts,
DistributeHint hint) {
this(joinType, leftChild, rightChild, expressions,
hashJoinConjuncts, Optional.empty(), Optional.empty(), hint, Optional.empty());
hashJoinConjuncts, Optional.empty(), Optional.empty(), hint);
}
/**
@ -63,13 +61,12 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
DistributeHint hint, Optional<MarkJoinSlotReference> markJoinSlotReference) {
DistributeHint hint) {
super(PlanType.LOGICAL_USING_JOIN, groupExpression, logicalProperties, leftChild, rightChild);
this.joinType = joinType;
this.otherJoinConjuncts = ImmutableList.copyOf(expressions);
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.hint = hint;
this.markJoinSlotReference = markJoinSlotReference;
}
@Override
@ -114,20 +111,20 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint, markJoinSlotReference);
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
}
@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
return new UsingJoin(joinType, children.get(0), children.get(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, logicalProperties, hint, markJoinSlotReference);
hashJoinConjuncts, groupExpression, logicalProperties, hint);
}
@Override
public Plan withChildren(List<Plan> children) {
return new UsingJoin(joinType, children.get(0), children.get(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint, markJoinSlotReference);
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
}
@Override
@ -160,11 +157,15 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
}
public boolean isMarkJoin() {
return markJoinSlotReference.isPresent();
return false;
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
return Optional.empty();
}
public List<Expression> getMarkJoinConjuncts() {
return ExpressionUtils.EMPTY_CONDITION;
}
@Override
@ -176,9 +177,4 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
public boolean hasDistributeHint() {
return hint != null;
}
@Override
public boolean hasJoinCondition() {
return !CollectionUtils.isEmpty(hashJoinConjuncts);
}
}

View File

@ -36,6 +36,7 @@ import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
@ -61,6 +62,7 @@ public abstract class AbstractPhysicalJoin<
protected final JoinType joinType;
protected final List<Expression> hashJoinConjuncts;
protected final List<Expression> otherJoinConjuncts;
protected final List<Expression> markJoinConjuncts;
protected final DistributeHint hint;
protected final Optional<MarkJoinSlotReference> markJoinSlotReference;
protected final List<RuntimeFilter> runtimeFilters = Lists.newArrayList();
@ -81,12 +83,9 @@ public abstract class AbstractPhysicalJoin<
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(type, groupExpression, logicalProperties, leftChild, rightChild);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.markJoinSlotReference = markJoinSlotReference;
this(type, joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION,
hint, markJoinSlotReference, groupExpression, logicalProperties, null, null,
leftChild, rightChild);
}
/**
@ -105,10 +104,30 @@ public abstract class AbstractPhysicalJoin<
Statistics statistics,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(type, joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION,
hint, markJoinSlotReference, groupExpression, logicalProperties, physicalProperties,
statistics, leftChild, rightChild);
}
protected AbstractPhysicalJoin(
PlanType type,
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
Statistics statistics,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(type, groupExpression, logicalProperties, physicalProperties, statistics, leftChild, rightChild);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.markJoinConjuncts = ImmutableList.copyOf(markJoinConjuncts);
this.hint = hint;
this.markJoinSlotReference = markJoinSlotReference;
}
@ -137,11 +156,16 @@ public abstract class AbstractPhysicalJoin<
return markJoinSlotReference.isPresent();
}
public List<Expression> getMarkJoinConjuncts() {
return markJoinConjuncts;
}
@Override
public List<? extends Expression> getExpressions() {
return new Builder<Expression>()
.addAll(hashJoinConjuncts)
.addAll(otherJoinConjuncts).build();
.addAll(otherJoinConjuncts)
.addAll(markJoinConjuncts).build();
}
// TODO:
@ -158,13 +182,14 @@ public abstract class AbstractPhysicalJoin<
return joinType == that.joinType
&& hashJoinConjuncts.equals(that.hashJoinConjuncts)
&& otherJoinConjuncts.equals(that.otherJoinConjuncts)
&& markJoinConjuncts.equals(that.markJoinConjuncts)
&& hint.equals(that.hint)
&& Objects.equals(markJoinSlotReference, that.markJoinSlotReference);
}
@Override
public int hashCode() {
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference);
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference);
}
/**
@ -173,7 +198,14 @@ public abstract class AbstractPhysicalJoin<
* @return the combination of hashJoinConjuncts and otherJoinConjuncts
*/
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
// TODO this function is called by AggScalarSubQueryToWindowFunction and InferPredicates
// we assume they can handle mark join correctly
Optional<Expression> normalJoinConjuncts =
ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
return normalJoinConjuncts.isPresent()
? ExpressionUtils.optionalAnd(ImmutableList.of(normalJoinConjuncts.get()),
markJoinConjuncts)
: ExpressionUtils.optionalAnd(markJoinConjuncts);
}
@Override
@ -200,6 +232,7 @@ public abstract class AbstractPhysicalJoin<
properties.put("JoinType", joinType.toString());
properties.put("HashJoinConjuncts", hashJoinConjuncts.toString());
properties.put("OtherJoinConjuncts", otherJoinConjuncts.toString());
properties.put("MarkJoinConjuncts", markJoinConjuncts.toString());
properties.put("JoinHint", hint.toString());
properties.put("MarkJoinSlotReference", markJoinSlotReference.toString());
physicalJoin.put("Properties", properties);
@ -223,7 +256,14 @@ public abstract class AbstractPhysicalJoin<
.build();
}
/**
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}
@ -233,7 +273,8 @@ public abstract class AbstractPhysicalJoin<
List<Object> args = Lists.newArrayList("type", joinType,
"stats", statistics,
"hashCondition", hashJoinConjuncts,
"otherCondition", otherJoinConjuncts);
"otherCondition", otherJoinConjuncts,
"markCondition", markJoinConjuncts);
if (markJoinSlotReference.isPresent()) {
args.add("isMarkJoin");
args.add("true");

View File

@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.planner.RuntimeFilterId;
import org.apache.doris.qe.ConnectContext;
@ -86,14 +87,30 @@ public class PhysicalHashJoin<
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties, leftChild, rightChild);
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, hint,
markJoinSlotReference, groupExpression, logicalProperties, null, null, leftChild,
rightChild);
}
public PhysicalHashJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), logicalProperties, null, null, leftChild, rightChild);
}
private PhysicalHashJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts,
DistributeHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
@ -102,8 +119,9 @@ public class PhysicalHashJoin<
Statistics statistics,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties, physicalProperties, statistics, leftChild, rightChild);
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts,
markJoinConjuncts, hint, markJoinSlotReference, groupExpression, logicalProperties,
physicalProperties, statistics, leftChild, rightChild);
}
/**
@ -111,6 +129,12 @@ public class PhysicalHashJoin<
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
// TODO this function is only called by addShuffleJoinRequestProperty
// currently standalone mark join can only allow broadcast( we can remove this limitation after implement
// something like nullaware shuffle to broadcast nulls to all instances
// mark join with non-empty hash join conjuncts allow shuffle join by hash join conjuncts
Preconditions.checkState(!(isMarkJoin() && hashJoinConjuncts.isEmpty()),
"shouldn't call mark join's getHashConjunctsExprIds method for standalone mark join");
int size = hashJoinConjuncts.size();
List<ExprId> exprIds1 = new ArrayList<>(size);
@ -143,7 +167,7 @@ public class PhysicalHashJoin<
public PhysicalHashJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
PhysicalHashJoin newJoin = new PhysicalHashJoin<>(joinType, hashJoinConjuncts,
otherJoinConjuncts, hint, markJoinSlotReference,
otherJoinConjuncts, markJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), getLogicalProperties(), physicalProperties, statistics,
children.get(0), children.get(1));
if (groupExpression.isPresent()) {
@ -155,28 +179,32 @@ public class PhysicalHashJoin<
@Override
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression(
Optional<GroupExpression> groupExpression) {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, getLogicalProperties(), left(), right());
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
markJoinConjuncts, hint, markJoinSlotReference, groupExpression,
getLogicalProperties(), null, null, left(), right());
}
@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties.get(), children.get(0), children.get(1));
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
markJoinConjuncts, hint, markJoinSlotReference, groupExpression,
logicalProperties.get(), null, null, children.get(0), children.get(1));
}
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats(
PhysicalProperties physicalProperties, Statistics statistics) {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, getLogicalProperties(), physicalProperties, statistics, left(), right());
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
markJoinConjuncts, hint, markJoinSlotReference, groupExpression,
getLogicalProperties(), physicalProperties, statistics, left(), right());
}
@Override
public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<RuntimeFilterId> generator,
AbstractPhysicalJoin<?, ?> builderNode, Expression srcExpr, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
// currently, mark join doesn't support RF, so markJoinConjuncts is not processed here
if (RuntimeFilterGenerator.DENIED_JOIN_TYPES.contains(getJoinType()) || isMarkJoin()) {
if (builderNode instanceof PhysicalHashJoin) {
PhysicalHashJoin<?, ?> builderJoin = (PhysicalHashJoin<?, ?>) builderNode;
@ -253,6 +281,10 @@ public class PhysicalHashJoin<
.sorted().collect(Collectors.joining(" and ", " hashCondition=(", ")")));
builder.append(otherJoinConjuncts.stream().map(cond -> cond.shapeInfo())
.sorted().collect(Collectors.joining(" and ", " otherCondition=(", ")")));
if (!markJoinConjuncts.isEmpty()) {
builder.append(markJoinConjuncts.stream().map(cond -> cond.shapeInfo()).sorted()
.collect(Collectors.joining(" and ", " markCondition=(", ")")));
}
if (!runtimeFilters.isEmpty()) {
builder.append(" build RFs:").append(runtimeFilters.stream()
.map(rf -> rf.shapeInfo()).collect(Collectors.joining(";")));
@ -262,7 +294,8 @@ public class PhysicalHashJoin<
@Override
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> resetLogicalProperties() {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, null, physicalProperties, statistics, left(), right());
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
markJoinConjuncts, hint, markJoinSlotReference, groupExpression, null,
physicalProperties, statistics, left(), right());
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.statistics.Statistics;
@ -39,6 +40,7 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
@ -82,10 +84,8 @@ public class PhysicalNestedLoopJoin<
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts,
// nested loop join ignores join hints.
new DistributeHint(DistributeType.NONE), markJoinSlotReference,
groupExpression, logicalProperties, leftChild, rightChild);
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, markJoinSlotReference,
groupExpression, logicalProperties, null, null, leftChild, rightChild);
}
/**
@ -105,7 +105,36 @@ public class PhysicalNestedLoopJoin<
Statistics statistics,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts,
this(joinType, hashJoinConjuncts, otherJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, markJoinSlotReference,
groupExpression, logicalProperties, physicalProperties, statistics, leftChild, rightChild);
}
public PhysicalNestedLoopJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference,
Optional.empty(), logicalProperties, null, null, leftChild, rightChild);
}
private PhysicalNestedLoopJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
List<Expression> markJoinConjuncts,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
Statistics statistics,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
// nested loop join ignores join hints.
new DistributeHint(DistributeType.NONE), markJoinSlotReference,
groupExpression, logicalProperties, physicalProperties, statistics, leftChild, rightChild);
@ -132,7 +161,7 @@ public class PhysicalNestedLoopJoin<
public PhysicalNestedLoopJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
PhysicalNestedLoopJoin newJoin = new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, Optional.empty(),
hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference, Optional.empty(),
getLogicalProperties(), physicalProperties, statistics, children.get(0), children.get(1));
if (groupExpression.isPresent()) {
newJoin.setMutableState(MutableState.KEY_GROUP, groupExpression.get().getOwnerGroup().getGroupId().asInt());
@ -144,8 +173,8 @@ public class PhysicalNestedLoopJoin<
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression(
Optional<GroupExpression> groupExpression) {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference,
groupExpression, getLogicalProperties(), left(), right());
hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference,
groupExpression, getLogicalProperties(), null, null, left(), right());
}
@Override
@ -153,15 +182,15 @@ public class PhysicalNestedLoopJoin<
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, groupExpression,
logicalProperties.get(), children.get(0), children.get(1));
hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference, groupExpression,
logicalProperties.get(), null, null, children.get(0), children.get(1));
}
@Override
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats(
PhysicalProperties physicalProperties, Statistics statistics) {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, groupExpression,
hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference, groupExpression,
getLogicalProperties(), physicalProperties, statistics, left(), right());
}
@ -177,23 +206,39 @@ public class PhysicalNestedLoopJoin<
return bitMapRuntimeFilterConditions.isEmpty();
}
/**
* getConditionSlot
*/
public Set<Slot> getConditionSlot() {
// this function is called by rules which reject mark join
// so markJoinConjuncts is not processed here
Preconditions.checkState(!isMarkJoin(),
"shouldn't call mark join's getConditionSlot method");
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder("NestedLoopJoin");
builder.append("[").append(joinType).append("]");
otherJoinConjuncts.forEach(expr -> builder.append(expr.shapeInfo()));
if (!markJoinConjuncts.isEmpty()) {
builder.append(otherJoinConjuncts.stream().map(cond -> cond.shapeInfo()).sorted()
.collect(Collectors.joining(" and ", " otherCondition=(", ")")));
builder.append(markJoinConjuncts.stream().map(cond -> cond.shapeInfo()).sorted()
.collect(Collectors.joining(" and ", " markCondition=(", ")")));
} else {
otherJoinConjuncts.forEach(expr -> builder.append(expr.shapeInfo()));
}
return builder.toString();
}
@Override
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> resetLogicalProperties() {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, groupExpression,
hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, markJoinSlotReference, groupExpression,
null, physicalProperties, statistics, left(), right());
}
}

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
@ -51,6 +52,8 @@ import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewri
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.NumericType;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
@ -58,6 +61,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.Arrays;
@ -184,9 +188,9 @@ public class ExpressionUtils {
*/
public static Expression combine(Class<? extends Expression> type, Collection<Expression> expressions) {
/*
* (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E)
* ▲ ▲ ▲ ▲ ▲
* │ │ │ │ │
* (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E)
* ▲ ▲ ▲ ▲ ▲
* │ │ │ │ │
* A B C D E ──► A B C D E ──► (AB) (CD) E ──► ((AB)(CD)) E ──► (((AB)(CD))E)
*/
Preconditions.checkArgument(type == And.class || type == Or.class);
@ -219,7 +223,8 @@ public class ExpressionUtils {
}
/**
* Replace the slot in expressions with the lineage identifier from specifiedbaseTable sets or target table types
* Replace the slot in expressions with the lineage identifier from
* specifiedbaseTable sets or target table types
* example as following:
* select a + 10 as a1, d from (
* select b - 5 as a, d from table
@ -234,11 +239,10 @@ public class ExpressionUtils {
if (expressions.isEmpty()) {
return ImmutableList.of();
}
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(
expressions.stream().map(Expression.class::cast).collect(Collectors.toList()),
targetTypes,
tableIdentifiers);
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext = new ExpressionLineageReplacer.ExpressionReplaceContext(
expressions.stream().map(Expression.class::cast).collect(Collectors.toList()),
targetTypes,
tableIdentifiers);
plan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
// Replace expressions by expression map
@ -272,8 +276,10 @@ public class ExpressionUtils {
}
/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
* Check whether the input expression is a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* <p>
* for example:
* - SlotReference to a column:
@ -283,7 +289,8 @@ public class ExpressionUtils {
* cast(cast(int_col as long) as string)
*
* @param expr input expression
* @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot.
* @return Return Optional[ExprId] of underlying slot reference if input
* expression is a slot or cast on slot.
* Otherwise, return empty optional result.
*/
public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
@ -291,8 +298,10 @@ public class ExpressionUtils {
}
/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
* Check whether the input expression is a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
*/
public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
while (expr instanceof Cast) {
@ -307,7 +316,8 @@ public class ExpressionUtils {
}
/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as
* name]
*/
public static Map<Slot, Expression> generateReplaceMap(List<NamedExpression> namedExpressions) {
return namedExpressions
@ -317,14 +327,14 @@ public class ExpressionUtils {
Collectors.toMap(
NamedExpression::toSlot,
// Avoid cast to alias, retrieving the first child expression.
alias -> alias.child(0)
)
);
alias -> alias.child(0)));
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* Replace expression node in the expression tree by `replaceMap` in top-down
* manner.
* For example.
*
* <pre>
* input expression: a > 1
* replaceMap: a -> b + c
@ -365,7 +375,8 @@ public class ExpressionUtils {
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* Replace expression node in the expression tree by `replaceMap` in top-down
* manner.
*/
public static List<NamedExpression> replaceNamedExpressions(List<NamedExpression> namedExpressions,
Map<? extends Expression, ? extends Expression> replaceMap) {
@ -468,6 +479,75 @@ public class ExpressionUtils {
return children.stream().allMatch(c -> c instanceof NullLiteral);
}
/**
* canInferNotNullForMarkSlot
*/
public static boolean canInferNotNullForMarkSlot(Expression predicate) {
/*
* assume predicate is from LogicalFilter
* the idea is replacing each mark join slot with null and false literal then
* run FoldConstant rule
* if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false
* values)
* the mark slot can be non-nullable boolean
* and in semi join, we can safely change the mark conjunct to hash conjunct
*/
ImmutableList<Literal> literals = ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE);
List<MarkJoinSlotReference> markJoinSlotReferenceList = ((Set<MarkJoinSlotReference>) predicate
.collect(MarkJoinSlotReference.class::isInstance)).stream()
.collect(Collectors.toList());
int markSlotSize = markJoinSlotReferenceList.size();
int maxMarkSlotCount = 4;
// if the conjunct has mark slot, and maximum 4 mark slots(for performance)
if (markSlotSize > 0 && markSlotSize <= maxMarkSlotCount) {
Map<Expression, Expression> replaceMap = Maps.newHashMap();
boolean meetTrue = false;
boolean meetNullOrFalse = false;
/*
* markSlotSize = 1 -> loopCount = 2 ---- 0, 1
* markSlotSize = 2 -> loopCount = 4 ---- 00, 01, 10, 11
* markSlotSize = 3 -> loopCount = 8 ---- 000, 001, 010, 011, 100, 101, 110, 111
* markSlotSize = 4 -> loopCount = 16 ---- 0000, 0001, ... 1111
*/
int loopCount = 2 << markSlotSize;
for (int i = 0; i < loopCount; ++i) {
replaceMap.clear();
/*
* replace each mark slot with null or false
* literals.get(0) -> NullLiteral(BooleanType.INSTANCE)
* literals.get(1) -> BooleanLiteral.FALSE
*/
for (int j = 0; j < markSlotSize; ++j) {
replaceMap.put(markJoinSlotReferenceList.get(j), literals.get((i >> j) & 1));
}
Expression evalResult = FoldConstantRule.INSTANCE.rewrite(
ExpressionUtils.replace(predicate, replaceMap),
new ExpressionRewriteContext(null));
if (evalResult.equals(BooleanLiteral.TRUE)) {
if (meetNullOrFalse) {
return false;
} else {
meetTrue = true;
}
} else if ((isNullOrFalse(evalResult))) {
if (meetTrue) {
return false;
} else {
meetNullOrFalse = true;
}
}
}
}
return true;
}
private static boolean isNullOrFalse(Expression expression) {
return expression.isNullLiteral() || expression.equals(BooleanLiteral.FALSE);
}
/**
* infer notNulls slot from predicate
*/
@ -502,7 +582,8 @@ public class ExpressionUtils {
}
/**
* infer notNulls slot from predicate but these slots must be in the given slots.
* infer notNulls slot from predicate but these slots must be in the given
* slots.
*/
public static Set<Expression> inferNotNull(Set<Expression> predicates, Set<Slot> slots,
CascadesContext cascadesContext) {
@ -666,18 +747,18 @@ public class ExpressionUtils {
*/
public static boolean checkSlotConstant(Slot slot, Set<Expression> predicates) {
return predicates.stream().anyMatch(predicate -> {
if (predicate instanceof EqualTo) {
EqualTo equalTo = (EqualTo) predicate;
return (equalTo.left() instanceof Literal && equalTo.right().equals(slot))
|| (equalTo.right() instanceof Literal && equalTo.left().equals(slot));
}
return false;
}
);
if (predicate instanceof EqualTo) {
EqualTo equalTo = (EqualTo) predicate;
return (equalTo.left() instanceof Literal && equalTo.right().equals(slot))
|| (equalTo.right() instanceof Literal && equalTo.left().equals(slot));
}
return false;
});
}
/**
* Check the expression is inferred or not, if inferred return true, nor return false
* Check the expression is inferred or not, if inferred return true, nor return
* false
*/
public static boolean isInferred(Expression expression) {
return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {

View File

@ -58,10 +58,17 @@ import java.util.stream.Collectors;
* Utils for join
*/
public class JoinUtils {
/**
* couldShuffle
*/
public static boolean couldShuffle(Join join) {
// Cross-join and Null-Aware-Left-Anti-Join only can be broadcast join.
// Because mark join would consider null value from both build and probe side, so must use broadcast join too.
return !(join.getJoinType().isCrossJoin() || join.getJoinType().isNullAwareLeftAntiJoin() || join.isMarkJoin());
// standalone mark join would consider null value from both build and probe side, so must use broadcast join.
// mark join with hash conjuncts can shuffle by hash conjuncts
// TODO actually standalone mark join can use shuffle, but need do nullaware shuffle to broadcast null value
// to all instances
return !(join.getJoinType().isCrossJoin() || join.getJoinType().isNullAwareLeftAntiJoin()
|| (!join.getMarkJoinConjuncts().isEmpty() && join.getHashJoinConjuncts().isEmpty()));
}
public static boolean couldBroadcast(Join join) {
@ -173,10 +180,14 @@ public class JoinUtils {
}
public static boolean shouldNestedLoopJoin(Join join) {
return join.getHashJoinConjuncts().isEmpty();
// currently, mark join conjuncts only has one conjunct, so we always get the first element here
return join.getHashJoinConjuncts().isEmpty() && (join.getMarkJoinConjuncts().isEmpty()
|| !(join.getMarkJoinConjuncts().get(0) instanceof EqualPredicate));
}
public static boolean shouldNestedLoopJoin(JoinType joinType, List<Expression> hashConjuncts) {
// this function is only called by hyper graph, which reject mark join
// so mark join is not processed here
return hashConjuncts.isEmpty();
}

View File

@ -76,6 +76,8 @@ public class HashJoinNode extends JoinNodeBase {
// join conjuncts from the JOIN clause that aren't equi-join predicates
private List<Expr> otherJoinConjuncts;
private List<Expr> markJoinConjuncts;
private DistributionMode distrMode;
private boolean isColocate = false; //the flag for colocate join
private String colocateReason = ""; // if can not do colocate join, set reason here
@ -173,6 +175,61 @@ public class HashJoinNode extends JoinNodeBase {
vSrcToOutputSMap = new ExprSubstitutionMap(srcToOutputList, Collections.emptyList());
}
public HashJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, JoinOperator joinOp,
List<Expr> eqJoinConjuncts, List<Expr> otherJoinConjuncts, List<Expr> markJoinConjuncts,
List<Expr> srcToOutputList, TupleDescriptor intermediateTuple,
TupleDescriptor outputTuple, boolean isMarkJoin) {
super(id, "HASH JOIN", StatisticalType.HASH_JOIN_NODE, joinOp, isMarkJoin);
Preconditions.checkArgument((eqJoinConjuncts != null && !eqJoinConjuncts.isEmpty())
|| (markJoinConjuncts != null && !markJoinConjuncts.isEmpty()));
Preconditions.checkArgument(otherJoinConjuncts != null);
tblRefIds.addAll(outer.getTblRefIds());
tblRefIds.addAll(inner.getTblRefIds());
if (joinOp.equals(JoinOperator.LEFT_ANTI_JOIN) || joinOp.equals(JoinOperator.LEFT_SEMI_JOIN)
|| joinOp.equals(JoinOperator.NULL_AWARE_LEFT_ANTI_JOIN)) {
tupleIds.addAll(outer.getTupleIds());
} else if (joinOp.equals(JoinOperator.RIGHT_ANTI_JOIN) || joinOp.equals(JoinOperator.RIGHT_SEMI_JOIN)) {
tupleIds.addAll(inner.getTupleIds());
} else {
tupleIds.addAll(outer.getTupleIds());
tupleIds.addAll(inner.getTupleIds());
}
for (Expr eqJoinPredicate : eqJoinConjuncts) {
Preconditions.checkArgument(eqJoinPredicate instanceof BinaryPredicate);
BinaryPredicate eqJoin = (BinaryPredicate) eqJoinPredicate;
if (eqJoin.getOp().equals(BinaryPredicate.Operator.EQ_FOR_NULL)) {
Preconditions.checkArgument(eqJoin.getChildren().size() == 2);
if (!eqJoin.getChild(0).isNullable() || !eqJoin.getChild(1).isNullable()) {
eqJoin.setOp(BinaryPredicate.Operator.EQ);
}
}
this.eqJoinConjuncts.add(eqJoin);
}
this.distrMode = DistributionMode.NONE;
this.otherJoinConjuncts = otherJoinConjuncts;
this.markJoinConjuncts = markJoinConjuncts;
children.add(outer);
children.add(inner);
// Inherits all the nullable tuple from the children
// Mark tuples that form the "nullable" side of the outer join as nullable.
nullableTupleIds.addAll(inner.getNullableTupleIds());
nullableTupleIds.addAll(outer.getNullableTupleIds());
if (joinOp.equals(JoinOperator.FULL_OUTER_JOIN)) {
nullableTupleIds.addAll(outer.getTupleIds());
nullableTupleIds.addAll(inner.getTupleIds());
} else if (joinOp.equals(JoinOperator.LEFT_OUTER_JOIN)) {
nullableTupleIds.addAll(inner.getTupleIds());
} else if (joinOp.equals(JoinOperator.RIGHT_OUTER_JOIN)) {
nullableTupleIds.addAll(outer.getTupleIds());
}
vIntermediateTupleDescList = Lists.newArrayList(intermediateTuple);
vOutputTupleDesc = outputTuple;
vSrcToOutputSMap = new ExprSubstitutionMap(srcToOutputList, Collections.emptyList());
}
public List<BinaryPredicate> getEqJoinConjuncts() {
return eqJoinConjuncts;
}
@ -717,6 +774,32 @@ public class HashJoinNode extends JoinNodeBase {
msg.hash_join_node.addToOtherJoinConjuncts(e.treeToThrift());
}
if (markJoinConjuncts != null) {
if (eqJoinConjuncts.isEmpty()) {
Preconditions.checkState(joinOp == JoinOperator.LEFT_SEMI_JOIN
|| joinOp == JoinOperator.LEFT_ANTI_JOIN);
if (joinOp == JoinOperator.LEFT_SEMI_JOIN) {
msg.hash_join_node.join_op = JoinOperator.NULL_AWARE_LEFT_SEMI_JOIN.toThrift();
} else if (joinOp == JoinOperator.LEFT_ANTI_JOIN) {
msg.hash_join_node.join_op = JoinOperator.NULL_AWARE_LEFT_ANTI_JOIN.toThrift();
}
// because eqJoinConjuncts mustn't be empty in thrift
// we have to use markJoinConjuncts instead
for (Expr e : markJoinConjuncts) {
Preconditions.checkState(e instanceof BinaryPredicate,
"mark join conjunct must be BinaryPredicate");
TEqJoinCondition eqJoinCondition = new TEqJoinCondition(
e.getChild(0).treeToThrift(), e.getChild(1).treeToThrift());
eqJoinCondition.setOpcode(((BinaryPredicate) e).getOp().getOpcode());
msg.hash_join_node.addToEqJoinConjuncts(eqJoinCondition);
}
} else {
for (Expr e : markJoinConjuncts) {
msg.hash_join_node.addToMarkJoinConjuncts(e.treeToThrift());
}
}
}
if (hashOutputSlotIds != null) {
for (SlotId slotId : hashOutputSlotIds) {
msg.hash_join_node.addToHashOutputSlotIds(slotId.asInt());
@ -772,6 +855,10 @@ public class HashJoinNode extends JoinNodeBase {
output.append(detailPrefix).append("other join predicates: ")
.append(getExplainString(otherJoinConjuncts)).append("\n");
}
if (markJoinConjuncts != null && !markJoinConjuncts.isEmpty()) {
output.append(detailPrefix).append("mark join predicates: ")
.append(getExplainString(markJoinConjuncts)).append("\n");
}
if (!conjuncts.isEmpty()) {
output.append(detailPrefix).append("other predicates: ").append(getExplainString(conjuncts)).append("\n");
}
@ -849,10 +936,18 @@ public class HashJoinNode extends JoinNodeBase {
this.otherJoinConjuncts = otherJoinConjuncts;
}
public void setMarkJoinConjuncts(List<Expr> markJoinConjuncts) {
this.markJoinConjuncts = markJoinConjuncts;
}
public List<Expr> getOtherJoinConjuncts() {
return otherJoinConjuncts;
}
public List<Expr> getMarkJoinConjuncts() {
return markJoinConjuncts;
}
SlotRef getMappedInputSlotRef(SlotRef slotRef) {
if (outputSmap != null) {
Expr mappedExpr = outputSmap.mappingForRhsExpr(slotRef);

View File

@ -65,6 +65,8 @@ public class NestedLoopJoinNode extends JoinNodeBase {
private List<Expr> runtimeFilterExpr = Lists.newArrayList();
private List<Expr> joinConjuncts;
private List<Expr> markJoinConjuncts;
public NestedLoopJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRef innerRef) {
super(id, "NESTED LOOP JOIN", StatisticalType.NESTED_LOOP_JOIN_NODE, outer, inner, innerRef);
tupleIds.addAll(outer.getOutputTupleIds());
@ -81,6 +83,10 @@ public class NestedLoopJoinNode extends JoinNodeBase {
this.joinConjuncts = joinConjuncts;
}
public void setMarkJoinConjuncts(List<Expr> markJoinConjuncts) {
this.markJoinConjuncts = markJoinConjuncts;
}
@Override
protected List<SlotId> computeSlotIdsForJoinConjuncts(Analyzer analyzer) {
// conjunct
@ -171,6 +177,12 @@ public class NestedLoopJoinNode extends JoinNodeBase {
for (Expr conjunct : joinConjuncts) {
msg.nested_loop_join_node.addToJoinConjuncts(conjunct.treeToThrift());
}
if (markJoinConjuncts != null) {
for (Expr conjunct : markJoinConjuncts) {
msg.nested_loop_join_node.addToMarkJoinConjuncts(conjunct.treeToThrift());
}
}
msg.nested_loop_join_node.setIsMark(isMarkJoin());
if (vSrcToOutputSMap != null) {
for (int i = 0; i < vSrcToOutputSMap.size(); i++) {
@ -230,6 +242,11 @@ public class NestedLoopJoinNode extends JoinNodeBase {
output.append(detailPrefix).append("join conjuncts: ").append(getExplainString(joinConjuncts)).append("\n");
}
if (markJoinConjuncts != null && !markJoinConjuncts.isEmpty()) {
output.append(detailPrefix).append("mark join predicates: ")
.append(getExplainString(markJoinConjuncts)).append("\n");
}
if (!conjuncts.isEmpty()) {
output.append(detailPrefix).append("predicates: ").append(getExplainString(conjuncts)).append("\n");
}