[refactor](Nereids) Refactor all rewrite logical unit tests by match-pattern (#17691)

This commit is contained in:
Weijie Guo
2023-03-12 18:49:12 +08:00
committed by GitHub
parent 9b687026bd
commit 11fbe07221
12 changed files with 317 additions and 345 deletions

View File

@ -23,30 +23,29 @@ import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.thrift.TStorageType;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
public class EliminateGroupByConstantTest {
/** Tests for {@link EliminateGroupByConstant}. */
class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
private static final OlapTable table = new OlapTable(0L, "student",
ImmutableList.of(new Column("k1", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("k2", Type.INT, false, AggregateType.NONE, "0", ""),
@ -65,110 +64,107 @@ public class EliminateGroupByConstantTest {
}
@Test
public void testIntegerLiteral() {
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
ImmutableList.of(new IntegerLiteral(1), k2),
ImmutableList.of(k1, k2),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
);
void testIntegerLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
.agg(ImmutableList.of(new IntegerLiteral(1), k2),
ImmutableList.of(k1, k2))
.build();
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
context.topDownRewrite(new EliminateGroupByConstant().build());
context.bottomUpRewrite(new CheckAfterRewrite().build());
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot);
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
);
}
@Test
public void testOtherLiteral() {
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
ImmutableList.of(
new StringLiteral("str"), k2),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"), k1, k2),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
);
void testOtherLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
.agg(ImmutableList.of(
new StringLiteral("str"), k2),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"), k1, k2))
.build();
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
context.topDownRewrite(new EliminateGroupByConstant().build());
context.bottomUpRewrite(new CheckAfterRewrite().build());
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot);
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
);
}
@Test
public void testMixedLiteral() {
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
);
void testMixedLiteral() {
LogicalPlan aggregate = new LogicalPlanBuilder(
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
.agg(ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
.build();
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
context.topDownRewrite(new EliminateGroupByConstant().build());
context.bottomUpRewrite(new CheckAfterRewrite().build());
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
List groupByExprs = aggregate1.getGroupByExpressions();
Assertions.assertTrue(groupByExprs.get(0) instanceof Slot
&& groupByExprs.get(1) instanceof Add);
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
);
}
@Test
public void testComplexGroupBy() {
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
ImmutableList.of(
new IntegerLiteral(1),
new IntegerLiteral(2),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new Max(k1), "max"),
new Alias(new Min(k2), "min"),
new Alias(new Add(k1, k2), "add")),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
);
void testComplexGroupBy() {
LogicalPlan aggregate = new LogicalPlanBuilder(
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
.agg(ImmutableList.of(
new IntegerLiteral(1),
new IntegerLiteral(2),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new Max(k1), "max"),
new Alias(new Min(k2), "min"),
new Alias(new Add(k1, k2), "add")))
.build();
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
context.topDownRewrite(new EliminateGroupByConstant().build());
context.bottomUpRewrite(new CheckAfterRewrite().build());
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1);
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
);
}
@Test
public void testOutOfRange() {
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(
ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new IntegerLiteral(5),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)
);
CascadesContext context = MemoTestUtils.createCascadesContext(aggregate);
context.topDownRewrite(new EliminateGroupByConstant().build());
context.bottomUpRewrite(new CheckAfterRewrite().build());
LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut());
Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2);
void testOutOfRange() {
LogicalPlan aggregate = new LogicalPlanBuilder(
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table))
.agg(ImmutableList.of(
new StringLiteral("str"), k2,
new IntegerLiteral(1),
new IntegerLiteral(2),
new IntegerLiteral(3),
new IntegerLiteral(5),
new Add(k1, k2)),
ImmutableList.of(
new Alias(new StringLiteral("str"), "str"),
k2, k1, new Alias(new IntegerLiteral(1), "integer")))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
.applyTopDown(new EliminateGroupByConstant())
.applyBottomUp(new CheckAfterRewrite())
.matches(
aggregate()
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
);
}
}

View File

