[fix](Nereids): don't transpose agg and join if join is mark join (#33312)

This commit is contained in:
谢健
2024-04-08 11:08:05 +08:00
committed by yiguolei
parent 0d0a96d097
commit 045dd05f2a
4 changed files with 37 additions and 4 deletions

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
)
);
}
@Test
void markJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(
scan1.getOutput().get(0),
new Alias(new Sum(scan1.getOutput().get(1)), "sum")
)
)
.build();
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
.getAllPlan().size();
Assertions.assertEquals(1, size);
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
);
}
@Test
void markJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 1))
.project(ImmutableList.of(0))
.markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.build();
int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(TransposeAggSemiJoin.INSTANCE.build())
.getAllPlan().size();
Assertions.assertEquals(1, size);
}
}