[fix](Nereids): don't transpose agg and join if join is mark join (#33312)
This commit is contained in:
@ -31,8 +31,8 @@ public class TransposeAggSemiJoin extends OneExplorationRuleFactory {
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(logicalJoin())
|
||||
.when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin())
|
||||
return logicalAggregate(
|
||||
logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin()))
|
||||
.then(agg -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
|
||||
if (!TransposeSemiJoinAgg.canTranspose(agg, join)) {
|
||||
|
||||
@ -32,8 +32,8 @@ public class TransposeAggSemiJoinProject extends OneExplorationRuleFactory {
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate(logicalProject(logicalJoin()))
|
||||
.when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin())
|
||||
return logicalAggregate(logicalProject(
|
||||
logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin())))
|
||||
.then(agg -> {
|
||||
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = agg.child();
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
|
||||
@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
|
||||
@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void markJoin() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.aggGroupUsingIndex(ImmutableList.of(0),
|
||||
ImmutableList.of(
|
||||
scan1.getOutput().get(0),
|
||||
new Alias(new Sum(scan1.getOutput().get(1)), "sum")
|
||||
)
|
||||
)
|
||||
.build();
|
||||
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
|
||||
.getAllPlan().size();
|
||||
Assertions.assertEquals(1, size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
|
||||
@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void markJoin() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 1))
|
||||
.project(ImmutableList.of(0))
|
||||
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.build();
|
||||
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
|
||||
.getAllPlan().size();
|
||||
Assertions.assertEquals(1, size);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user