@ -17,28 +17,25 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* test ELIMINATE_UNNECESSARY_PROJECT rule.
*/
public class EliminateUnnecessaryProjectTest extends TestWithFeService {
class EliminateUnnecessaryProjectTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@ -55,57 +52,49 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService {
}
@Test
public void testEliminateNonTopUnnecessaryProject() {
void testEliminateNonTopUnnecessaryProject() {
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(1, 0))
.filter(BooleanLiteral.FALSE)
.build();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
Plan actual = cascadesContext.getMemo().copyOut();
Assertions.assertTrue(actual.child(0) instanceof LogicalProject);
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
.applyTopDown(new EliminateUnnecessaryProject())
.matchesFromRoot(logicalFilter(logicalProject()));
}
@Test
public void testEliminateTopUnnecessaryProject() {
void testEliminateTopUnnecessaryProject() {
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(0, 1))
.build();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
Plan actual = cascadesContext.getMemo().copyOut();
Assertions.assertTrue(actual instanceof LogicalOlapScan);
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
.applyTopDown(new EliminateUnnecessaryProject())
.matchesFromRoot(logicalOlapScan());
}
@Test
public void testNotEliminateTopProjectWhenOutputNotEquals() {
void testNotEliminateTopProjectWhenOutputNotEquals() {
LogicalPlan necessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.project(ImmutableList.of(1, 0))
.build();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(necessaryProject);
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
Plan actual = cascadesContext.getMemo().copyOut();
Assertions.assertTrue(actual instanceof LogicalProject);
PlanChecker.from(MemoTestUtils.createConnectContext(), necessaryProject)
.applyTopDown(new EliminateUnnecessaryProject())
.matchesFromRoot(logicalProject());
}
@Test
public void testEliminateProjectWhenEmptyRelationChild() {
void testEliminateProjectWhenEmptyRelationChild() {
LogicalPlan unnecessaryProject = new LogicalPlanBuilder(new LogicalEmptyRelation(ImmutableList.of(
new SlotReference("k1", IntegerType.INSTANCE),
new SlotReference("k2", IntegerType.INSTANCE))))
.project(ImmutableList.of(1, 0))
.build();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject);
cascadesContext.topDownRewrite(new EliminateUnnecessaryProject());
Plan actual = cascadesContext.getMemo().copyOut();
Assertions.assertTrue(actual instanceof LogicalEmptyRelation);
PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject)
.applyTopDown(new EliminateUnnecessaryProject())
.matchesFromRoot(logicalEmptyRelation());
}
// TODO: uncomment this after the Elimination project rule is correctly implemented

View File

@ -17,8 +17,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -31,12 +29,12 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@ -54,9 +52,9 @@ import java.util.Optional;
* -hashJoinConjuncts={A.x=B.x, A.y+1=B.y}
* -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x"
*/
class FindHashConditionForJoinTest {
class FindHashConditionForJoinTest implements MemoPatternMatchSupported {
@Test
public void testFindHashCondition() {
void testFindHashCondition() {
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
@ -77,20 +75,12 @@ class FindHashConditionForJoinTest {
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(),
expr, JoinHint.NONE, Optional.empty(), student, score);
CascadesContext context = MemoTestUtils.createCascadesContext(join);
List<Rule> rules = Lists.newArrayList(new FindHashConditionForJoin().build());
context.topDownRewrite(rules);
Plan plan = context.getMemo().copyOut();
LogicalJoin after = (LogicalJoin) plan;
Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2);
Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1));
Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3));
List<Expression> others = after.getOtherJoinConjuncts();
Assertions.assertEquals(others.size(), 3);
Assertions.assertTrue(others.contains(less));
Assertions.assertTrue(others.contains(eq2));
Assertions.assertTrue(others.contains(less));
PlanChecker.from(new ConnectContext(), join)
.applyTopDown(new FindHashConditionForJoin())
.matches(
logicalJoin()
.when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3)))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less))));
}
}

View File

