[fix](Nereids): don't transpose agg and join if join is mark join (#33312)

This commit is contained in:
谢健
2024-04-08 11:08:05 +08:00
committed by yiguolei
parent 0d0a96d097
commit 045dd05f2a
4 changed files with 37 additions and 4 deletions

View File

@ -31,8 +31,8 @@ public class TransposeAggSemiJoin extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalAggregate(logicalJoin())
.when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin())
return logicalAggregate(
logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin()))
.then(agg -> {
LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
if (!TransposeSemiJoinAgg.canTranspose(agg, join)) {

View File

@ -32,8 +32,8 @@ public class TransposeAggSemiJoinProject extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalAggregate(logicalProject(logicalJoin()))
.when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin())
return logicalAggregate(logicalProject(
logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin())))
.then(agg -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = agg.child();
LogicalJoin<GroupPlan, GroupPlan> join = project.child();

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
)
);
}
@Test
void markJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(
scan1.getOutput().get(0),
new Alias(new Sum(scan1.getOutput().get(1)), "sum")
)
)
.build();
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
.getAllPlan().size();
Assertions.assertEquals(1, size);
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
);
}
@Test
void markJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 1))
.project(ImmutableList.of(0))
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.build();
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
.getAllPlan().size();
Assertions.assertEquals(1, size);
}
}