[fix](Nereids) should not infer not null from mark join (#30897)
This commit is contained in:
@ -43,6 +43,7 @@ public class InferJoinNotNull extends OneRewriteRuleFactory {
|
||||
// TODO: maybe consider ANTI?
|
||||
return logicalJoin(any(), any())
|
||||
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isSemiJoin())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
.thenApply(ctx -> {
|
||||
LogicalJoin<Plan, Plan> join = ctx.root;
|
||||
Set<Expression> conjuncts = new HashSet<>();
|
||||
|
||||
@ -70,6 +70,18 @@ class InferJoinNotNullTest implements MemoPatternMatchSupported {
|
||||
logicalFilter().when(f -> f.getPredicate().toString().equals("( not id#2 IS NULL)"))
|
||||
)
|
||||
);
|
||||
|
||||
LogicalPlan rightMarkSemiJoin = new LogicalPlanBuilder(scan1)
|
||||
.markJoin(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), rightMarkSemiJoin)
|
||||
.applyTopDown(new InferJoinNotNull())
|
||||
.matches(
|
||||
rightSemiLogicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
|
||||
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.DistributeType;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
@ -49,6 +50,7 @@ import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
@ -99,6 +101,17 @@ public class LogicalPlanBuilder {
|
||||
return from(project);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder markJoin(LogicalPlan right, JoinType joinType, Pair<Integer, Integer> hashOnSlots) {
|
||||
ImmutableList<EqualTo> hashConjuncts = ImmutableList.of(
|
||||
new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));
|
||||
|
||||
LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjuncts),
|
||||
Collections.emptyList(), Collections.emptyList(),
|
||||
new DistributeHint(DistributeType.NONE), Optional.of(new MarkJoinSlotReference("fake")),
|
||||
this.plan, right);
|
||||
return from(join);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, Pair<Integer, Integer> hashOnSlots) {
|
||||
ImmutableList<EqualTo> hashConjuncts = ImmutableList.of(
|
||||
new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));
|
||||
|
||||
Reference in New Issue
Block a user