@ -17,47 +17,42 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
/**
* MergeConsecutiveFilter ut
* Tests for {@link MergeFilters}.
*/
public class MergeFiltersTest {
class MergeFiltersTest implements MemoPatternMatchSupported {
@Test
public void testMergeConsecutiveFilters() {
UnboundRelation relation = new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table"));
void testMergeFilters() {
Expression expression1 = new IntegerLiteral(1);
LogicalFilter filter1 = new LogicalFilter<>(ImmutableSet.of(expression1), relation);
Expression expression2 = new IntegerLiteral(2);
LogicalFilter filter2 = new LogicalFilter<>(ImmutableSet.of(expression2), filter1);
Expression expression3 = new IntegerLiteral(3);
LogicalFilter filter3 = new LogicalFilter<>(ImmutableSet.of(expression3), filter2);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3);
List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
cascadesContext.bottomUpRewrite(rules);
//check transformed plan
Plan resultPlan = cascadesContext.getMemo().copyOut();
System.out.println(resultPlan.treeString());
Assertions.assertTrue(resultPlan instanceof LogicalFilter);
Set<Expression> allPredicates = ImmutableSet.of(expression1, expression2, expression3);
Assertions.assertEquals(ImmutableSet.copyOf(((LogicalFilter<?>) resultPlan).getConjuncts()), allPredicates);
Assertions.assertTrue(resultPlan.child(0) instanceof UnboundRelation);
LogicalPlan logicalFilter = new LogicalPlanBuilder(
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")))
.filter(ImmutableSet.of(expression1))
.filter(ImmutableSet.of(expression2))
.filter(ImmutableSet.of(expression3))
.build();
PlanChecker.from(new ConnectContext(), logicalFilter).applyBottomUp(new MergeFilters())
.matches(
logicalFilter(
unboundRelation()
).when(filter -> filter.getConjuncts()
.equals(ImmutableSet.of(expression1, expression2, expression3))));
}
}

View File

@ -17,37 +17,33 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
public class MergeLimitsTest {
/**
* Tests for {@link MergeLimits}.
*/
class MergeLimitsTest implements MemoPatternMatchSupported {
@Test
public void testMergeConsecutiveLimits() {
LogicalLimit limit3 = new LogicalLimit<>(3, 5, LimitPhase.ORIGIN, new UnboundRelation(
RelationUtil.newRelationId(), Lists.newArrayList("db", "t")));
LogicalLimit limit2 = new LogicalLimit<>(2, 0, LimitPhase.ORIGIN, limit3);
LogicalLimit limit1 = new LogicalLimit<>(10, 0, LimitPhase.ORIGIN, limit2);
CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
context.topDownRewrite(rules);
LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();
Assertions.assertEquals(2, limit.getLimit());
Assertions.assertEquals(5, limit.getOffset());
Assertions.assertEquals(1, limit.children().size());
Assertions.assertTrue(limit.child(0) instanceof UnboundRelation);
void testMergeLimits() {
LogicalPlan logicalLimit = new LogicalPlanBuilder(
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "t")))
.limit(3, 5)
.limit(2, 0)
.limit(10, 0).build();
PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new MergeLimits())
.matches(
logicalLimit(
unboundRelation()
).when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5));
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
@ -32,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
@ -42,7 +42,7 @@ import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.Optional;
public class PhysicalStorageLayerAggregateTest implements GeneratedMemoPatterns {
public class PhysicalStorageLayerAggregateTest implements MemoPatternMatchSupported {
@Test
public void testWithoutProject() {

View File

@ -26,8 +26,6 @@ import org.apache.doris.catalog.RangePartitionInfo;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
@ -36,20 +34,19 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@ -58,10 +55,10 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
class PruneOlapScanPartitionTest {
class PruneOlapScanPartitionTest implements MemoPatternMatchSupported {
@Test
public void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception {
void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@ -92,24 +89,29 @@ class PruneOlapScanPartitionTest {
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(expression), scan);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter);
List<Rule> rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
Plan resultPlan = cascadesContext.getMemo().copyOut();
LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
)
);
Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0));
Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6));
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6), scan);
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
resultPlan = cascadesContext.getMemo().copyOut();
rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
Assertions.assertEquals(1L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 1L)
)
);
Expression greaterThanEqual0 =
new GreaterThanEqual(
@ -118,17 +120,20 @@ class PruneOlapScanPartitionTest {
new LessThanEqual(slotRef, new IntegerLiteral(5));
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0, lessThanEqual5), scan);
cascadesContext = MemoTestUtils.createCascadesContext(filter);
rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
resultPlan = cascadesContext.getMemo().copyOut();
rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
Assertions.assertEquals(2, rewrittenOlapScan.getSelectedPartitionIds().toArray().length);
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
.when(olapScan -> olapScan.getSelectedPartitionIds().size() == 2)
)
);
}
@Test
public void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception {
void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
@ -155,11 +160,14 @@ class PruneOlapScanPartitionTest {
Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4));
Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(left, right), scan);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter);
List<Rule> rules = Lists.newArrayList(new PruneOlapScanPartition().build());
cascadesContext.topDownRewrite(rules);
Plan resultPlan = cascadesContext.getMemo().copyOut();
LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0);
Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next());
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan()
.when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
)
);
}
}

