[fix](Nereids) should not infer not null from mark join (#30897)

This commit is contained in:
morrySnow
2024-02-06 16:55:10 +08:00
committed by yiguolei
parent 08508d65fd
commit f95d0cf802
3 changed files with 26 additions and 0 deletions

View File

@ -43,6 +43,7 @@ public class InferJoinNotNull extends OneRewriteRuleFactory {
// TODO: maybe consider ANTI?
return logicalJoin(any(), any())
.when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isSemiJoin())
.whenNot(LogicalJoin::isMarkJoin)
.thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
Set<Expression> conjuncts = new HashSet<>();

View File

@ -70,6 +70,18 @@ class InferJoinNotNullTest implements MemoPatternMatchSupported {
logicalFilter().when(f -> f.getPredicate().toString().equals("( not id#2 IS NULL)"))
)
);
LogicalPlan rightMarkSemiJoin = new LogicalPlanBuilder(scan1)
.markJoin(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), rightMarkSemiJoin)
.applyTopDown(new InferJoinNotNull())
.matches(
rightSemiLogicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
);
}
@Test

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -49,6 +50,7 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
@ -99,6 +101,17 @@ public class LogicalPlanBuilder {
return from(project);
}
public LogicalPlanBuilder markJoin(LogicalPlan right, JoinType joinType, Pair<Integer, Integer> hashOnSlots) {
ImmutableList<EqualTo> hashConjuncts = ImmutableList.of(
new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));
LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjuncts),
Collections.emptyList(), Collections.emptyList(),
new DistributeHint(DistributeType.NONE), Optional.of(new MarkJoinSlotReference("fake")),
this.plan, right);
return from(join);
}
public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, Pair<Integer, Integer> hashOnSlots) {
ImmutableList<EqualTo> hashConjuncts = ImmutableList.of(
new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));