[refactor](Nereids) Refactor all rewrite logical unit tests by match-pattern (#17691)
This commit is contained in:
@ -23,30 +23,29 @@ import org.apache.doris.catalog.KeysType;
|
||||
import org.apache.doris.catalog.OlapTable;
|
||||
import org.apache.doris.catalog.PartitionInfo;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.ObjectId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.thrift.TStorageType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class EliminateGroupByConstantTest {
|
||||
/** Tests for {@link EliminateGroupByConstant}. */
|
||||
class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
|
||||
private static final OlapTable table = new OlapTable(0L, "student",
|
||||
ImmutableList.of(new Column("k1", Type.INT, true, AggregateType.NONE, "0", ""),
|
||||
new Column("k2", Type.INT, false, AggregateType.NONE, "0", ""),
|
||||
@ -65,110 +64,107 @@ public class EliminateGroupByConstantTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntegerLiteral() {
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(new IntegerLiteral(1), k2),
|
||||
ImmutableList.of(k1, k2),
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
|
||||
);
|
||||
void testIntegerLiteral() {
|
||||
LogicalPlan aggregate = new LogicalPlanBuilder(
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
|
||||
.agg(ImmutableList.of(new IntegerLiteral(1), k2),
|
||||
ImmutableList.of(k1, k2))
|
||||
.build();
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
|
||||
context.topDownRewrite(new EliminateGroupByConstant().build());
|
||||
context.bottomUpRewrite(new CheckAfterRewrite().build());
|
||||
|
||||
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
|
||||
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
|
||||
Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
|
||||
.applyTopDown(new EliminateGroupByConstant())
|
||||
.applyBottomUp(new CheckAfterRewrite())
|
||||
.matches(
|
||||
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOtherLiteral() {
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(
|
||||
new StringLiteral("str"), k2),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"), k1, k2),
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
|
||||
);
|
||||
void testOtherLiteral() {
|
||||
LogicalPlan aggregate = new LogicalPlanBuilder(
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
|
||||
.agg(ImmutableList.of(
|
||||
new StringLiteral("str"), k2),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"), k1, k2))
|
||||
.build();
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
|
||||
context.topDownRewrite(new EliminateGroupByConstant().build());
|
||||
context.bottomUpRewrite(new CheckAfterRewrite().build());
|
||||
|
||||
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
|
||||
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
|
||||
Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
|
||||
.applyTopDown(new EliminateGroupByConstant())
|
||||
.applyBottomUp(new CheckAfterRewrite())
|
||||
.matches(
|
||||
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMixedLiteral() {
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(
|
||||
new StringLiteral("str"), k2,
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new IntegerLiteral(3),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"),
|
||||
k2, k1, new Alias(new IntegerLiteral(1), "integer")),
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
|
||||
);
|
||||
void testMixedLiteral() {
|
||||
LogicalPlan aggregate = new LogicalPlanBuilder(
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
|
||||
.agg(ImmutableList.of(
|
||||
new StringLiteral("str"), k2,
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new IntegerLiteral(3),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"),
|
||||
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
|
||||
.build();
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
|
||||
context.topDownRewrite(new EliminateGroupByConstant().build());
|
||||
context.bottomUpRewrite(new CheckAfterRewrite().build());
|
||||
|
||||
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
|
||||
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
|
||||
List groupByExprs = aggregate1.getGroupByExpressions();
|
||||
Assertions.assertTrue(groupByExprs.get(0) instanceof Slot
|
||||
&& groupByExprs.get(1) instanceof Add);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
|
||||
.applyTopDown(new EliminateGroupByConstant())
|
||||
.applyBottomUp(new CheckAfterRewrite())
|
||||
.matches(
|
||||
aggregate()
|
||||
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComplexGroupBy() {
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new Max(k1), "max"),
|
||||
new Alias(new Min(k2), "min"),
|
||||
new Alias(new Add(k1, k2), "add")),
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
|
||||
);
|
||||
void testComplexGroupBy() {
|
||||
LogicalPlan aggregate = new LogicalPlanBuilder(
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
|
||||
.agg(ImmutableList.of(
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new Max(k1), "max"),
|
||||
new Alias(new Min(k2), "min"),
|
||||
new Alias(new Add(k1, k2), "add")))
|
||||
.build();
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
|
||||
context.topDownRewrite(new EliminateGroupByConstant().build());
|
||||
context.bottomUpRewrite(new CheckAfterRewrite().build());
|
||||
|
||||
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
|
||||
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
|
||||
.applyTopDown(new EliminateGroupByConstant())
|
||||
.applyBottomUp(new CheckAfterRewrite())
|
||||
.matches(
|
||||
aggregate()
|
||||
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOutOfRange() {
|
||||
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(
|
||||
new StringLiteral("str"), k2,
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new IntegerLiteral(3),
|
||||
new IntegerLiteral(5),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"),
|
||||
k2, k1, new Alias(new IntegerLiteral(1), "integer")),
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
|
||||
);
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
|
||||
context.topDownRewrite(new EliminateGroupByConstant().build());
|
||||
context.bottomUpRewrite(new CheckAfterRewrite().build());
|
||||
|
||||
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
|
||||
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
|
||||
void testOutOfRange() {
|
||||
LogicalPlan aggregate = new LogicalPlanBuilder(
|
||||
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
|
||||
.agg(ImmutableList.of(
|
||||
new StringLiteral("str"), k2,
|
||||
new IntegerLiteral(1),
|
||||
new IntegerLiteral(2),
|
||||
new IntegerLiteral(3),
|
||||
new IntegerLiteral(5),
|
||||
new Add(k1, k2)),
|
||||
ImmutableList.of(
|
||||
new Alias(new StringLiteral("str"), "str"),
|
||||
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
|
||||
.applyTopDown(new EliminateGroupByConstant())
|
||||
.applyBottomUp(new CheckAfterRewrite())
|
||||
.matches(
|
||||
aggregate()
|
||||
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,28 +17,25 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/**
|
||||
* test ELIMINATE_UNNECESSARY_PROJECT rule.
|
||||
*/
|
||||
public class EliminateUnnecessaryProjectTest extends TestWithFeService {
|
||||
class EliminateUnnecessaryProjectTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
@ -55,57 +52,49 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEliminateNonTopUnnecessaryProject() {
|
||||
void testEliminateNonTopUnnecessaryProject() {
|
||||
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
|
||||
.project(ImmutableList.of(1, 0))
|
||||
.filter(BooleanLiteral.FALSE)
|
||||
.build();
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
|
||||
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
|
||||
|
||||
Plan actual = cascadesContext.getMemo().copyOut();
|
||||
Assertions.assertTrue(actual.child(0) instanceof LogicalProject);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
|
||||
.applyTopDown(new EliminateUnnecessaryProject())
|
||||
.matchesFromRoot(logicalFilter(logicalProject()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEliminateTopUnnecessaryProject() {
|
||||
void testEliminateTopUnnecessaryProject() {
|
||||
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
|
||||
.project(ImmutableList.of(0, 1))
|
||||
.build();
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
|
||||
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
|
||||
|
||||
Plan actual = cascadesContext.getMemo().copyOut();
|
||||
Assertions.assertTrue(actual instanceof LogicalOlapScan);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
|
||||
.applyTopDown(new EliminateUnnecessaryProject())
|
||||
.matchesFromRoot(logicalOlapScan());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotEliminateTopProjectWhenOutputNotEquals() {
|
||||
void testNotEliminateTopProjectWhenOutputNotEquals() {
|
||||
LogicalPlan necessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
|
||||
.project(ImmutableList.of(1, 0))
|
||||
.build();
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(necessaryProject);
|
||||
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
|
||||
|
||||
Plan actual = cascadesContext.getMemo().copyOut();
|
||||
Assertions.assertTrue(actual instanceof LogicalProject);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), necessaryProject)
|
||||
.applyTopDown(new EliminateUnnecessaryProject())
|
||||
.matchesFromRoot(logicalProject());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEliminateProjectWhenEmptyRelationChild() {
|
||||
void testEliminateProjectWhenEmptyRelationChild() {
|
||||
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(new LogicalEmptyRelation(ImmutableList.of(
|
||||
new SlotReference("k1", IntegerType.INSTANCE),
|
||||
new SlotReference("k2", IntegerType.INSTANCE))))
|
||||
.project(ImmutableList.of(1, 0))
|
||||
.build();
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
|
||||
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
|
||||
|
||||
Plan actual = cascadesContext.getMemo().copyOut();
|
||||
Assertions.assertTrue(actual instanceof LogicalEmptyRelation);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
|
||||
.applyTopDown(new EliminateUnnecessaryProject())
|
||||
.matchesFromRoot(logicalEmptyRelation());
|
||||
}
|
||||
|
||||
// TODO: uncomment this after the Elimination project rule is correctly implemented
|
||||
|
||||
@ -17,8 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -31,12 +29,12 @@ import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -54,9 +52,9 @@ import java.util.Optional;
|
||||
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
|
||||
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
|
||||
*/
|
||||
class FindHashConditionForJoinTest {
|
||||
class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
|
||||
@Test
|
||||
public void testFindHashCondition() {
|
||||
void testFindHashCondition() {
|
||||
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
|
||||
@ -77,20 +75,12 @@ class FindHashConditionForJoinTest {
|
||||
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
|
||||
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(),
|
||||
expr, JoinHint.NONE, Optional.empty(), student, score);
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(join);
|
||||
List<Rule> rules = Lists.newArrayList(new FindHashConditionForJoin().build());
|
||||
|
||||
context.topDownRewrite(rules);
|
||||
Plan plan = context.getMemo().copyOut();
|
||||
LogicalJoin after = (LogicalJoin) plan;
|
||||
Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2);
|
||||
Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1));
|
||||
Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3));
|
||||
List<Expression> others = after.getOtherJoinConjuncts();
|
||||
Assertions.assertEquals(others.size(), 3);
|
||||
Assertions.assertTrue(others.contains(less));
|
||||
Assertions.assertTrue(others.contains(eq2));
|
||||
Assertions.assertTrue(others.contains(less));
|
||||
PlanChecker.from(new ConnectContext(), join)
|
||||
.applyTopDown(new FindHashConditionForJoin())
|
||||
.matches(
|
||||
logicalJoin()
|
||||
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
|
||||
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -17,47 +17,42 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.analyzer.UnboundRelation;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* MergeConsecutiveFilter ut
|
||||
* Tests for {@link MergeFilters}.
|
||||
*/
|
||||
public class MergeFiltersTest {
|
||||
class MergeFiltersTest implements MemoPatternMatchSupported {
|
||||
@Test
|
||||
public void testMergeConsecutiveFilters() {
|
||||
UnboundRelation relation = new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table"));
|
||||
void testMergeFilters() {
|
||||
Expression expression1 = new IntegerLiteral(1);
|
||||
LogicalFilter filter1 = new LogicalFilter<>(ImmutableSet.of(expression1), relation);
|
||||
Expression expression2 = new IntegerLiteral(2);
|
||||
LogicalFilter filter2 = new LogicalFilter<>(ImmutableSet.of(expression2), filter1);
|
||||
Expression expression3 = new IntegerLiteral(3);
|
||||
LogicalFilter filter3 = new LogicalFilter<>(ImmutableSet.of(expression3), filter2);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
|
||||
cascadesContext.bottomUpRewrite(rules);
|
||||
//check transformed plan
|
||||
Plan resultPlan = cascadesContext.getMemo().copyOut();
|
||||
System.out.println(resultPlan.treeString());
|
||||
Assertions.assertTrue(resultPlan instanceof LogicalFilter);
|
||||
Set<Expression> allPredicates = ImmutableSet.of(expression1, expression2, expression3);
|
||||
Assertions.assertEquals(ImmutableSet.copyOf(((LogicalFilter<?>) resultPlan).getConjuncts()), allPredicates);
|
||||
Assertions.assertTrue(resultPlan.child(0) instanceof UnboundRelation);
|
||||
LogicalPlan logicalFilter = new LogicalPlanBuilder(
|
||||
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")))
|
||||
.filter(ImmutableSet.of(expression1))
|
||||
.filter(ImmutableSet.of(expression2))
|
||||
.filter(ImmutableSet.of(expression3))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(new ConnectContext(), logicalFilter).applyBottomUp(new MergeFilters())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
unboundRelation()
|
||||
).when(filter -> filter.getConjuncts()
|
||||
.equals(ImmutableSet.of(expression1, expression2, expression3))));
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,37 +17,33 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.analyzer.UnboundRelation;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.plans.LimitPhase;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class MergeLimitsTest {
|
||||
/**
|
||||
* Tests for {@link MergeLimits}.
|
||||
*/
|
||||
class MergeLimitsTest implements MemoPatternMatchSupported {
|
||||
@Test
|
||||
public void testMergeConsecutiveLimits() {
|
||||
LogicalLimit limit3 = new LogicalLimit<>(3, 5, LimitPhase.ORIGIN, new UnboundRelation(
|
||||
RelationUtil.newRelationId(), Lists.newArrayList("db", "t")));
|
||||
LogicalLimit limit2 = new LogicalLimit<>(2, 0, LimitPhase.ORIGIN, limit3);
|
||||
LogicalLimit limit1 = new LogicalLimit<>(10, 0, LimitPhase.ORIGIN, limit2);
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
|
||||
context.topDownRewrite(rules);
|
||||
LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();
|
||||
|
||||
Assertions.assertEquals(2, limit.getLimit());
|
||||
Assertions.assertEquals(5, limit.getOffset());
|
||||
Assertions.assertEquals(1, limit.children().size());
|
||||
Assertions.assertTrue(limit.child(0) instanceof UnboundRelation);
|
||||
void testMergeLimits() {
|
||||
LogicalPlan logicalLimit = new LogicalPlanBuilder(
|
||||
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "t")))
|
||||
.limit(3, 5)
|
||||
.limit(2, 0)
|
||||
.limit(10, 0).build();
|
||||
|
||||
PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new MergeLimits())
|
||||
.matches(
|
||||
logicalLimit(
|
||||
unboundRelation()
|
||||
).when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5));
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RulePromise;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
@ -32,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
@ -42,7 +42,7 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.Collections;
|
||||
import java.util.Optional;
|
||||
|
||||
public class PhysicalStorageLayerAggregateTest implements GeneratedMemoPatterns {
|
||||
public class PhysicalStorageLayerAggregateTest implements MemoPatternMatchSupported {
|
||||
|
||||
@Test
|
||||
public void testWithoutProject() {
|
||||
|
||||
@ -26,8 +26,6 @@ import org.apache.doris.catalog.RangePartitionInfo;
|
||||
import org.apache.doris.catalog.RangePartitionItem;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.jmockit.Deencapsulation;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
@ -36,20 +34,19 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.BoundType;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Range;
|
||||
import mockit.Expectations;
|
||||
import mockit.Mocked;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -58,10 +55,10 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class PruneOlapScanPartitionTest {
|
||||
class PruneOlapScanPartitionTest implements MemoPatternMatchSupported {
|
||||
|
||||
@Test
|
||||
public void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception {
|
||||
void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception {
|
||||
List<Column> columnNameList = new ArrayList<>();
|
||||
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
|
||||
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
|
||||
@ -92,24 +89,29 @@ class PruneOlapScanPartitionTest {
|
||||
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
|
||||
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(expression), scan);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
List<Rule> rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
cascadesContext.topDownRewrite(rules);
|
||||
Plan resultPlan = cascadesContext.getMemo().copyOut();
|
||||
LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
|
||||
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
|
||||
.applyTopDown(new PruneOlapScanPartition())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(
|
||||
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
|
||||
)
|
||||
);
|
||||
|
||||
Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0));
|
||||
Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6));
|
||||
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
|
||||
filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6), scan);
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
cascadesContext.topDownRewrite(rules);
|
||||
resultPlan = cascadesContext.getMemo().copyOut();
|
||||
rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
|
||||
Assertions.assertEquals(1L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
|
||||
.applyTopDown(new PruneOlapScanPartition())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(
|
||||
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 1L)
|
||||
)
|
||||
);
|
||||
|
||||
Expression greaterThanEqual0 =
|
||||
new GreaterThanEqual(
|
||||
@ -118,17 +120,20 @@ class PruneOlapScanPartitionTest {
|
||||
new LessThanEqual(slotRef, new IntegerLiteral(5));
|
||||
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
|
||||
filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0, lessThanEqual5), scan);
|
||||
cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
cascadesContext.topDownRewrite(rules);
|
||||
resultPlan = cascadesContext.getMemo().copyOut();
|
||||
rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
|
||||
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
|
||||
Assertions.assertEquals(2, rewrittenOlapScan.getSelectedPartitionIds().toArray().length);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
|
||||
.applyTopDown(new PruneOlapScanPartition())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(
|
||||
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
|
||||
.when(olapScan -> olapScan.getSelectedPartitionIds().size() == 2)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception {
|
||||
void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception {
|
||||
List<Column> columnNameList = new ArrayList<>();
|
||||
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
|
||||
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
|
||||
@ -155,11 +160,14 @@ class PruneOlapScanPartitionTest {
|
||||
Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4));
|
||||
Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11));
|
||||
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(left, right), scan);
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter);
|
||||
List<Rule> rules = Lists.newArrayList(new PruneOlapScanPartition().build());
|
||||
cascadesContext.topDownRewrite(rules);
|
||||
Plan resultPlan = cascadesContext.getMemo().copyOut();
|
||||
LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
|
||||
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
|
||||
.applyTopDown(new PruneOlapScanPartition())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalOlapScan()
|
||||
.when(
|
||||
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,7 +29,6 @@ import org.apache.doris.catalog.OlapTable;
|
||||
import org.apache.doris.catalog.Partition;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
@ -41,7 +40,9 @@ import org.apache.doris.nereids.trees.plans.ObjectId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.planner.PartitionColumnFilter;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -54,10 +55,10 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class PruneOlapScanTabletTest {
|
||||
class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
|
||||
|
||||
@Test
|
||||
public void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
|
||||
void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
|
||||
@Mocked Partition partition, @Mocked MaterializedIndex index,
|
||||
@Mocked HashDistributionInfo distributionInfo) {
|
||||
List<Long> tabletIds = Lists.newArrayListWithExpectedSize(300);
|
||||
@ -153,11 +154,12 @@ public class PruneOlapScanTabletTest {
|
||||
|
||||
Assertions.assertEquals(0, filter.child().getSelectedTabletIds().size());
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(filter);
|
||||
context.topDownRewrite(ImmutableList.of(new PruneOlapScanTablet().build()));
|
||||
|
||||
LogicalFilter<LogicalOlapScan> filter1 = ((LogicalFilter<LogicalOlapScan>) context.getMemo().copyOut());
|
||||
LogicalOlapScan olapScan = filter1.child();
|
||||
Assertions.assertEquals(19, olapScan.getSelectedTabletIds().size());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
|
||||
.applyTopDown(new PruneOlapScanTablet())
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(scan -> scan.getSelectedTabletIds().size() == 19)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,52 +17,53 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinHint;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class PushdownJoinOtherConditionTest {
|
||||
class PushdownJoinOtherConditionTest implements MemoPatternMatchSupported {
|
||||
|
||||
private Plan rStudent;
|
||||
private Plan rScore;
|
||||
private LogicalOlapScan rStudent;
|
||||
private LogicalOlapScan rScore;
|
||||
|
||||
private List<Slot> rStudentSlots;
|
||||
|
||||
private List<Slot> rScoreSlots;
|
||||
|
||||
/**
|
||||
* ut before.
|
||||
*/
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
|
||||
rStudentSlots = rStudent.getOutput();
|
||||
rScoreSlots = rScore.getOutput();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void oneSide() {
|
||||
void oneSide() {
|
||||
oneSide(JoinType.CROSS_JOIN, false);
|
||||
oneSide(JoinType.INNER_JOIN, false);
|
||||
oneSide(JoinType.LEFT_OUTER_JOIN, true);
|
||||
@ -74,46 +75,43 @@ public class PushdownJoinOtherConditionTest {
|
||||
}
|
||||
|
||||
private void oneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
|
||||
Expression pushSide1 = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
|
||||
Expression pushSide2 = new GreaterThan(rStudentSlots.get(1), Literal.of(50));
|
||||
List<Expression> condition = ImmutableList.of(pushSide1, pushSide2);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
LogicalOlapScan left = rStudent;
|
||||
LogicalOlapScan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
LogicalPlan root = new LogicalPlanBuilder(left)
|
||||
.join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
|
||||
.project(Lists.newArrayList())
|
||||
.build();
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new PushdownJoinOtherCondition());
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
planChecker.matches(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalFilter()
|
||||
.when(filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition))));
|
||||
} else {
|
||||
planChecker.matches(
|
||||
logicalJoin(
|
||||
logicalFilter().when(
|
||||
filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition)),
|
||||
logicalOlapScan()));
|
||||
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
|
||||
Assertions.assertEquals(condition, ImmutableList.copyOf(actualFilter.getConjuncts()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToBothSide() {
|
||||
void bothSideToBothSide() {
|
||||
bothSideToBothSide(JoinType.CROSS_JOIN);
|
||||
bothSideToBothSide(JoinType.INNER_JOIN);
|
||||
bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
|
||||
@ -122,34 +120,26 @@ public class PushdownJoinOtherConditionTest {
|
||||
|
||||
private void bothSideToBothSide(JoinType joinType) {
|
||||
|
||||
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression leftSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
|
||||
Expression rightSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60));
|
||||
List<Expression> condition = ImmutableList.of(leftSide, rightSide);
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), rStudent,
|
||||
rScore);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
LogicalPlan root = new LogicalPlanBuilder(rStudent)
|
||||
.join(rScore, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
|
||||
.project(Lists.newArrayList())
|
||||
.build();
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
|
||||
.child(1).getLogicalExpression().getPlan();
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
|
||||
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
|
||||
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
|
||||
Assertions.assertEquals(ImmutableSet.of(leftSide), actualLeft.getConjuncts());
|
||||
Assertions.assertEquals(ImmutableSet.of(rightSide), actualRight.getConjuncts());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new PushdownJoinOtherCondition())
|
||||
.matches(
|
||||
logicalJoin(
|
||||
logicalFilter().when(left -> left.getConjuncts().equals(ImmutableSet.of(leftSide))),
|
||||
logicalFilter().when(right -> right.getConjuncts().equals(ImmutableSet.of(rightSide)))
|
||||
));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bothSideToOneSide() {
|
||||
void bothSideToOneSide() {
|
||||
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
|
||||
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
|
||||
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
|
||||
@ -157,45 +147,37 @@ public class PushdownJoinOtherConditionTest {
|
||||
}
|
||||
|
||||
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
|
||||
|
||||
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
|
||||
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
|
||||
Expression pushSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
|
||||
Expression reserveSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60));
|
||||
List<Expression> condition = ImmutableList.of(pushSide, reserveSide);
|
||||
|
||||
Plan left = rStudent;
|
||||
Plan right = rScore;
|
||||
LogicalOlapScan left = rStudent;
|
||||
LogicalOlapScan right = rScore;
|
||||
if (testRight) {
|
||||
left = rScore;
|
||||
right = rStudent;
|
||||
}
|
||||
|
||||
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
|
||||
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
|
||||
LogicalPlan root = new LogicalPlanBuilder(left)
|
||||
.join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
|
||||
.project(Lists.newArrayList())
|
||||
.build();
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
Group rootGroup = memo.getRoot();
|
||||
PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new PushdownJoinOtherCondition());
|
||||
|
||||
Plan shouldJoin = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldFilter = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
Plan shouldScan = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
if (testRight) {
|
||||
shouldFilter = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
|
||||
shouldScan = rootGroup.getLogicalExpression()
|
||||
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
|
||||
planChecker.matches(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide)))
|
||||
));
|
||||
} else {
|
||||
planChecker.matches(
|
||||
logicalJoin(
|
||||
logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide))),
|
||||
logicalOlapScan()
|
||||
));
|
||||
}
|
||||
|
||||
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
|
||||
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
|
||||
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
|
||||
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
|
||||
Assertions.assertEquals(ImmutableSet.of(pushSide), actualFilter.getConjuncts());
|
||||
}
|
||||
|
||||
private Memo rewrite(Plan plan) {
|
||||
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownJoinOtherCondition());
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,10 +27,10 @@ import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
|
||||
class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
|
||||
|
||||
@Test
|
||||
public void testPushdownProjectThroughLimit() {
|
||||
void testPushdownProjectThroughLimit() {
|
||||
LogicalPlan project = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
|
||||
.limit(1, 1)
|
||||
.project(ImmutableList.of(0)) // id
|
||||
|
||||
@ -17,25 +17,32 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.trees.plans.LimitPhase;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SplitLimitTest {
|
||||
/**
|
||||
* Tests for {@link SplitLimit}.
|
||||
*/
|
||||
class SplitLimitTest implements MemoPatternMatchSupported {
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
|
||||
@Test
|
||||
void testSplitLimit() {
|
||||
Plan plan = new LogicalLimit<>(0, 0, LimitPhase.ORIGIN, scan1);
|
||||
plan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.rewrite()
|
||||
.getPlan();
|
||||
plan.anyMatch(x -> x instanceof LogicalLimit && ((LogicalLimit<?>) x).isSplit());
|
||||
LogicalPlan limit = new LogicalPlanBuilder(scan1)
|
||||
.limit(0, 0)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), limit)
|
||||
.applyTopDown(new SplitLimit())
|
||||
.matches(
|
||||
globalLogicalLimit(localLogicalLimit())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,6 +55,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.planner.PlanFragment;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.OriginStatement;
|
||||
@ -126,6 +127,12 @@ public class PlanChecker {
|
||||
return applyTopDown(ruleFactory.buildRules());
|
||||
}
|
||||
|
||||
public PlanChecker applyTopDown(CustomRewriter customRewriter) {
|
||||
cascadesContext.topDownRewrite(customRewriter);
|
||||
MemoValidator.validate(cascadesContext.getMemo());
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker applyTopDown(List<Rule> rule) {
|
||||
cascadesContext.topDownRewrite(rule);
|
||||
MemoValidator.validate(cascadesContext.getMemo());
|
||||
|
||||
Reference in New Issue
Block a user