View File

@ -29,7 +29,6 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
@ -41,7 +40,9 @@ import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.planner.PartitionColumnFilter;
import com.google.common.collect.ImmutableList;
@ -54,10 +55,10 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class PruneOlapScanTabletTest {
class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
@Test
public void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
void testPruneOlapScanTablet(@Mocked OlapTable olapTable,
@Mocked Partition partition, @Mocked MaterializedIndex index,
@Mocked HashDistributionInfo distributionInfo) {
List<Long> tabletIds = Lists.newArrayListWithExpectedSize(300);
@ -153,11 +154,12 @@ public class PruneOlapScanTabletTest {
Assertions.assertEquals(0, filter.child().getSelectedTabletIds().size());
CascadesContext context = MemoTestUtils.createCascadesContext(filter);
context.topDownRewrite(ImmutableList.of(new PruneOlapScanTablet().build()));
LogicalFilter<LogicalOlapScan> filter1 = ((LogicalFilter<LogicalOlapScan>) context.getMemo().copyOut());
LogicalOlapScan olapScan = filter1.child();
Assertions.assertEquals(19, olapScan.getSelectedTabletIds().size());
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanTablet())
.matches(
logicalFilter(
logicalOlapScan().when(scan -> scan.getSelectedTabletIds().size() == 19)
)
);
}
}

View File

