diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index ee270eccac..00d864ab01 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -26,15 +26,15 @@ import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.UnboundLogicalProperties; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.FakePlan; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.LeafPlan; import org.apache.doris.nereids.trees.plans.LimitPhase; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -42,6 +42,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -57,8 +58,10 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Map; import java.util.Objects; +import java.util.Set; class MemoTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private final ConnectContext connectContext = MemoTestUtils.createConnectContext(); @@ -186,11 +189,9 @@ class MemoTest implements MemoPatternMatchSupported { @Test public void initByTwoLevelChainPlan() { - OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), table); - - LogicalProject topProject = new LogicalProject<>( - ImmutableList.of(scan.computeOutput().get(0)), scan); + Plan topProject = new LogicalPlanBuilder(scan) + .project(ImmutableList.of(0)) + .build(); PlanChecker.from(connectContext, topProject) .checkGroupNum(2) @@ -250,13 +251,11 @@ class MemoTest implements MemoPatternMatchSupported { @Test public void initByThreeLevelChainPlan() { - OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), table); - - LogicalProject project = new LogicalProject<>( - ImmutableList.of(scan.computeOutput().get(0)), scan); - LogicalFilter> filter = new LogicalFilter<>(ImmutableSet.of( - new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1))), project); + Set exprs = ImmutableSet.of(new EqualTo(scan.getOutput().get(0), Literal.of(1))); + Plan filter = new LogicalPlanBuilder(scan) + .project(ImmutableList.of(0)) + .filter(exprs) + .build(); PlanChecker.from(connectContext, filter) .checkGroupNum(3) @@ -264,8 +263,8 @@ class MemoTest implements MemoPatternMatchSupported { logicalFilter( logicalProject( any().when(child -> Objects.equals(child, scan)) - ).when(root -> Objects.equals(root, project)) - ).when(root -> Objects.equals(root, filter)) + ).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).equals(scan.getOutput().get(0))) + ).when(f -> Objects.equals(f, filter)) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index 7425615fab..488036f461 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -153,7 +153,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP // after analyze PlanChecker.from(connectContext) .analyze(sql2) - .matches( + .matchesNotCheck( logicalApply( any(), logicalAggregate( @@ -181,7 +181,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP PlanChecker.from(connectContext) .analyze(sql2) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) - .matches( + .matchesNotCheck( logicalApply( any(), logicalAggregate().when(FieldChecker.check("outputExpressions", ImmutableList.of( @@ -216,7 +216,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .analyze(sql2) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) .applyBottomUp(new ScalarApplyToJoin()) - .matches( + .matchesNotCheck( logicalJoin( any(), logicalAggregate() @@ -240,7 +240,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP //"select * from t6 where t6.k1 in (select t7.k3 from t7 where t7.v2 = t6.k2)" PlanChecker.from(connectContext) .analyze(sql4) - .matches( + .matchesNotCheck( logicalApply( any(), logicalProject().when(FieldChecker.check("projects", ImmutableList.of( @@ -303,7 +303,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP //"select * from t6 where exists (select t7.k3 from t7 where t6.k2 = t7.v2)" PlanChecker.from(connectContext) .analyze(sql6) - .matches( + .matchesNotCheck( logicalApply( any(), logicalProject().when(FieldChecker.check("projects", ImmutableList.of( @@ -321,7 +321,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP PlanChecker.from(connectContext) .analyze(sql6) .applyBottomUp(new PullUpProjectUnderApply()) - .matches( + .matchesNotCheck( logicalProject( logicalApply().when(FieldChecker.check("correlationFilter", Optional.empty())) .when(FieldChecker.check("correlationSlot", ImmutableList.of( @@ -418,7 +418,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) .applyTopDown(new MergeProjects()) .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) - .matches( + .matchesNotCheck( logicalApply( any(), logicalAggregate( @@ -451,7 +451,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .applyTopDown(new MergeProjects()) .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) - .matches( + .matchesNotCheck( logicalApply( any(), logicalAggregate( @@ -478,14 +478,13 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) .applyBottomUp(new ScalarApplyToJoin()) - .matches( - logicalJoin( + .matchesNotCheck( + leftSemiLogicalJoin( any(), logicalAggregate( logicalProject() ) ) - .when(j -> j.getJoinType().equals(JoinType.LEFT_SEMI_JOIN)) .when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of( new LessThan(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java index 37508d5387..3fc2fec9a6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -18,14 +18,11 @@ package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; -import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.types.IntegerType; -import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -35,17 +32,18 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; public class NormalizeRepeatTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); @Test public void testKeepNullableAfterNormalizeRepeat() { - SlotReference slot1 = new SlotReference("id", IntegerType.INSTANCE, false); - SlotReference slot2 = slot1.withNullable(true); - SlotReference slot3 = new SlotReference("name", StringType.INSTANCE, false); - Alias alias = new Alias(new Sum(slot3), "sum(name)"); + Slot id = scan1.getOutput().get(0); + Slot idNotNull = id.withNullable(true); + Slot name = scan1.getOutput().get(1); + Alias alias = new Alias(new Sum(name), "sum(name)"); Plan plan = new LogicalRepeat<>( - ImmutableList.of(ImmutableList.of(slot1)), - ImmutableList.of(slot2, alias), - new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.newOlapTable(0, "t", 0)) + ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(idNotNull, alias), + scan1 ); PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) .applyTopDown(new NormalizeRepeat()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java index e706151e0e..af8e208034 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java @@ -17,32 +17,32 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.RelationUtil; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; /** * Tests for {@link MergeFilters}. */ class MergeFiltersTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + @Test void testMergeFilters() { Expression expression1 = new IntegerLiteral(1); Expression expression2 = new IntegerLiteral(2); Expression expression3 = new IntegerLiteral(3); - LogicalPlan logicalFilter = new LogicalPlanBuilder( - new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table"))) + LogicalPlan logicalFilter = new LogicalPlanBuilder(scan) .filter(ImmutableSet.of(expression1)) .filter(ImmutableSet.of(expression2)) .filter(ImmutableSet.of(expression3)) @@ -50,9 +50,7 @@ class MergeFiltersTest implements MemoPatternMatchSupported { PlanChecker.from(new ConnectContext(), logicalFilter).applyBottomUp(new MergeFilters()) .matches( - logicalFilter( - unboundRelation() - ).when(filter -> filter.getConjuncts() + logicalFilter().when(filter -> filter.getConjuncts() .equals(ImmutableSet.of(expression1, expression2, expression3)))); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java index ae608f1d26..bae4abbfe3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java @@ -17,33 +17,31 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.analyzer.UnboundRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.RelationUtil; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.qe.ConnectContext; -import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; /** * Tests for {@link MergeLimits}. */ class MergeLimitsTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + @Test void testMergeLimits() { - LogicalPlan logicalLimit = new LogicalPlanBuilder( - new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "t"))) + LogicalPlan logicalLimit = new LogicalPlanBuilder(scan1) .limit(3, 5) .limit(2, 0) .limit(10, 0).build(); PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new MergeLimits()) .matches( - logicalLimit( - unboundRelation() - ).when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5)); + logicalLimit().when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5)); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java index 5d34aec5db..50aff4011c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanTabletTest.java @@ -28,18 +28,15 @@ import org.apache.doris.catalog.MaterializedIndex; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Partition; import org.apache.doris.catalog.PrimitiveType; -import org.apache.doris.catalog.Type; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.ObjectId; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -79,24 +76,24 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported { k0Filter.setUpperBound(new StringLiteral("2019-08-22"), true); PartitionColumnFilter k1Filter = new PartitionColumnFilter(); - List inList = Lists.newArrayList(); - inList.add(new IntLiteral(100)); - inList.add(new IntLiteral(200)); - inList.add(new IntLiteral(300)); - inList.add(new IntLiteral(400)); - inList.add(new IntLiteral(500)); + List inList = ImmutableList.of( + new IntLiteral(100), + new IntLiteral(200), + new IntLiteral(300), + new IntLiteral(400), + new IntLiteral(500)); k1Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k1"), inList, false)); PartitionColumnFilter k2Filter = new PartitionColumnFilter(); - List inList2 = Lists.newArrayList(); - inList2.add(new IntLiteral(900)); - inList2.add(new IntLiteral(1100)); + List inList2 = ImmutableList.of( + new IntLiteral(900), + new IntLiteral(1100)); k2Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k2"), inList2, false)); PartitionColumnFilter k3Filter = new PartitionColumnFilter(); - List inList3 = Lists.newArrayList(); - inList3.add(new IntLiteral(1)); - inList3.add(new IntLiteral(3)); + List inList3 = ImmutableList.of( + new IntLiteral(1), + new IntLiteral(3)); k3Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k3"), inList3, false)); PartitionColumnFilter k4Filter = new PartitionColumnFilter(); @@ -104,31 +101,14 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported { inList4.add(new IntLiteral(2)); k4Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k4"), inList4, false)); - SlotReference k0 = new SlotReference("k0", DataType.fromCatalogType(Type.INT), false, ImmutableList.of()); - SlotReference k1 = new SlotReference("k1", DataType.fromCatalogType(Type.INT), false, ImmutableList.of()); - SlotReference k2 = new SlotReference("k2", DataType.fromCatalogType(Type.INT), false, ImmutableList.of()); - SlotReference k3 = new SlotReference("k3", DataType.fromCatalogType(Type.INT), false, ImmutableList.of()); - SlotReference k4 = new SlotReference("k4", DataType.fromCatalogType(Type.INT), false, ImmutableList.of()); - - GreaterThanEqual greaterThanEqual = new GreaterThanEqual(k0, new DateLiteral("2019-08-22")); - LessThanEqual lessThanEqual = new LessThanEqual(k0, new DateLiteral("2019-08-22")); - - InPredicate inPredicate1 = new InPredicate(k1, ImmutableList.of(new IntegerLiteral(101), - new IntegerLiteral(201), - new IntegerLiteral(301), - new IntegerLiteral(401), - new IntegerLiteral(500))); - InPredicate inPredicate2 = new InPredicate(k2, ImmutableList.of(new IntegerLiteral(901), - new IntegerLiteral(1101))); - InPredicate inPredicate3 = new InPredicate(k3, ImmutableList.of(new IntegerLiteral(1), - new IntegerLiteral(3))); - EqualTo equalTo = new EqualTo(k4, new IntegerLiteral(10)); - new Expectations() { { olapTable.getPartitionIds(); result = ImmutableList.of(1L); + olapTable.getBaseSchema(); + result = columns; + olapTable.getName(); result = "t1"; olapTable.getPartition(anyLong); @@ -148,9 +128,21 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported { } }; + LogicalOlapScan scan = new LogicalOlapScan(ObjectId.createGenerator().getNextId(), olapTable); + + GreaterThanEqual greaterThanEqual = new GreaterThanEqual(scan.getOutput().get(0), + new DateLiteral("2019-08-22")); + LessThanEqual lessThanEqual = new LessThanEqual(scan.getOutput().get(0), new DateLiteral("2019-08-22")); + InPredicate inPredicate1 = new InPredicate(scan.getOutput().get(1), ImmutableList.of(Literal.of(101), + Literal.of(201), Literal.of(301), Literal.of(401), Literal.of(500))); + InPredicate inPredicate2 = new InPredicate(scan.getOutput().get(2), ImmutableList.of(Literal.of(901), + Literal.of(1101))); + InPredicate inPredicate3 = new InPredicate(scan.getOutput().get(3), ImmutableList.of(Literal.of(1), + Literal.of(3))); + EqualTo equalTo = new EqualTo(scan.getOutput().get(4), Literal.of(10)); LogicalFilter filter = new LogicalFilter<>( ImmutableSet.of(greaterThanEqual, lessThanEqual, inPredicate1, inPredicate2, inPredicate3, equalTo), - new LogicalOlapScan(ObjectId.createGenerator().getNextId(), olapTable)); + scan); Assertions.assertEquals(0, filter.child().getSelectedTabletIds().size()); @@ -158,7 +150,7 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported { .applyTopDown(new PruneOlapScanTablet()) .matches( logicalFilter( - logicalOlapScan().when(scan -> scan.getSelectedTabletIds().size() == 19) + logicalOlapScan().when(s -> s.getSelectedTabletIds().size() == 19) ) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownAliasThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownAliasThroughJoinTest.java index 9f08d7f073..778159d232 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownAliasThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownAliasThroughJoinTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -30,12 +31,14 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @Test void testPushdown() { // condition don't use alias slot - LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1) - .join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .alias(ImmutableList.of(1, 3), ImmutableList.of("1name", "2name")) .build(); @@ -54,8 +57,8 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { @Test void testCondition() { // condition use alias slot - LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1) - .join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .alias(ImmutableList.of(0, 1, 3), ImmutableList.of("1id", "1name", "2name")) .build(); @@ -79,8 +82,8 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { @Test void testJustRightSide() { // condition use alias slot - LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1) - .join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .alias(ImmutableList.of(2, 3), ImmutableList.of("2id", "2name")) .build(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java index 474f1acd49..d141877844 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java @@ -18,11 +18,14 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -39,6 +42,8 @@ import com.google.common.collect.ImmutableSet; import org.junit.jupiter.api.Test; public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, + ImmutableList.of("")); /*- * origin plan: @@ -61,8 +66,6 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup */ @Test public void pushDownPredicateOneFilterTest() { - LogicalPlan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); Slot gender = scan.getOutput().get(1); Expression filterPredicate = new GreaterThan(gender, Literal.of(1)); @@ -89,7 +92,7 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup * origin plan: * project * | - * filter gender=1 and name="abc" and (gender+10)<100 + * filter gender=1 and nameMax="abc" and (gender+10)<100 * | * aggregation group by gender * | @@ -98,7 +101,7 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup * transformed plan: * project * | - * filter name="abc" + * filter nameMax="abc" * | * aggregation group by gender * | @@ -108,20 +111,19 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup */ @Test public void pushDownPredicateTwoFilterTest() { - LogicalPlan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); Slot gender = scan.getOutput().get(1); - Slot name = scan.getOutput().get(2); + NamedExpression nameMax = new Alias(new Max(scan.getOutput().get(2)), "nameMax"); Expression filterPredicate = ExpressionUtils.and( new GreaterThan(gender, Literal.of(1)), new LessThanEqual( new Add(gender, Literal.of(10)), Literal.of(100)), - new EqualTo(name, Literal.of("abc"))); + new EqualTo(nameMax.toSlot(), Literal.of("abc"))); LogicalPlan plan = new LogicalPlanBuilder(scan) - .aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3)) + .aggGroupUsingIndex(ImmutableList.of(3, 1), ImmutableList.of( + scan.getOutput().get(1), scan.getOutput().get(3), nameMax)) .filter(filterPredicate) .project(ImmutableList.of(0)) .build(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java index 61271f49b0..7a4c518d28 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java @@ -105,8 +105,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup logicalLimit( logicalProject( logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalOlapScan().when(s -> s.equals(scanStudent)) + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))), + logicalOlapScan().when(s -> s.getTable().getName().equals("student")) ).when(j -> j.getJoinType() == JoinType.LEFT_OUTER_JOIN) ) ) @@ -114,8 +114,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup test(JoinType.LEFT_OUTER_JOIN, false, logicalLimit( logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalOlapScan().when(s -> s.equals(scanStudent)) + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))), + logicalOlapScan().when(s -> s.getTable().getName().equals("student")) ).when(j -> j.getJoinType() == JoinType.LEFT_OUTER_JOIN) ) ); @@ -127,19 +127,19 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup test(JoinType.RIGHT_OUTER_JOIN, true, logicalLimit( logicalProject( - logicalJoin( - logicalOlapScan().when(s -> s.equals(scanScore)), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) - ).when(j -> j.getJoinType() == JoinType.RIGHT_OUTER_JOIN) + rightOuterLogicalJoin( + logicalOlapScan().when(s -> s.getTable().getName().equals("score")), + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student"))) + ) ) ) ); test(JoinType.RIGHT_OUTER_JOIN, false, logicalLimit( - logicalJoin( - logicalOlapScan().when(s -> s.equals(scanScore)), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) - ).when(j -> j.getJoinType() == JoinType.RIGHT_OUTER_JOIN) + rightOuterLogicalJoin( + logicalOlapScan().when(s -> s.getTable().getName().equals("score")), + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student"))) + ) ) ); } @@ -149,19 +149,19 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup test(JoinType.CROSS_JOIN, true, logicalLimit( logicalProject( - logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) - ).when(j -> j.getJoinType() == JoinType.CROSS_JOIN) + crossLogicalJoin( + logicalLimit(logicalOlapScan()), + logicalLimit(logicalOlapScan()) + ) ) ) ); test(JoinType.CROSS_JOIN, false, logicalLimit( - logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) - ).when(j -> j.getJoinType() == JoinType.CROSS_JOIN) + crossLogicalJoin( + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))), + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student"))) + ) ) ); } @@ -172,8 +172,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup logicalLimit( logicalProject( logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))), + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student"))) ) ) ) @@ -181,8 +181,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup test(JoinType.INNER_JOIN, false, logicalLimit( logicalJoin( - logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))), - logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent))) + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))), + logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student"))) ) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index c994ace23b..a9267b5390 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.pattern.GroupExpressionMatching; import org.apache.doris.nereids.pattern.MatchingContext; import org.apache.doris.nereids.pattern.PatternDescriptor; import org.apache.doris.nereids.pattern.PatternMatcher; +import org.apache.doris.nereids.processor.post.Validator; import org.apache.doris.nereids.properties.DistributionSpecGather; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; @@ -413,6 +414,14 @@ public class PlanChecker { } public PlanChecker matches(PatternDescriptor patternDesc) { + Memo memo = cascadesContext.getMemo(); + checkSlotFromChildren(memo); + assertMatches(memo, () -> MatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern)); + return this; + } + + // TODO: remove it. + public PlanChecker matchesNotCheck(PatternDescriptor patternDesc) { Memo memo = cascadesContext.getMemo(); assertMatches(memo, () -> MatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern)); return this; @@ -420,6 +429,7 @@ public class PlanChecker { public PlanChecker matchesExploration(PatternDescriptor patternDesc) { Memo memo = cascadesContext.getMemo(); + checkSlotFromChildren(memo); Supplier asserter = () -> new GroupExpressionMatching(patternDesc.pattern, memo.getRoot().getLogicalExpressions().get(1)).iterator().hasNext(); Assertions.assertTrue(asserter.get(), @@ -429,6 +439,11 @@ public class PlanChecker { return this; } + private void checkSlotFromChildren(Memo memo) { + Validator validator = new Validator(); + memo.getGroupExpressions().forEach((key, value) -> validator.visit(value.getPlan(), null)); + } + private PlanChecker assertMatches(Memo memo, Supplier asserter) { Assertions.assertTrue(asserter.get(), () -> "pattern not match, plan :\n" diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java index c1d9dcf08d..51948be52b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java @@ -39,11 +39,6 @@ public class PlanConstructor { public static final OlapTable score; public static final OlapTable course; - public static final LogicalOlapScan scan1; - public static final LogicalOlapScan scan2; - public static final LogicalOlapScan scan3; - public static final LogicalOlapScan scan4; - private static final IdGenerator RELATION_ID_GENERATOR = ObjectId.createGenerator(); static { @@ -81,11 +76,6 @@ public class PlanConstructor { 0, 0, (short) 0, TStorageType.COLUMN, KeysType.PRIMARY_KEYS); - - scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); } public static OlapTable newOlapTable(long tableId, String tableName, int hashColumn) {