From 11fbe072213e1bb007b2e9d1b97217e22045df27 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sun, 12 Mar 2023 18:49:12 +0800 Subject: [PATCH] [refactor](Nereids) Refactor all rewrite logical unit tests by match-pattern (#17691) --- .../logical/EliminateGroupByConstantTest.java | 188 +++++++++--------- .../EliminateUnnecessaryProjectTest.java | 49 ++--- .../logical/FindHashConditionForJoinTest.java | 32 +-- .../rewrite/logical/MergeFiltersTest.java | 47 ++--- .../rewrite/logical/MergeLimitsTest.java | 44 ++-- .../PhysicalStorageLayerAggregateTest.java | 4 +- .../logical/PruneOlapScanPartitionTest.java | 74 ++++--- .../logical/PruneOlapScanTabletTest.java | 20 +- .../PushdownJoinOtherConditionTest.java | 168 +++++++--------- .../PushdownProjectThroughLimitTest.java | 4 +- .../rules/rewrite/logical/SplitLimitTest.java | 25 ++- .../doris/nereids/util/PlanChecker.java | 7 + 12 files changed, 317 insertions(+), 345 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java index 9aa4a027a8..4af6f2234b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstantTest.java @@ -23,30 +23,29 @@ import org.apache.doris.catalog.KeysType; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.plans.ObjectId; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.thrift.TStorageType; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; - -public class EliminateGroupByConstantTest { +/** Tests for {@link EliminateGroupByConstant}. */ +class EliminateGroupByConstantTest implements MemoPatternMatchSupported { private static final OlapTable table = new OlapTable(0L, "student", ImmutableList.of(new Column("k1", Type.INT, true, AggregateType.NONE, "0", ""), new Column("k2", Type.INT, false, AggregateType.NONE, "0", ""), @@ -65,110 +64,107 @@ public class EliminateGroupByConstantTest { } @Test - public void testIntegerLiteral() { - LogicalAggregate aggregate = new LogicalAggregate<>( - ImmutableList.of(new IntegerLiteral(1), k2), - ImmutableList.of(k1, k2), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table) - ); + void testIntegerLiteral() { + LogicalPlan aggregate = new LogicalPlanBuilder( + new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)) + .agg(ImmutableList.of(new IntegerLiteral(1), k2), + ImmutableList.of(k1, k2)) + .build(); - CascadesContext context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new EliminateGroupByConstant().build()); - context.bottomUpRewrite(new CheckAfterRewrite().build()); - - LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut()); - Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1); - Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot); + PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) + .applyTopDown(new EliminateGroupByConstant()) + .applyBottomUp(new CheckAfterRewrite()) + .matches( + aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2))) + ); } @Test - public void testOtherLiteral() { - LogicalAggregate aggregate = new LogicalAggregate<>( - ImmutableList.of( - new StringLiteral("str"), k2), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), k1, k2), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table) - ); + void testOtherLiteral() { + LogicalPlan aggregate = new LogicalPlanBuilder( + new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)) + .agg(ImmutableList.of( + new StringLiteral("str"), k2), + ImmutableList.of( + new Alias(new StringLiteral("str"), "str"), k1, k2)) + .build(); - CascadesContext context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new EliminateGroupByConstant().build()); - context.bottomUpRewrite(new CheckAfterRewrite().build()); - - LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut()); - Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1); - Assertions.assertTrue(aggregate1.getGroupByExpressions().get(0) instanceof Slot); + PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) + .applyTopDown(new EliminateGroupByConstant()) + .applyBottomUp(new CheckAfterRewrite()) + .matches( + aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2))) + ); } @Test - public void testMixedLiteral() { - LogicalAggregate aggregate = new LogicalAggregate<>( - ImmutableList.of( - new StringLiteral("str"), k2, - new IntegerLiteral(1), - new IntegerLiteral(2), - new IntegerLiteral(3), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), - k2, k1, new Alias(new IntegerLiteral(1), "integer")), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table) - ); + void testMixedLiteral() { + LogicalPlan aggregate = new LogicalPlanBuilder( + new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)) + .agg(ImmutableList.of( + new StringLiteral("str"), k2, + new IntegerLiteral(1), + new IntegerLiteral(2), + new IntegerLiteral(3), + new Add(k1, k2)), + ImmutableList.of( + new Alias(new StringLiteral("str"), "str"), + k2, k1, new Alias(new IntegerLiteral(1), "integer"))) + .build(); - CascadesContext context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new EliminateGroupByConstant().build()); - context.bottomUpRewrite(new CheckAfterRewrite().build()); - - LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut()); - Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2); - List groupByExprs = aggregate1.getGroupByExpressions(); - Assertions.assertTrue(groupByExprs.get(0) instanceof Slot - && groupByExprs.get(1) instanceof Add); + PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) + .applyTopDown(new EliminateGroupByConstant()) + .applyBottomUp(new CheckAfterRewrite()) + .matches( + aggregate() + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2)))) + ); } @Test - public void testComplexGroupBy() { - LogicalAggregate aggregate = new LogicalAggregate<>( - ImmutableList.of( - new IntegerLiteral(1), - new IntegerLiteral(2), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new Max(k1), "max"), - new Alias(new Min(k2), "min"), - new Alias(new Add(k1, k2), "add")), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table) - ); + void testComplexGroupBy() { + LogicalPlan aggregate = new LogicalPlanBuilder( + new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)) + .agg(ImmutableList.of( + new IntegerLiteral(1), + new IntegerLiteral(2), + new Add(k1, k2)), + ImmutableList.of( + new Alias(new Max(k1), "max"), + new Alias(new Min(k2), "min"), + new Alias(new Add(k1, k2), "add"))) + .build(); - CascadesContext context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new EliminateGroupByConstant().build()); - context.bottomUpRewrite(new CheckAfterRewrite().build()); - - LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut()); - Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 1); + PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) + .applyTopDown(new EliminateGroupByConstant()) + .applyBottomUp(new CheckAfterRewrite()) + .matches( + aggregate() + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2)))) + ); } @Test - public void testOutOfRange() { - LogicalAggregate aggregate = new LogicalAggregate<>( - ImmutableList.of( - new StringLiteral("str"), k2, - new IntegerLiteral(1), - new IntegerLiteral(2), - new IntegerLiteral(3), - new IntegerLiteral(5), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), - k2, k1, new Alias(new IntegerLiteral(1), "integer")), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table) - ); - - CascadesContext context = MemoTestUtils.createCascadesContext(aggregate); - context.topDownRewrite(new EliminateGroupByConstant().build()); - context.bottomUpRewrite(new CheckAfterRewrite().build()); - - LogicalAggregate aggregate1 = ((LogicalAggregate) context.getMemo().copyOut()); - Assertions.assertEquals(aggregate1.getGroupByExpressions().size(), 2); + void testOutOfRange() { + LogicalPlan aggregate = new LogicalPlanBuilder( + new LogicalOlapScan(ObjectId.createGenerator().getNextId(), table)) + .agg(ImmutableList.of( + new StringLiteral("str"), k2, + new IntegerLiteral(1), + new IntegerLiteral(2), + new IntegerLiteral(3), + new IntegerLiteral(5), + new Add(k1, k2)), + ImmutableList.of( + new Alias(new StringLiteral("str"), "str"), + k2, k1, new Alias(new IntegerLiteral(1), "integer"))) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) + .applyTopDown(new EliminateGroupByConstant()) + .applyBottomUp(new CheckAfterRewrite()) + .matches( + aggregate() + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2)))) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java index ecd6de1cd5..4435eb16f8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java @@ -17,28 +17,25 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; /** * test ELIMINATE_UNNECESSARY_PROJECT rule. */ -public class EliminateUnnecessaryProjectTest extends TestWithFeService { +class EliminateUnnecessaryProjectTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { @@ -55,57 +52,49 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService { } @Test - public void testEliminateNonTopUnnecessaryProject() { + void testEliminateNonTopUnnecessaryProject() { LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) .project(ImmutableList.of(1, 0)) .filter(BooleanLiteral.FALSE) .build(); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject); - cascadesContext.topDownRewrite(new EliminateUnnecessaryProject()); - - Plan actual = cascadesContext.getMemo().copyOut(); - Assertions.assertTrue(actual.child(0) instanceof LogicalProject); + PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject) + .applyTopDown(new EliminateUnnecessaryProject()) + .matchesFromRoot(logicalFilter(logicalProject())); } @Test - public void testEliminateTopUnnecessaryProject() { + void testEliminateTopUnnecessaryProject() { LogicalPlan unnecessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) .project(ImmutableList.of(0, 1)) .build(); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject); - cascadesContext.topDownRewrite(new EliminateUnnecessaryProject()); - - Plan actual = cascadesContext.getMemo().copyOut(); - Assertions.assertTrue(actual instanceof LogicalOlapScan); + PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject) + .applyTopDown(new EliminateUnnecessaryProject()) + .matchesFromRoot(logicalOlapScan()); } @Test - public void testNotEliminateTopProjectWhenOutputNotEquals() { + void testNotEliminateTopProjectWhenOutputNotEquals() { LogicalPlan necessaryProject = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) .project(ImmutableList.of(1, 0)) .build(); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(necessaryProject); - cascadesContext.topDownRewrite(new EliminateUnnecessaryProject()); - - Plan actual = cascadesContext.getMemo().copyOut(); - Assertions.assertTrue(actual instanceof LogicalProject); + PlanChecker.from(MemoTestUtils.createConnectContext(), necessaryProject) + .applyTopDown(new EliminateUnnecessaryProject()) + .matchesFromRoot(logicalProject()); } @Test - public void testEliminateProjectWhenEmptyRelationChild() { + void testEliminateProjectWhenEmptyRelationChild() { LogicalPlan unnecessaryProject = new LogicalPlanBuilder(new LogicalEmptyRelation(ImmutableList.of( new SlotReference("k1", IntegerType.INSTANCE), new SlotReference("k2", IntegerType.INSTANCE)))) .project(ImmutableList.of(1, 0)) .build(); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(unnecessaryProject); - cascadesContext.topDownRewrite(new EliminateUnnecessaryProject()); - - Plan actual = cascadesContext.getMemo().copyOut(); - Assertions.assertTrue(actual instanceof LogicalEmptyRelation); + PlanChecker.from(MemoTestUtils.createConnectContext(), unnecessaryProject) + .applyTopDown(new EliminateUnnecessaryProject()) + .matchesFromRoot(logicalEmptyRelation()); } // TODO: uncomment this after the Elimination project rule is correctly implemented diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java index 020e1d79d6..8bf0bada3f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java @@ -17,8 +17,6 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -31,12 +29,12 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -54,9 +52,9 @@ import java.util.Optional; * -hashJoinConjuncts={A.x=B.x, A.y+1=B.y} * -otherJoinCondition="A.x=1 and (A.x=1 or B.x=A.x) and A.x>B.x" */ -class FindHashConditionForJoinTest { +class FindHashConditionForJoinTest implements MemoPatternMatchSupported { @Test - public void testFindHashCondition() { + void testFindHashCondition() { Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, @@ -77,20 +75,12 @@ class FindHashConditionForJoinTest { List expr = ImmutableList.of(eq1, eq2, eq3, or, less); LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(), expr, JoinHint.NONE, Optional.empty(), student, score); - CascadesContext context = MemoTestUtils.createCascadesContext(join); - List rules = Lists.newArrayList(new FindHashConditionForJoin().build()); - context.topDownRewrite(rules); - Plan plan = context.getMemo().copyOut(); - LogicalJoin after = (LogicalJoin) plan; - Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2); - Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1)); - Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3)); - List others = after.getOtherJoinConjuncts(); - Assertions.assertEquals(others.size(), 3); - Assertions.assertTrue(others.contains(less)); - Assertions.assertTrue(others.contains(eq2)); - Assertions.assertTrue(others.contains(less)); + PlanChecker.from(new ConnectContext(), join) + .applyTopDown(new FindHashConditionForJoin()) + .matches( + logicalJoin() + .when(j -> j.getHashJoinConjuncts().equals(ImmutableList.of(eq1, eq3))) + .when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(eq2, or, less)))); } - } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java index 31bf853bdf..e706151e0e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java @@ -17,47 +17,42 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.analyzer.UnboundRelation; -import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; -import java.util.Set; - /** - * MergeConsecutiveFilter ut + * Tests for {@link MergeFilters}. */ -public class MergeFiltersTest { +class MergeFiltersTest implements MemoPatternMatchSupported { @Test - public void testMergeConsecutiveFilters() { - UnboundRelation relation = new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")); + void testMergeFilters() { Expression expression1 = new IntegerLiteral(1); - LogicalFilter filter1 = new LogicalFilter<>(ImmutableSet.of(expression1), relation); Expression expression2 = new IntegerLiteral(2); - LogicalFilter filter2 = new LogicalFilter<>(ImmutableSet.of(expression2), filter1); Expression expression3 = new IntegerLiteral(3); - LogicalFilter filter3 = new LogicalFilter<>(ImmutableSet.of(expression3), filter2); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3); - List rules = Lists.newArrayList(new MergeFilters().build()); - cascadesContext.bottomUpRewrite(rules); - //check transformed plan - Plan resultPlan = cascadesContext.getMemo().copyOut(); - System.out.println(resultPlan.treeString()); - Assertions.assertTrue(resultPlan instanceof LogicalFilter); - Set allPredicates = ImmutableSet.of(expression1, expression2, expression3); - Assertions.assertEquals(ImmutableSet.copyOf(((LogicalFilter) resultPlan).getConjuncts()), allPredicates); - Assertions.assertTrue(resultPlan.child(0) instanceof UnboundRelation); + LogicalPlan logicalFilter = new LogicalPlanBuilder( + new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table"))) + .filter(ImmutableSet.of(expression1)) + .filter(ImmutableSet.of(expression2)) + .filter(ImmutableSet.of(expression3)) + .build(); + + PlanChecker.from(new ConnectContext(), logicalFilter).applyBottomUp(new MergeFilters()) + .matches( + logicalFilter( + unboundRelation() + ).when(filter -> filter.getConjuncts() + .equals(ImmutableSet.of(expression1, expression2, expression3)))); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java index 869dec982f..ae608f1d26 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java @@ -17,37 +17,33 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.analyzer.UnboundRelation; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.trees.plans.LimitPhase; -import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; - -public class MergeLimitsTest { +/** + * Tests for {@link MergeLimits}. + */ +class MergeLimitsTest implements MemoPatternMatchSupported { @Test - public void testMergeConsecutiveLimits() { - LogicalLimit limit3 = new LogicalLimit<>(3, 5, LimitPhase.ORIGIN, new UnboundRelation( - RelationUtil.newRelationId(), Lists.newArrayList("db", "t"))); - LogicalLimit limit2 = new LogicalLimit<>(2, 0, LimitPhase.ORIGIN, limit3); - LogicalLimit limit1 = new LogicalLimit<>(10, 0, LimitPhase.ORIGIN, limit2); - - CascadesContext context = MemoTestUtils.createCascadesContext(limit1); - List rules = Lists.newArrayList(new MergeLimits().build()); - context.topDownRewrite(rules); - LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut(); - - Assertions.assertEquals(2, limit.getLimit()); - Assertions.assertEquals(5, limit.getOffset()); - Assertions.assertEquals(1, limit.children().size()); - Assertions.assertTrue(limit.child(0) instanceof UnboundRelation); + void testMergeLimits() { + LogicalPlan logicalLimit = new LogicalPlanBuilder( + new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "t"))) + .limit(3, 5) + .limit(2, 0) + .limit(10, 0).build(); + PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new MergeLimits()) + .matches( + logicalLimit( + unboundRelation() + ).when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5)); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java index 1556438652..69e07eebc7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PhysicalStorageLayerAggregateTest.java @@ -18,7 +18,6 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.pattern.GeneratedMemoPatterns; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RulePromise; import org.apache.doris.nereids.rules.RuleType; @@ -32,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; @@ -42,7 +42,7 @@ import org.junit.jupiter.api.Test; import java.util.Collections; import java.util.Optional; -public class PhysicalStorageLayerAggregateTest implements GeneratedMemoPatterns { +public class PhysicalStorageLayerAggregateTest implements MemoPatternMatchSupported { @Test public void testWithoutProject() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java index a22d2b29f5..73156a6cfd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java @@ -26,8 +26,6 @@ import org.apache.doris.catalog.RangePartitionInfo; import org.apache.doris.catalog.RangePartitionItem; import org.apache.doris.catalog.Type; import org.apache.doris.common.jmockit.Deencapsulation; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; @@ -36,20 +34,19 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Range; import mockit.Expectations; import mockit.Mocked; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -58,10 +55,10 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; -class PruneOlapScanPartitionTest { +class PruneOlapScanPartitionTest implements MemoPatternMatchSupported { @Test - public void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception { + void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception { List columnNameList = new ArrayList<>(); columnNameList.add(new Column("col1", Type.INT.getPrimitiveType())); columnNameList.add(new Column("col2", Type.INT.getPrimitiveType())); @@ -92,24 +89,29 @@ class PruneOlapScanPartitionTest { Expression expression = new LessThan(slotRef, new IntegerLiteral(4)); LogicalFilter filter = new LogicalFilter<>(ImmutableSet.of(expression), scan); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter); - List rules = Lists.newArrayList(new PruneOlapScanPartition().build()); - cascadesContext.topDownRewrite(rules); - Plan resultPlan = cascadesContext.getMemo().copyOut(); - LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0); - Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next()); + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(new PruneOlapScanPartition()) + .matches( + logicalFilter( + logicalOlapScan().when( + olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L) + ) + ); Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0)); Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6)); Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6); filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6), scan); scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); - cascadesContext = MemoTestUtils.createCascadesContext(filter); - rules = Lists.newArrayList(new PruneOlapScanPartition().build()); - cascadesContext.topDownRewrite(rules); - resultPlan = cascadesContext.getMemo().copyOut(); - rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0); - Assertions.assertEquals(1L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next()); + + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(new PruneOlapScanPartition()) + .matches( + logicalFilter( + logicalOlapScan().when( + olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 1L) + ) + ); Expression greaterThanEqual0 = new GreaterThanEqual( @@ -118,17 +120,20 @@ class PruneOlapScanPartitionTest { new LessThanEqual(slotRef, new IntegerLiteral(5)); scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0, lessThanEqual5), scan); - cascadesContext = MemoTestUtils.createCascadesContext(filter); - rules = Lists.newArrayList(new PruneOlapScanPartition().build()); - cascadesContext.topDownRewrite(rules); - resultPlan = cascadesContext.getMemo().copyOut(); - rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0); - Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next()); - Assertions.assertEquals(2, rewrittenOlapScan.getSelectedPartitionIds().toArray().length); + + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(new PruneOlapScanPartition()) + .matches( + logicalFilter( + logicalOlapScan().when( + olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L) + .when(olapScan -> olapScan.getSelectedPartitionIds().size() == 2) + ) + ); } @Test - public void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception { + void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception { List columnNameList = new ArrayList<>(); columnNameList.add(new Column("col1", Type.INT.getPrimitiveType())); columnNameList.add(new Column("col2", Type.INT.getPrimitiveType())); @@ -155,11 +160,14 @@ class PruneOlapScanPartitionTest { Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4)); Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11)); LogicalFilter filter = new LogicalFilter<>(ImmutableSet.of(left, right), scan); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter); - List rules = Lists.newArrayList(new PruneOlapScanPartition().build()); - cascadesContext.topDownRewrite(rules); - Plan resultPlan = cascadesContext.getMemo().copyOut(); - LogicalOlapScan rewrittenOlapScan = (LogicalOlapScan) resultPlan.child(0); - Assertions.assertEquals(0L, rewrittenOlapScan.getSelectedPartitionIds().iterator().next()); + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(new PruneOlapScanPartition()) + .matches( + logicalFilter( + logicalOlapScan() + .when( + olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L) + ) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java index 7ee589a422..5d34aec5db 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java @@ -29,7 +29,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Partition; import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -41,7 +40,9 @@ import org.apache.doris.nereids.trees.plans.ObjectId; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.planner.PartitionColumnFilter; import com.google.common.collect.ImmutableList; @@ -54,10 +55,10 @@ import org.junit.jupiter.api.Test; import java.util.List; -public class PruneOlapScanTabletTest { +class PruneOlapScanTabletTest implements MemoPatternMatchSupported { @Test - public void testPruneOlapScanTablet(@Mocked OlapTable olapTable, + void testPruneOlapScanTablet(@Mocked OlapTable olapTable, @Mocked Partition partition, @Mocked MaterializedIndex index, @Mocked HashDistributionInfo distributionInfo) { List tabletIds = Lists.newArrayListWithExpectedSize(300); @@ -153,11 +154,12 @@ public class PruneOlapScanTabletTest { Assertions.assertEquals(0, filter.child().getSelectedTabletIds().size()); - CascadesContext context = MemoTestUtils.createCascadesContext(filter); - context.topDownRewrite(ImmutableList.of(new PruneOlapScanTablet().build())); - - LogicalFilter filter1 = ((LogicalFilter) context.getMemo().copyOut()); - LogicalOlapScan olapScan = filter1.child(); - Assertions.assertEquals(19, olapScan.getSelectedTabletIds().size()); + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(new PruneOlapScanTablet()) + .matches( + logicalFilter( + logicalOlapScan().when(scan -> scan.getSelectedTabletIds().size() == 19) + ) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java index 48dbbb7621..b04318b072 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java @@ -17,52 +17,53 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.memo.Group; -import org.apache.doris.nereids.memo.Memo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.PlanRewriter; -import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import java.util.List; -import java.util.Optional; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class PushdownJoinOtherConditionTest { +class PushdownJoinOtherConditionTest implements MemoPatternMatchSupported { - private Plan rStudent; - private Plan rScore; + private LogicalOlapScan rStudent; + private LogicalOlapScan rScore; + + private List rStudentSlots; + + private List rScoreSlots; /** * ut before. */ @BeforeAll - public final void beforeAll() { + final void beforeAll() { rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); + rStudentSlots = rStudent.getOutput(); + rScoreSlots = rScore.getOutput(); } @Test - public void oneSide() { + void oneSide() { oneSide(JoinType.CROSS_JOIN, false); oneSide(JoinType.INNER_JOIN, false); oneSide(JoinType.LEFT_OUTER_JOIN, true); @@ -74,46 +75,43 @@ public class PushdownJoinOtherConditionTest { } private void oneSide(JoinType joinType, boolean testRight) { - - Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); + Expression pushSide1 = new GreaterThan(rStudentSlots.get(1), Literal.of(18)); + Expression pushSide2 = new GreaterThan(rStudentSlots.get(1), Literal.of(50)); List condition = ImmutableList.of(pushSide1, pushSide2); - Plan left = rStudent; - Plan right = rScore; + LogicalOlapScan left = rStudent; + LogicalOlapScan right = rScore; if (testRight) { left = rScore; right = rStudent; } - Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right); - Plan root = new LogicalProject<>(Lists.newArrayList(), join); + LogicalPlan root = new LogicalPlanBuilder(left) + .join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition) + .project(Lists.newArrayList()) + .build(); - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); + PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new PushdownJoinOtherCondition()); - Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); if (testRight) { - shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); + planChecker.matches( + logicalJoin( + logicalOlapScan(), + logicalFilter() + .when(filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition)))); + } else { + planChecker.matches( + logicalJoin( + logicalFilter().when( + filter -> ImmutableList.copyOf(filter.getConjuncts()).equals(condition)), + logicalOlapScan())); + } - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(shouldFilter instanceof LogicalFilter); - Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); - LogicalFilter actualFilter = (LogicalFilter) shouldFilter; - - Assertions.assertEquals(condition, ImmutableList.copyOf(actualFilter.getConjuncts())); } @Test - public void bothSideToBothSide() { + void bothSideToBothSide() { bothSideToBothSide(JoinType.CROSS_JOIN); bothSideToBothSide(JoinType.INNER_JOIN); bothSideToBothSide(JoinType.LEFT_SEMI_JOIN); @@ -122,34 +120,26 @@ public class PushdownJoinOtherConditionTest { private void bothSideToBothSide(JoinType joinType) { - Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression leftSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18)); + Expression rightSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60)); List condition = ImmutableList.of(leftSide, rightSide); - Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), rStudent, - rScore); - Plan root = new LogicalProject<>(Lists.newArrayList(), join); + LogicalPlan root = new LogicalPlanBuilder(rStudent) + .join(rScore, joinType, ExpressionUtils.EMPTY_CONDITION, condition) + .project(Lists.newArrayList()) + .build(); - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - - Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(leftFilter instanceof LogicalFilter); - Assertions.assertTrue(rightFilter instanceof LogicalFilter); - LogicalFilter actualLeft = (LogicalFilter) leftFilter; - LogicalFilter actualRight = (LogicalFilter) rightFilter; - Assertions.assertEquals(ImmutableSet.of(leftSide), actualLeft.getConjuncts()); - Assertions.assertEquals(ImmutableSet.of(rightSide), actualRight.getConjuncts()); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new PushdownJoinOtherCondition()) + .matches( + logicalJoin( + logicalFilter().when(left -> left.getConjuncts().equals(ImmutableSet.of(leftSide))), + logicalFilter().when(right -> right.getConjuncts().equals(ImmutableSet.of(rightSide))) + )); } @Test - public void bothSideToOneSide() { + void bothSideToOneSide() { bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true); bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true); bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false); @@ -157,45 +147,37 @@ public class PushdownJoinOtherConditionTest { } private void bothSideToOneSide(JoinType joinType, boolean testRight) { - - Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression pushSide = new GreaterThan(rStudentSlots.get(1), Literal.of(18)); + Expression reserveSide = new GreaterThan(rScoreSlots.get(2), Literal.of(60)); List condition = ImmutableList.of(pushSide, reserveSide); - Plan left = rStudent; - Plan right = rScore; + LogicalOlapScan left = rStudent; + LogicalOlapScan right = rScore; if (testRight) { left = rScore; right = rStudent; } - Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right); - Plan root = new LogicalProject<>(Lists.newArrayList(), join); + LogicalPlan root = new LogicalPlanBuilder(left) + .join(right, joinType, ExpressionUtils.EMPTY_CONDITION, condition) + .project(Lists.newArrayList()) + .build(); - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); + PlanChecker planChecker = PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new PushdownJoinOtherCondition()); - Plan shouldJoin = rootGroup.getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan shouldFilter = rootGroup.getLogicalExpression() - .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan shouldScan = rootGroup.getLogicalExpression() - .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); if (testRight) { - shouldFilter = rootGroup.getLogicalExpression() - .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); - shouldScan = rootGroup.getLogicalExpression() - .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); + planChecker.matches( + logicalJoin( + logicalOlapScan(), + logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide))) + )); + } else { + planChecker.matches( + logicalJoin( + logicalFilter().when(filter -> filter.getConjuncts().equals(ImmutableSet.of(pushSide))), + logicalOlapScan() + )); } - - Assertions.assertTrue(shouldJoin instanceof LogicalJoin); - Assertions.assertTrue(shouldFilter instanceof LogicalFilter); - Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); - LogicalFilter actualFilter = (LogicalFilter) shouldFilter; - Assertions.assertEquals(ImmutableSet.of(pushSide), actualFilter.getConjuncts()); - } - - private Memo rewrite(Plan plan) { - return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownJoinOtherCondition()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java index 77bf76af4c..d66ede574f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimitTest.java @@ -27,10 +27,10 @@ import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; -public class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported { +class PushdownProjectThroughLimitTest implements MemoPatternMatchSupported { @Test - public void testPushdownProjectThroughLimit() { + void testPushdownProjectThroughLimit() { LogicalPlan project = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) .limit(1, 1) .project(ImmutableList.of(0)) // id diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java index 174f5a90b4..3948406abd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SplitLimitTest.java @@ -17,25 +17,32 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.trees.plans.LimitPhase; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.junit.jupiter.api.Test; -public class SplitLimitTest { +/** + * Tests for {@link SplitLimit}. + */ +class SplitLimitTest implements MemoPatternMatchSupported { private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); @Test void testSplitLimit() { - Plan plan = new LogicalLimit<>(0, 0, LimitPhase.ORIGIN, scan1); - plan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .rewrite() - .getPlan(); - plan.anyMatch(x -> x instanceof LogicalLimit && ((LogicalLimit) x).isSplit()); + LogicalPlan limit = new LogicalPlanBuilder(scan1) + .limit(0, 0) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), limit) + .applyTopDown(new SplitLimit()) + .matches( + globalLogicalLimit(localLogicalLimit()) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index f49711aa89..ce713c2f4d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -55,6 +55,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.planner.PlanFragment; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; @@ -126,6 +127,12 @@ public class PlanChecker { return applyTopDown(ruleFactory.buildRules()); } + public PlanChecker applyTopDown(CustomRewriter customRewriter) { + cascadesContext.topDownRewrite(customRewriter); + MemoValidator.validate(cascadesContext.getMemo()); + return this; + } + public PlanChecker applyTopDown(List rule) { cascadesContext.topDownRewrite(rule); MemoValidator.validate(cascadesContext.getMemo());