From f95d0cf8024bf26920772a30db11397bcc3a73f9 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Tue, 6 Feb 2024 16:55:10 +0800 Subject: [PATCH] [fix](Nereids) should not infer not null from mark join (#30897) --- .../nereids/rules/rewrite/InferJoinNotNull.java | 1 + .../nereids/rules/rewrite/InferJoinNotNullTest.java | 12 ++++++++++++ .../doris/nereids/util/LogicalPlanBuilder.java | 13 +++++++++++++ 3 files changed, 26 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java index d583512145..e7168ca0e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java @@ -43,6 +43,7 @@ public class InferJoinNotNull extends OneRewriteRuleFactory { // TODO: maybe consider ANTI? return logicalJoin(any(), any()) .when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isSemiJoin()) + .whenNot(LogicalJoin::isMarkJoin) .thenApply(ctx -> { LogicalJoin join = ctx.root; Set conjuncts = new HashSet<>(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java index 867c2000c3..a8ac045d80 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java @@ -70,6 +70,18 @@ class InferJoinNotNullTest implements MemoPatternMatchSupported { logicalFilter().when(f -> f.getPredicate().toString().equals("( not id#2 IS NULL)")) ) ); + + LogicalPlan rightMarkSemiJoin = new LogicalPlanBuilder(scan1) + .markJoin(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), rightMarkSemiJoin) + .applyTopDown(new InferJoinNotNull()) + .matches( + rightSemiLogicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index 971ccd90ef..697a830279 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.DistributeType; import org.apache.doris.nereids.trees.plans.JoinType; @@ -49,6 +50,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.Set; @@ -99,6 +101,17 @@ public class LogicalPlanBuilder { return from(project); } + public LogicalPlanBuilder markJoin(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { + ImmutableList hashConjuncts = ImmutableList.of( + new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); + + LogicalJoin join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjuncts), + Collections.emptyList(), Collections.emptyList(), + new DistributeHint(DistributeType.NONE), Optional.of(new MarkJoinSlotReference("fake")), + this.plan, right); + return from(join); + } + public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { ImmutableList hashConjuncts = ImmutableList.of( new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));