@ -17,52 +17,53 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushdownJoinOtherConditionTest {
class PushdownJoinOtherConditionTest implements MemoPatternMatchSupported {
private Plan rStudent;
private Plan rScore;
private LogicalOlapScan rStudent;
private LogicalOlapScan rScore;
private List<Slot> rStudentSlots;
private List<Slot> rScoreSlots;
/**
* ut before.
*/
@BeforeAll
public final void beforeAll() {
final void beforeAll() {
rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));
rStudentSlots = rStudent.getOutput();
rScoreSlots = rScore.getOutput();
}
@Test
public void oneSide() {
void oneSide() {
oneSide(JoinType.CROSS_JOIN, false);
oneSide(JoinType.INNER_JOIN, false);
oneSide(JoinType.LEFT_OUTER_JOIN, true);
@ -74,46 +75,43 @@ public class PushdownJoinOtherConditionTest {
}
private void oneSide(JoinType joinType, boolean testRight) {
Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50));
Expression pushSide1 = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
Expression pushSide2 = new GreaterThan(rStudentSlots.get(1), Literal.of(50));
List<Expression> condition = ImmutableList.of(pushSide1, pushSide2);
Plan left = rStudent;
Plan right = rScore;
LogicalOlapScan left = rStudent;
LogicalOlapScan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
LogicalPlan root = new LogicalPlanBuilder(left)
.join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
.project(Lists.newArrayList())
.build();
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new PushdownJoinOtherCondition());
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
planChecker.matches(
logicalJoin(
logicalOlapScan(),
logicalFilter()
.when(filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition))));
} else {
planChecker.matches(
logicalJoin(
logicalFilter().when(
filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition)),
logicalOlapScan()));
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(condition, ImmutableList.copyOf(actualFilter.getConjuncts()));
}
@Test
public void bothSideToBothSide() {
void bothSideToBothSide() {
bothSideToBothSide(JoinType.CROSS_JOIN);
bothSideToBothSide(JoinType.INNER_JOIN);
bothSideToBothSide(JoinType.LEFT_SEMI_JOIN);
@ -122,34 +120,26 @@ public class PushdownJoinOtherConditionTest {
private void bothSideToBothSide(JoinType joinType) {
Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression leftSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
Expression rightSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60));
List<Expression> condition = ImmutableList.of(leftSide, rightSide);
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), rStudent,
rScore);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
LogicalPlan root = new LogicalPlanBuilder(rStudent)
.join(rScore, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
.project(Lists.newArrayList())
.build();
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression()
.child(1).getLogicalExpression().getPlan();
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(leftFilter instanceof LogicalFilter);
Assertions.assertTrue(rightFilter instanceof LogicalFilter);
LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter;
LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter;
Assertions.assertEquals(ImmutableSet.of(leftSide), actualLeft.getConjuncts());
Assertions.assertEquals(ImmutableSet.of(rightSide), actualRight.getConjuncts());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new PushdownJoinOtherCondition())
.matches(
logicalJoin(
logicalFilter().when(left -> left.getConjuncts().equals(ImmutableSet.of(leftSide))),
logicalFilter().when(right -> right.getConjuncts().equals(ImmutableSet.of(rightSide)))
));
}
@Test
public void bothSideToOneSide() {
void bothSideToOneSide() {
bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true);
bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true);
bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false);
@ -157,45 +147,37 @@ public class PushdownJoinOtherConditionTest {
}
private void bothSideToOneSide(JoinType joinType, boolean testRight) {
Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression pushSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18));
Expression reserveSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60));
List<Expression> condition = ImmutableList.of(pushSide, reserveSide);
Plan left = rStudent;
Plan right = rScore;
LogicalOlapScan left = rStudent;
LogicalOlapScan right = rScore;
if (testRight) {
left = rScore;
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
LogicalPlan root = new LogicalPlanBuilder(left)
.join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition)
.project(Lists.newArrayList())
.build();
Memo memo = rewrite(root);
Group rootGroup = memo.getRoot();
PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new PushdownJoinOtherCondition());
Plan shouldJoin = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().getPlan();
Plan shouldFilter = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
Plan shouldScan = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
if (testRight) {
shouldFilter = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan();
shouldScan = rootGroup.getLogicalExpression()
.child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan();
planChecker.matches(
logicalJoin(
logicalOlapScan(),
logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide)))
));
} else {
planChecker.matches(
logicalJoin(
logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide))),
logicalOlapScan()
));
}
Assertions.assertTrue(shouldJoin instanceof LogicalJoin);
Assertions.assertTrue(shouldFilter instanceof LogicalFilter);
Assertions.assertTrue(shouldScan instanceof LogicalOlapScan);
LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter;
Assertions.assertEquals(ImmutableSet.of(pushSide), actualFilter.getConjuncts());
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownJoinOtherCondition());
}
}

View File

@ -27,10 +27,10 @@ import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
public class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported {
@Test
public void testPushdownProjectThroughLimit() {
void testPushdownProjectThroughLimit() {
LogicalPlan project = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
.limit(1, 1)
.project(ImmutableList.of(0)) // id

View File

@ -17,25 +17,32 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.junit.jupiter.api.Test;
public class SplitLimitTest {
/**
* Tests for {@link SplitLimit}.
*/
class SplitLimitTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
void testSplitLimit() {
Plan plan = new LogicalLimit<>(0, 0, LimitPhase.ORIGIN, scan1);
plan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.rewrite()
.getPlan();
plan.anyMatch(x -> x instanceof LogicalLimit && ((LogicalLimit<?>) x).isSplit());
LogicalPlan limit = new LogicalPlanBuilder(scan1)
.limit(0, 0)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), limit)
.applyTopDown(new SplitLimit())
.matches(
globalLogicalLimit(localLogicalLimit())
);
}
}

View File

@ -55,6 +55,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
@ -126,6 +127,12 @@ public class PlanChecker {
return applyTopDown(ruleFactory.buildRules());
}
public PlanChecker applyTopDown(CustomRewriter customRewriter) {
cascadesContext.topDownRewrite(customRewriter);
MemoValidator.validate(cascadesContext.getMemo());
return this;
}
public PlanChecker applyTopDown(List<Rule> rule) {
cascadesContext.topDownRewrite(rule);
MemoValidator.validate(cascadesContext.getMemo());