[fix](nereids)should distinguish hash and other conjuncts for outer join in MultiJoin (#53184)

pick from master https://github.com/apache/doris/pull/50378 and
https://github.com/apache/doris/pull/53051
### What problem does this PR solve?

Issue Number: close #xxx

Related PR: #xxx

Problem Summary:

### Release note

None

### Check List (For Author)

- Test <!-- At least one of them must be included. -->
    - [ ] Regression test
    - [ ] Unit Test
    - [ ] Manual test (add detailed scripts or steps below)
    - [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
        - [ ] Previous test can cover this change.
        - [ ] No code files have been changed.
        - [ ] Other reason <!-- Add your reason?  -->

- Behavior changed:
    - [ ] No.
    - [ ] Yes. <!-- Explain the behavior change -->

- Does this need documentation?
    - [ ] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->

### Check List (For Reviewer who merge this PR)

- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
This commit is contained in:
starocean999
2025-07-16 09:30:04 +08:00
committed by GitHub
parent 11142eee3a
commit e2fb2566dd
17 changed files with 1758 additions and 101 deletions

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
@ -390,9 +391,24 @@ public class SubqueryToApply implements AnalysisRuleFactory {
ctx.setSubqueryExprIsAnalyzed(subquery, true);
boolean needAddScalarSubqueryOutputToProjects = isConjunctContainsScalarSubqueryOutput(
subquery, conjunct, isProject, singleSubquery);
LogicalApply.SubQueryType subQueryType;
boolean isNot = false;
Optional<Expression> compareExpr = Optional.empty();
if (subquery instanceof InSubquery) {
subQueryType = LogicalApply.SubQueryType.IN_SUBQUERY;
isNot = ((InSubquery) subquery).isNot();
compareExpr = Optional.of(((InSubquery) subquery).getCompareExpr());
} else if (subquery instanceof Exists) {
subQueryType = LogicalApply.SubQueryType.EXITS_SUBQUERY;
isNot = ((Exists) subquery).isNot();
} else if (subquery instanceof ScalarSubquery) {
subQueryType = LogicalApply.SubQueryType.SCALAR_SUBQUERY;
} else {
throw new AnalysisException(String.format("Unsupported subquery : %s", subquery.toString()));
}
LogicalApply newApply = new LogicalApply(
subquery.getCorrelateSlots(),
subquery, Optional.empty(),
subQueryType, isNot, compareExpr, subquery.getTypeCoercionExpr(), Optional.empty(),
subqueryToMarkJoinSlot.get(subquery),
needAddScalarSubqueryOutputToProjects, isProject, isMarkJoinSlotNotNull,
childPlan, subquery.getQueryPlan());

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@ -94,7 +93,7 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory {
private Plan correlatedToJoin(LogicalApply<?, ?> apply) {
Optional<Expression> correlationFilter = apply.getCorrelationFilter();
if (((Exists) apply.getSubqueryExpr()).isNot()) {
if (apply.isNot()) {
return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, ExpressionUtils.EMPTY_CONDITION,
correlationFilter.map(ExpressionUtils::extractConjunction).orElse(ExpressionUtils.EMPTY_CONDITION),
new DistributeHint(DistributeType.NONE),
@ -110,7 +109,7 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory {
}
private Plan unCorrelatedToJoin(LogicalApply<?, ?> unapply) {
if (((Exists) unapply.getSubqueryExpr()).isNot()) {
if (unapply.isNot()) {
return unCorrelatedNotExist(unapply);
} else {
return unCorrelatedExist(unapply);

View File

@ -24,7 +24,6 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
@ -36,7 +35,6 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
@ -82,9 +80,9 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
List<NamedExpression> outputExpressions = Lists.newArrayList(alias);
LogicalAggregate agg = new LogicalAggregate(groupExpressions, outputExpressions, apply.right());
Expression compareExpr = ((InSubquery) apply.getSubqueryExpr()).getCompareExpr();
Expression compareExpr = apply.getCompareExpr().get();
Expression expr = new BitmapContains(agg.getOutput().get(0), compareExpr);
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
if (apply.isNot()) {
expr = new Not(expr);
}
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
@ -95,19 +93,18 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
}
//in-predicate to equal
InSubquery inSubquery = ((InSubquery) apply.getSubqueryExpr());
Expression predicate;
Expression left = inSubquery.getCompareExpr();
Expression left = apply.getCompareExpr().get();
// TODO: trick here, because when deep copy logical plan the apply right child
// is not same with query plan in subquery expr, since the scan node copy twice
Expression right = inSubquery.getSubqueryOutput((LogicalPlan) apply.right());
Expression right = apply.getSubqueryOutput();
if (apply.isMarkJoin()) {
List<Expression> joinConjuncts = apply.getCorrelationFilter().isPresent()
? ExpressionUtils.extractConjunction(apply.getCorrelationFilter().get())
: Lists.newArrayList();
predicate = new EqualTo(left, right);
List<Expression> markConjuncts = Lists.newArrayList(predicate);
if (!predicate.nullable() || (apply.isMarkJoinSlotNotNull() && !inSubquery.isNot())) {
if (!predicate.nullable() || (apply.isMarkJoinSlotNotNull() && !apply.isNot())) {
// we can merge mark conjuncts with hash conjuncts in 2 scenarios
// 1. the mark join predicate is not nullable, so no null value would be produced
// 2. semi join with non-nullable mark slot.
@ -117,7 +114,7 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
markConjuncts.clear();
}
return new LogicalJoin<>(
inSubquery.isNot() ? JoinType.LEFT_ANTI_JOIN : JoinType.LEFT_SEMI_JOIN,
apply.isNot() ? JoinType.LEFT_ANTI_JOIN : JoinType.LEFT_SEMI_JOIN,
Lists.newArrayList(), joinConjuncts, markConjuncts,
new DistributeHint(DistributeType.NONE), apply.getMarkJoinSlotReference(),
apply.children(), null);
@ -127,7 +124,7 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
// so we need check both correlated slot and correlation filter exists
// before creating LogicalJoin node
if (apply.isCorrelated() && apply.getCorrelationFilter().isPresent()) {
if (inSubquery.isNot()) {
if (apply.isNot()) {
predicate = ExpressionUtils.and(ExpressionUtils.or(new EqualTo(left, right),
new IsNull(left), new IsNull(right)),
apply.getCorrelationFilter().get());
@ -140,7 +137,7 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
}
List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
if (inSubquery.isNot()) {
if (apply.isNot()) {
return new LogicalJoin<>(
predicate.nullable() && !apply.isCorrelated()
? JoinType.NULL_AWARE_LEFT_ANTI_JOIN
@ -159,6 +156,6 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
private boolean needBitmapUnion(LogicalApply<Plan, Plan> apply) {
return apply.right().getOutput().get(0).getDataType().isBitmapType()
&& !((InSubquery) apply.getSubqueryExpr()).getCompareExpr().getDataType().isBitmapType();
&& !apply.getCompareExpr().get().getDataType().isBitmapType();
}
}

View File

@ -69,14 +69,16 @@ public class MultiJoin extends AbstractLogicalPlan implements BlockFuncDepsPropa
// MultiJoin just contains one OUTER/SEMI/ANTI.
private final JoinType joinType;
// When contains one OUTER/SEMI/ANTI join, keep separately its condition.
private final List<Expression> notInnerJoinConditions;
private final List<Expression> notInnerJoinHashConditions;
private final List<Expression> notInnerJoinOtherConditions;
public MultiJoin(List<Plan> inputs, List<Expression> joinFilter, JoinType joinType,
List<Expression> notInnerJoinConditions) {
List<Expression> notInnerJoinHashConditions, List<Expression> notInnerJoinOtherConditions) {
super(PlanType.LOGICAL_MULTI_JOIN, inputs);
this.joinFilter = Objects.requireNonNull(joinFilter);
this.joinType = joinType;
this.notInnerJoinConditions = Objects.requireNonNull(notInnerJoinConditions);
this.notInnerJoinHashConditions = Objects.requireNonNull(notInnerJoinHashConditions);
this.notInnerJoinOtherConditions = Objects.requireNonNull(notInnerJoinOtherConditions);
}
public JoinType getJoinType() {
@ -87,13 +89,17 @@ public class MultiJoin extends AbstractLogicalPlan implements BlockFuncDepsPropa
return joinFilter;
}
public List<Expression> getNotInnerJoinConditions() {
return notInnerJoinConditions;
public List<Expression> getNotInnerHashJoinConditions() {
return notInnerJoinHashConditions;
}
public List<Expression> getNotInnerOtherJoinConditions() {
return notInnerJoinOtherConditions;
}
@Override
public MultiJoin withChildren(List<Plan> children) {
return new MultiJoin(children, joinFilter, joinType, notInnerJoinConditions);
return new MultiJoin(children, joinFilter, joinType, notInnerJoinHashConditions, notInnerJoinOtherConditions);
}
@Override
@ -160,7 +166,8 @@ public class MultiJoin extends AbstractLogicalPlan implements BlockFuncDepsPropa
public List<? extends Expression> getExpressions() {
return new Builder<Expression>()
.addAll(joinFilter)
.addAll(notInnerJoinConditions)
.addAll(notInnerJoinHashConditions)
.addAll(notInnerJoinOtherConditions)
.build();
}
@ -180,7 +187,8 @@ public class MultiJoin extends AbstractLogicalPlan implements BlockFuncDepsPropa
return Utils.toSqlString("MultiJoin",
"joinType", joinType,
"joinFilter", joinFilter,
"notInnerJoinConditions", notInnerJoinConditions
"notInnerHashJoinConditions", notInnerJoinHashConditions,
"notInnerOtherJoinConditions", notInnerJoinOtherConditions
);
}
}

View File

@ -18,12 +18,10 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
@ -69,7 +67,10 @@ public class PullUpCteAnchor extends DefaultPlanRewriter<List<LogicalCTEProducer
List<LogicalCTEProducer<Plan>> producers) {
List<LogicalCTEProducer<Plan>> childProducers = Lists.newArrayList();
Plan child = cteProducer.child().accept(this, childProducers);
LogicalCTEProducer<Plan> newProducer = (LogicalCTEProducer<Plan>) cteProducer.withChildren(child);
LogicalCTEProducer<Plan> newProducer = (LogicalCTEProducer<Plan>) cteProducer;
if (child != cteProducer.child()) {
newProducer = (LogicalCTEProducer<Plan>) cteProducer.withChildren(child);
}
// because current producer relay on it child's producers, so add current producer first.
producers.add(newProducer);
producers.addAll(childProducers);
@ -79,17 +80,7 @@ public class PullUpCteAnchor extends DefaultPlanRewriter<List<LogicalCTEProducer
@Override
public Plan visitLogicalApply(LogicalApply<? extends Plan, ? extends Plan> apply,
List<LogicalCTEProducer<Plan>> producers) {
SubqueryExpr subqueryExpr = apply.getSubqueryExpr();
PullUpCteAnchor pullSubqueryExpr = new PullUpCteAnchor();
List<LogicalCTEProducer<Plan>> subqueryExprProducers = Lists.newArrayList();
Plan newPlanInExpr = pullSubqueryExpr.rewriteRoot(subqueryExpr.getQueryPlan(), subqueryExprProducers);
while (newPlanInExpr instanceof LogicalCTEAnchor) {
newPlanInExpr = ((LogicalCTEAnchor<?, ?>) newPlanInExpr).right();
}
SubqueryExpr newSubqueryExpr = subqueryExpr.withSubquery((LogicalPlan) newPlanInExpr);
Plan newApplyLeft = apply.left().accept(this, producers);
Plan applyRight = apply.right();
PullUpCteAnchor pullApplyRight = new PullUpCteAnchor();
List<LogicalCTEProducer<Plan>> childProducers = Lists.newArrayList();
@ -98,7 +89,10 @@ public class PullUpCteAnchor extends DefaultPlanRewriter<List<LogicalCTEProducer
newApplyRight = ((LogicalCTEAnchor<?, ?>) newApplyRight).right();
}
producers.addAll(childProducers);
return apply.withSubqueryExprAndChildren(newSubqueryExpr,
ImmutableList.of(newApplyLeft, newApplyRight));
if (newApplyLeft != apply.left() || newApplyRight != apply.right()) {
return apply.withChildren(ImmutableList.of(newApplyLeft, newApplyRight));
} else {
return apply;
}
}
}

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@ -60,7 +59,7 @@ public class PullUpProjectUnderApply extends OneRewriteRuleFactory {
LogicalProject<Plan> project = apply.right();
Plan newCorrelate = apply.withChildren(apply.left(), project.child());
List<NamedExpression> newProjects = new ArrayList<>(apply.left().getOutput());
if (apply.getSubqueryExpr() instanceof ScalarSubquery) {
if (apply.isScalar()) {
Preconditions.checkState(project.getProjects().size() == 1,
"ScalarSubquery should only have one output column");
newProjects.add(project.getProjects().get(0));

View File

@ -112,7 +112,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
List<Plan> inputs = Lists.newArrayList();
List<Expression> joinFilter = Lists.newArrayList();
List<Expression> notInnerJoinConditions = Lists.newArrayList();
List<Expression> notInnerHashJoinConditions = Lists.newArrayList();
List<Expression> notInnerOtherJoinConditions = Lists.newArrayList();
LogicalJoin<?, ?> join;
// Implicit rely on {rule: MergeFilters}, so don't exist filter--filter--join.
@ -128,8 +129,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
joinFilter.addAll(join.getHashJoinConjuncts());
joinFilter.addAll(join.getOtherJoinConjuncts());
} else {
notInnerJoinConditions.addAll(join.getHashJoinConjuncts());
notInnerJoinConditions.addAll(join.getOtherJoinConjuncts());
notInnerHashJoinConditions.addAll(join.getHashJoinConjuncts());
notInnerOtherJoinConditions.addAll(join.getOtherJoinConjuncts());
}
// recursively convert children.
@ -162,7 +163,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
inputs,
joinFilter,
join.getJoinType(),
notInnerJoinConditions);
notInnerHashJoinConditions,
notInnerOtherJoinConditions);
}
/**
@ -254,7 +256,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION), planToHintType);
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION), planToHintType);
} else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
left = multiJoinHandleChildren.child(0);
Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet();
@ -268,7 +270,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION), planToHintType);
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION), planToHintType);
} else {
remainingFilter = multiJoin.getJoinFilter();
Preconditions.checkState(multiJoinHandleChildren.arity() == 2);
@ -285,7 +287,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
return PlanUtils.filterOrSelf(ImmutableSet.copyOf(remainingFilter), new LogicalJoin<>(
multiJoinHandleChildren.getJoinType(),
ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(),
multiJoinHandleChildren.getNotInnerHashJoinConditions(),
multiJoinHandleChildren.getNotInnerOtherJoinConditions(),
new DistributeHint(DistributeType.fromRightPlanHintType(
planToHintType.getOrDefault(right, JoinDistributeType.NONE))),
Optional.empty(),

View File

@ -58,7 +58,7 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
private Plan unCorrelatedToJoin(LogicalApply apply) {
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(new AssertNumRowsElement(1,
apply.getSubqueryExpr().toString(), AssertNumRowsElement.Assertion.EQ),
apply.right().toString(), AssertNumRowsElement.Assertion.EQ),
(LogicalPlan) apply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,

View File

@ -115,7 +115,8 @@ public class UnCorrelatedApplyAggregateFilter implements RewriteRuleFactory {
correlatedPredicate = ExpressionUtils.replace(correlatedPredicate, unCorrelatedExprToSlot);
LogicalAggregate newAgg = new LogicalAggregate<>(newGroupby, newAggOutput,
PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()));
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryType(), apply.isNot(),
apply.getCompareExpr(), apply.getTypeCoercionExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(), apply.isInProject(),
apply.isMarkJoinSlotNotNull(), apply.left(),

View File

@ -66,7 +66,8 @@ public class UnCorrelatedApplyFilter extends OneRewriteRuleFactory {
}
Plan child = PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child());
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryType(), apply.isNot(),
apply.getCompareExpr(), apply.getTypeCoercionExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(),
apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), child);

View File

@ -87,7 +87,8 @@ public class UnCorrelatedApplyProjectFilter extends OneRewriteRuleFactory {
.map(NamedExpression.class::cast)
.forEach(projects::add);
LogicalProject newProject = project.withProjectsAndChild(projects, child);
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryType(), apply.isNot(),
apply.getCompareExpr(), apply.getTypeCoercionExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.isNeedAddSubOutputToProjects(),
apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), newProject);

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -129,13 +128,16 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext
List<Expression> correlationSlot = apply.getCorrelationSlot().stream()
.map(s -> ExpressionDeepCopier.INSTANCE.deepCopy(s, context))
.collect(ImmutableList.toImmutableList());
SubqueryExpr subqueryExpr = (SubqueryExpr) ExpressionDeepCopier.INSTANCE
.deepCopy(apply.getSubqueryExpr(), context);
Optional<Expression> compareExpr = apply.getCompareExpr()
.map(f -> ExpressionDeepCopier.INSTANCE.deepCopy(f, context));
Optional<Expression> typeCoercionExpr = apply.getTypeCoercionExpr()
.map(f -> ExpressionDeepCopier.INSTANCE.deepCopy(f, context));
Optional<Expression> correlationFilter = apply.getCorrelationFilter()
.map(f -> ExpressionDeepCopier.INSTANCE.deepCopy(f, context));
Optional<MarkJoinSlotReference> markJoinSlotReference = apply.getMarkJoinSlotReference()
.map(m -> (MarkJoinSlotReference) ExpressionDeepCopier.INSTANCE.deepCopy(m, context));
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
return new LogicalApply<>(correlationSlot, apply.getSubqueryType(), apply.isNot(),
compareExpr, typeCoercionExpr, correlationFilter,
markJoinSlotReference, apply.isNeedAddSubOutputToProjects(), apply.isInProject(),
apply.isMarkJoinSlotNotNull(), left, right);
}

View File

@ -19,13 +19,9 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.PropagateFuncDeps;
@ -46,11 +42,26 @@ import java.util.Optional;
*/
public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
extends LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> implements PropagateFuncDeps {
/**
* SubQueryType
*/
public enum SubQueryType {
IN_SUBQUERY,
EXITS_SUBQUERY,
SCALAR_SUBQUERY
}
private final SubQueryType subqueryType;
private final boolean isNot;
// only for InSubquery
private final Optional<Expression> compareExpr;
// only for InSubquery
private final Optional<Expression> typeCoercionExpr;
// correlation column
private final List<Expression> correlationSlot;
// original subquery
private final SubqueryExpr subqueryExpr;
// correlation Conjunction
private final Optional<Expression> correlationFilter;
// The slot replaced by the subquery in MarkJoin
@ -72,16 +83,23 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
private LogicalApply(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
List<Expression> correlationSlot,
SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter,
List<Expression> correlationSlot, SubQueryType subqueryType, boolean isNot,
Optional<Expression> compareExpr, Optional<Expression> typeCoercionExpr,
Optional<Expression> correlationFilter,
Optional<MarkJoinSlotReference> markJoinSlotReference,
boolean needAddSubOutputToProjects,
boolean inProject,
boolean isMarkJoinSlotNotNull,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties, leftChild, rightChild);
if (subqueryType == SubQueryType.IN_SUBQUERY) {
Preconditions.checkArgument(compareExpr.isPresent(), "InSubquery must have compareExpr");
}
this.correlationSlot = correlationSlot == null ? ImmutableList.of() : ImmutableList.copyOf(correlationSlot);
this.subqueryExpr = Objects.requireNonNull(subqueryExpr, "subquery can not be null");
this.subqueryType = subqueryType;
this.isNot = isNot;
this.compareExpr = compareExpr;
this.typeCoercionExpr = typeCoercionExpr;
this.correlationFilter = correlationFilter;
this.markJoinSlotReference = markJoinSlotReference;
this.needAddSubOutputToProjects = needAddSubOutputToProjects;
@ -89,13 +107,26 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
this.isMarkJoinSlotNotNull = isMarkJoinSlotNotNull;
}
public LogicalApply(List<Expression> correlationSlot, SubqueryExpr subqueryExpr,
public LogicalApply(List<Expression> correlationSlot, SubQueryType subqueryType, boolean isNot,
Optional<Expression> compareExpr, Optional<Expression> typeCoercionExpr,
Optional<Expression> correlationFilter, Optional<MarkJoinSlotReference> markJoinSlotReference,
boolean needAddSubOutputToProjects, boolean inProject, boolean isMarkJoinSlotNotNull,
LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) {
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, input,
subquery);
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryType, isNot, compareExpr, typeCoercionExpr,
correlationFilter, markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
input, subquery);
}
public Optional<Expression> getCompareExpr() {
return compareExpr;
}
public Optional<Expression> getTypeCoercionExpr() {
return typeCoercionExpr;
}
public Expression getSubqueryOutput() {
return typeCoercionExpr.orElseGet(() -> right().getOutput().get(0));
}
public List<Expression> getCorrelationSlot() {
@ -106,20 +137,24 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return correlationFilter;
}
public SubqueryExpr getSubqueryExpr() {
return subqueryExpr;
}
public boolean isScalar() {
return this.subqueryExpr instanceof ScalarSubquery;
return subqueryType == SubQueryType.SCALAR_SUBQUERY;
}
public boolean isIn() {
return this.subqueryExpr instanceof InSubquery;
return subqueryType == SubQueryType.IN_SUBQUERY;
}
public boolean isExist() {
return this.subqueryExpr instanceof Exists;
return subqueryType == SubQueryType.EXITS_SUBQUERY;
}
public SubQueryType getSubqueryType() {
return subqueryType;
}
public boolean isNot() {
return isNot;
}
public boolean isCorrelated() {
@ -181,19 +216,22 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
}
LogicalApply<?, ?> that = (LogicalApply<?, ?>) o;
return Objects.equals(correlationSlot, that.getCorrelationSlot())
&& Objects.equals(subqueryExpr, that.getSubqueryExpr())
&& Objects.equals(subqueryType, that.subqueryType)
&& Objects.equals(compareExpr, that.compareExpr)
&& Objects.equals(typeCoercionExpr, that.typeCoercionExpr)
&& Objects.equals(correlationFilter, that.getCorrelationFilter())
&& Objects.equals(markJoinSlotReference, that.getMarkJoinSlotReference())
&& needAddSubOutputToProjects == that.needAddSubOutputToProjects
&& inProject == that.inProject
&& isMarkJoinSlotNotNull == that.isMarkJoinSlotNotNull;
&& isMarkJoinSlotNotNull == that.isMarkJoinSlotNotNull
&& isNot == that.isNot;
}
@Override
public int hashCode() {
return Objects.hash(
correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull);
correlationSlot, subqueryType, compareExpr, typeCoercionExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, isNot);
}
@Override
@ -215,33 +253,27 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
}
}
public LogicalApply<Plan, Plan> withSubqueryExprAndChildren(SubqueryExpr subqueryExpr, List<Plan> children) {
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
children.get(0), children.get(1));
}
@Override
public LogicalApply<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
return new LogicalApply<>(correlationSlot, subqueryType, isNot, compareExpr, typeCoercionExpr,
correlationFilter, markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull,
children.get(0), children.get(1));
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalApply<>(groupExpression, Optional.of(getLogicalProperties()),
correlationSlot, subqueryExpr, correlationFilter, markJoinSlotReference,
needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, left(), right());
correlationSlot, subqueryType, isNot, compareExpr, typeCoercionExpr, correlationFilter,
markJoinSlotReference, needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, left(), right());
}
@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(groupExpression, logicalProperties, correlationSlot, subqueryExpr,
correlationFilter, markJoinSlotReference,
return new LogicalApply<>(groupExpression, logicalProperties, correlationSlot, subqueryType, isNot,
compareExpr, typeCoercionExpr, correlationFilter, markJoinSlotReference,
needAddSubOutputToProjects, inProject, isMarkJoinSlotNotNull, children.get(0), children.get(1));
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
@ -44,10 +43,10 @@ class ExistsApplyToJoinTest implements MemoPatternMatchSupported {
LogicalOlapScan right = PlanConstructor.newLogicalOlapScan(0, "t2", 1);
List<Slot> rightSlots = right.getOutput();
EqualTo equalTo = new EqualTo(leftSlots.get(0), rightSlots.get(0));
Exists exists = new Exists(right, false);
LogicalApply<LogicalOlapScan, LogicalOlapScan> apply =
new LogicalApply<>(ImmutableList.of(leftSlots.get(0), rightSlots.get(0)),
exists, Optional.of(equalTo), Optional.empty(),
LogicalApply.SubQueryType.EXITS_SUBQUERY, false, Optional.empty(), Optional.empty(),
Optional.of(equalTo), Optional.empty(),
false, false, false, left, right);
PlanChecker.from(MemoTestUtils.createConnectContext(), apply)
.applyTopDown(new ExistsApplyToJoin())
@ -64,10 +63,10 @@ class ExistsApplyToJoinTest implements MemoPatternMatchSupported {
LogicalOlapScan right = PlanConstructor.newLogicalOlapScan(0, "t2", 1);
List<Slot> rightSlots = right.getOutput();
EqualTo equalTo = new EqualTo(leftSlots.get(0), rightSlots.get(0));
Exists exists = new Exists(right, false);
LogicalApply<LogicalOlapScan, LogicalOlapScan> apply =
new LogicalApply<>(Collections.emptyList(),
exists, Optional.of(equalTo), Optional.empty(),
LogicalApply.SubQueryType.EXITS_SUBQUERY, false, Optional.empty(), Optional.empty(),
Optional.of(equalTo), Optional.empty(),
false, false, false, left, right);
PlanChecker.from(MemoTestUtils.createConnectContext(), apply)
.applyTopDown(new ExistsApplyToJoin())
@ -84,10 +83,10 @@ class ExistsApplyToJoinTest implements MemoPatternMatchSupported {
LogicalOlapScan right = PlanConstructor.newLogicalOlapScan(0, "t2", 1);
List<Slot> rightSlots = right.getOutput();
EqualTo equalTo = new EqualTo(leftSlots.get(0), rightSlots.get(0));
Exists exists = new Exists(right, true);
LogicalApply<LogicalOlapScan, LogicalOlapScan> apply =
new LogicalApply<>(Collections.emptyList(),
exists, Optional.of(equalTo), Optional.empty(),
LogicalApply.SubQueryType.EXITS_SUBQUERY, true, Optional.empty(), Optional.empty(),
Optional.of(equalTo), Optional.empty(),
false, false, false, left, right);
PlanChecker.from(MemoTestUtils.createConnectContext(), apply)
.applyTopDown(new ExistsApplyToJoin())
@ -105,10 +104,10 @@ class ExistsApplyToJoinTest implements MemoPatternMatchSupported {
LogicalOlapScan right = PlanConstructor.newLogicalOlapScan(0, "t2", 1);
List<Slot> rightSlots = right.getOutput();
EqualTo equalTo = new EqualTo(leftSlots.get(0), rightSlots.get(0));
Exists exists = new Exists(right, true);
LogicalApply<LogicalOlapScan, LogicalOlapScan> apply =
new LogicalApply<>(ImmutableList.of(leftSlots.get(0), rightSlots.get(0)),
exists, Optional.of(equalTo), Optional.empty(),
LogicalApply.SubQueryType.EXITS_SUBQUERY, true, Optional.empty(), Optional.empty(),
Optional.of(equalTo), Optional.empty(),
false, false, false, left, right);
PlanChecker.from(MemoTestUtils.createConnectContext(), apply)
.applyTopDown(new ExistsApplyToJoin())