[feat](Nereids): check Memo Plan for Unit Test. (#18082)

This commit is contained in:
jakevin
2023-03-24 18:31:33 +08:00
committed by GitHub
parent eb7b59c1c6
commit 354d109130
11 changed files with 135 additions and 141 deletions

View File

@ -26,15 +26,15 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.FakePlan;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@ -42,6 +42,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@ -57,8 +58,10 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
class MemoTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final ConnectContext connectContext = MemoTestUtils.createConnectContext();
@ -186,11 +189,9 @@ class MemoTest implements MemoPatternMatchSupported {
@Test
public void initByTwoLevelChainPlan() {
OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), table);
LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>(
ImmutableList.of(scan.computeOutput().get(0)), scan);
Plan topProject = new LogicalPlanBuilder(scan)
.project(ImmutableList.of(0))
.build();
PlanChecker.from(connectContext, topProject)
.checkGroupNum(2)
@ -250,13 +251,11 @@ class MemoTest implements MemoPatternMatchSupported {
@Test
public void initByThreeLevelChainPlan() {
OlapTable table = PlanConstructor.newOlapTable(0, "a", 1);
LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), table);
LogicalProject<LogicalOlapScan> project = new LogicalProject<>(
ImmutableList.of(scan.computeOutput().get(0)), scan);
LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>(ImmutableSet.of(
new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1))), project);
Set<Expression> exprs = ImmutableSet.of(new EqualTo(scan.getOutput().get(0), Literal.of(1)));
Plan filter = new LogicalPlanBuilder(scan)
.project(ImmutableList.of(0))
.filter(exprs)
.build();
PlanChecker.from(connectContext, filter)
.checkGroupNum(3)
@ -264,8 +263,8 @@ class MemoTest implements MemoPatternMatchSupported {
logicalFilter(
logicalProject(
any().when(child -> Objects.equals(child, scan))
).when(root -> Objects.equals(root, project))
).when(root -> Objects.equals(root, filter))
).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).equals(scan.getOutput().get(0)))
).when(f -> Objects.equals(f, filter))
);
}

View File

@ -153,7 +153,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
// after analyze
PlanChecker.from(connectContext)
.analyze(sql2)
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalAggregate(
@ -181,7 +181,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
PlanChecker.from(connectContext)
.analyze(sql2)
.applyBottomUp(new UnCorrelatedApplyAggregateFilter())
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalAggregate().when(FieldChecker.check("outputExpressions", ImmutableList.of(
@ -216,7 +216,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.analyze(sql2)
.applyBottomUp(new UnCorrelatedApplyAggregateFilter())
.applyBottomUp(new ScalarApplyToJoin())
.matches(
.matchesNotCheck(
logicalJoin(
any(),
logicalAggregate()
@ -240,7 +240,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
//"select * from t6 where t6.k1 in (select t7.k3 from t7 where t7.v2 = t6.k2)"
PlanChecker.from(connectContext)
.analyze(sql4)
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalProject().when(FieldChecker.check("projects", ImmutableList.of(
@ -303,7 +303,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
//"select * from t6 where exists (select t7.k3 from t7 where t6.k2 = t7.v2)"
PlanChecker.from(connectContext)
.analyze(sql6)
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalProject().when(FieldChecker.check("projects", ImmutableList.of(
@ -321,7 +321,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
PlanChecker.from(connectContext)
.analyze(sql6)
.applyBottomUp(new PullUpProjectUnderApply())
.matches(
.matchesNotCheck(
logicalProject(
logicalApply().when(FieldChecker.check("correlationFilter", Optional.empty()))
.when(FieldChecker.check("correlationSlot", ImmutableList.of(
@ -418,7 +418,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.applyBottomUp(new LogicalSubQueryAliasToLogicalProject())
.applyTopDown(new MergeProjects())
.applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject())
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalAggregate(
@ -451,7 +451,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.applyTopDown(new MergeProjects())
.applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject())
.applyBottomUp(new UnCorrelatedApplyAggregateFilter())
.matches(
.matchesNotCheck(
logicalApply(
any(),
logicalAggregate(
@ -478,14 +478,13 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject())
.applyBottomUp(new UnCorrelatedApplyAggregateFilter())
.applyBottomUp(new ScalarApplyToJoin())
.matches(
logicalJoin(
.matchesNotCheck(
leftSemiLogicalJoin(
any(),
logicalAggregate(
logicalProject()
)
)
.when(j -> j.getJoinType().equals(JoinType.LEFT_SEMI_JOIN))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(
new LessThan(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),

View File

@ -18,14 +18,11 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@ -35,17 +32,18 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
public class NormalizeRepeatTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
public void testKeepNullableAfterNormalizeRepeat() {
SlotReference slot1 = new SlotReference("id", IntegerType.INSTANCE, false);
SlotReference slot2 = slot1.withNullable(true);
SlotReference slot3 = new SlotReference("name", StringType.INSTANCE, false);
Alias alias = new Alias(new Sum(slot3), "sum(name)");
Slot id = scan1.getOutput().get(0);
Slot idNotNull = id.withNullable(true);
Slot name = scan1.getOutput().get(1);
Alias alias = new Alias(new Sum(name), "sum(name)");
Plan plan = new LogicalRepeat<>(
ImmutableList.of(ImmutableList.of(slot1)),
ImmutableList.of(slot2, alias),
new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.newOlapTable(0, "t", 0))
ImmutableList.of(ImmutableList.of(id)),
ImmutableList.of(idNotNull, alias),
scan1
);
PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
.applyTopDown(new NormalizeRepeat())

View File

@ -17,32 +17,32 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
/**
* Tests for {@link MergeFilters}.
*/
class MergeFiltersTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
void testMergeFilters() {
Expression expression1 = new IntegerLiteral(1);
Expression expression2 = new IntegerLiteral(2);
Expression expression3 = new IntegerLiteral(3);
LogicalPlan logicalFilter = new LogicalPlanBuilder(
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")))
LogicalPlan logicalFilter = new LogicalPlanBuilder(scan)
.filter(ImmutableSet.of(expression1))
.filter(ImmutableSet.of(expression2))
.filter(ImmutableSet.of(expression3))
@ -50,9 +50,7 @@ class MergeFiltersTest implements MemoPatternMatchSupported {
PlanChecker.from(new ConnectContext(), logicalFilter).applyBottomUp(new MergeFilters())
.matches(
logicalFilter(
unboundRelation()
).when(filter -> filter.getConjuncts()
logicalFilter().when(filter -> filter.getConjuncts()
.equals(ImmutableSet.of(expression1, expression2, expression3))));
}
}

View File

@ -17,33 +17,31 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
/**
* Tests for {@link MergeLimits}.
*/
class MergeLimitsTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
void testMergeLimits() {
LogicalPlan logicalLimit = new LogicalPlanBuilder(
new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "t")))
LogicalPlan logicalLimit = new LogicalPlanBuilder(scan1)
.limit(3, 5)
.limit(2, 0)
.limit(10, 0).build();
PlanChecker.from(new ConnectContext(), logicalLimit).applyTopDown(new MergeLimits())
.matches(
logicalLimit(
unboundRelation()
).when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5));
logicalLimit().when(limit -> limit.getLimit() == 2).when(limit -> limit.getOffset() == 5));
}
}

View File

@ -28,18 +28,15 @@ import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@ -79,24 +76,24 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
k0Filter.setUpperBound(new StringLiteral("2019-08-22"), true);
PartitionColumnFilter k1Filter = new PartitionColumnFilter();
List<Expr> inList = Lists.newArrayList();
inList.add(new IntLiteral(100));
inList.add(new IntLiteral(200));
inList.add(new IntLiteral(300));
inList.add(new IntLiteral(400));
inList.add(new IntLiteral(500));
List<Expr> inList = ImmutableList.of(
new IntLiteral(100),
new IntLiteral(200),
new IntLiteral(300),
new IntLiteral(400),
new IntLiteral(500));
k1Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k1"), inList, false));
PartitionColumnFilter k2Filter = new PartitionColumnFilter();
List<Expr> inList2 = Lists.newArrayList();
inList2.add(new IntLiteral(900));
inList2.add(new IntLiteral(1100));
List<Expr> inList2 = ImmutableList.of(
new IntLiteral(900),
new IntLiteral(1100));
k2Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k2"), inList2, false));
PartitionColumnFilter k3Filter = new PartitionColumnFilter();
List<Expr> inList3 = Lists.newArrayList();
inList3.add(new IntLiteral(1));
inList3.add(new IntLiteral(3));
List<Expr> inList3 = ImmutableList.of(
new IntLiteral(1),
new IntLiteral(3));
k3Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k3"), inList3, false));
PartitionColumnFilter k4Filter = new PartitionColumnFilter();
@ -104,31 +101,14 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
inList4.add(new IntLiteral(2));
k4Filter.setInPredicate(new org.apache.doris.analysis.InPredicate(new SlotRef(null, "k4"), inList4, false));
SlotReference k0 = new SlotReference("k0", DataType.fromCatalogType(Type.INT), false, ImmutableList.of());
SlotReference k1 = new SlotReference("k1", DataType.fromCatalogType(Type.INT), false, ImmutableList.of());
SlotReference k2 = new SlotReference("k2", DataType.fromCatalogType(Type.INT), false, ImmutableList.of());
SlotReference k3 = new SlotReference("k3", DataType.fromCatalogType(Type.INT), false, ImmutableList.of());
SlotReference k4 = new SlotReference("k4", DataType.fromCatalogType(Type.INT), false, ImmutableList.of());
GreaterThanEqual greaterThanEqual = new GreaterThanEqual(k0, new DateLiteral("2019-08-22"));
LessThanEqual lessThanEqual = new LessThanEqual(k0, new DateLiteral("2019-08-22"));
InPredicate inPredicate1 = new InPredicate(k1, ImmutableList.of(new IntegerLiteral(101),
new IntegerLiteral(201),
new IntegerLiteral(301),
new IntegerLiteral(401),
new IntegerLiteral(500)));
InPredicate inPredicate2 = new InPredicate(k2, ImmutableList.of(new IntegerLiteral(901),
new IntegerLiteral(1101)));
InPredicate inPredicate3 = new InPredicate(k3, ImmutableList.of(new IntegerLiteral(1),
new IntegerLiteral(3)));
EqualTo equalTo = new EqualTo(k4, new IntegerLiteral(10));
new Expectations() {
{
olapTable.getPartitionIds();
result = ImmutableList.of(1L);
olapTable.getBaseSchema();
result = columns;
olapTable.getName();
result = "t1";
olapTable.getPartition(anyLong);
@ -148,9 +128,21 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
}
};
LogicalOlapScan scan = new LogicalOlapScan(ObjectId.createGenerator().getNextId(), olapTable);
GreaterThanEqual greaterThanEqual = new GreaterThanEqual(scan.getOutput().get(0),
new DateLiteral("2019-08-22"));
LessThanEqual lessThanEqual = new LessThanEqual(scan.getOutput().get(0), new DateLiteral("2019-08-22"));
InPredicate inPredicate1 = new InPredicate(scan.getOutput().get(1), ImmutableList.of(Literal.of(101),
Literal.of(201), Literal.of(301), Literal.of(401), Literal.of(500)));
InPredicate inPredicate2 = new InPredicate(scan.getOutput().get(2), ImmutableList.of(Literal.of(901),
Literal.of(1101)));
InPredicate inPredicate3 = new InPredicate(scan.getOutput().get(3), ImmutableList.of(Literal.of(1),
Literal.of(3)));
EqualTo equalTo = new EqualTo(scan.getOutput().get(4), Literal.of(10));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(
ImmutableSet.of(greaterThanEqual, lessThanEqual, inPredicate1, inPredicate2, inPredicate3, equalTo),
new LogicalOlapScan(ObjectId.createGenerator().getNextId(), olapTable));
scan);
Assertions.assertEquals(0, filter.child().getSelectedTabletIds().size());
@ -158,7 +150,7 @@ class PruneOlapScanTabletTest implements MemoPatternMatchSupported {
.applyTopDown(new PruneOlapScanTablet())
.matches(
logicalFilter(
logicalOlapScan().when(scan -> scan.getSelectedTabletIds().size() == 19)
logicalOlapScan().when(s -> s.getSelectedTabletIds().size() == 19)
)
);
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@ -30,12 +31,14 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Test
void testPushdown() {
// condition don't use alias slot
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.alias(ImmutableList.of(1, 3), ImmutableList.of("1name", "2name"))
.build();
@ -54,8 +57,8 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
@Test
void testCondition() {
// condition use alias slot
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.alias(ImmutableList.of(0, 1, 3), ImmutableList.of("1id", "1name", "2name"))
.build();
@ -79,8 +82,8 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
@Test
void testJustRightSide() {
// condition use alias slot
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.alias(ImmutableList.of(2, 3), ImmutableList.of("2id", "2name"))
.build();

View File

@ -18,11 +18,14 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -39,6 +42,8 @@ import com.google.common.collect.ImmutableSet;
import org.junit.jupiter.api.Test;
public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student,
ImmutableList.of(""));
/*-
* origin plan:
@ -61,8 +66,6 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup
*/
@Test
public void pushDownPredicateOneFilterTest() {
LogicalPlan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student,
ImmutableList.of(""));
Slot gender = scan.getOutput().get(1);
Expression filterPredicate = new GreaterThan(gender, Literal.of(1));
@ -89,7 +92,7 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup
* origin plan:
* project
* |
* filter gender=1 and name="abc" and (gender+10)<100
* filter gender=1 and nameMax="abc" and (gender+10)<100
* |
* aggregation group by gender
* |
@ -98,7 +101,7 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup
* transformed plan:
* project
* |
* filter name="abc"
* filter nameMax="abc"
* |
* aggregation group by gender
* |
@ -108,20 +111,19 @@ public class PushdownFilterThroughAggregationTest implements MemoPatternMatchSup
*/
@Test
public void pushDownPredicateTwoFilterTest() {
LogicalPlan scan = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student,
ImmutableList.of(""));
Slot gender = scan.getOutput().get(1);
Slot name = scan.getOutput().get(2);
NamedExpression nameMax = new Alias(new Max(scan.getOutput().get(2)), "nameMax");
Expression filterPredicate = ExpressionUtils.and(
new GreaterThan(gender, Literal.of(1)),
new LessThanEqual(
new Add(gender, Literal.of(10)),
Literal.of(100)),
new EqualTo(name, Literal.of("abc")));
new EqualTo(nameMax.toSlot(), Literal.of("abc")));
LogicalPlan plan = new LogicalPlanBuilder(scan)
.aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3))
.aggGroupUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(
scan.getOutput().get(1), scan.getOutput().get(3), nameMax))
.filter(filterPredicate)
.project(ImmutableList.of(0))
.build();

View File

@ -105,8 +105,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
logicalLimit(
logicalProject(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalOlapScan().when(s -> s.equals(scanStudent))
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))),
logicalOlapScan().when(s -> s.getTable().getName().equals("student"))
).when(j -> j.getJoinType() == JoinType.LEFT_OUTER_JOIN)
)
)
@ -114,8 +114,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
test(JoinType.LEFT_OUTER_JOIN, false,
logicalLimit(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalOlapScan().when(s -> s.equals(scanStudent))
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))),
logicalOlapScan().when(s -> s.getTable().getName().equals("student"))
).when(j -> j.getJoinType() == JoinType.LEFT_OUTER_JOIN)
)
);
@ -127,19 +127,19 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
test(JoinType.RIGHT_OUTER_JOIN, true,
logicalLimit(
logicalProject(
logicalJoin(
logicalOlapScan().when(s -> s.equals(scanScore)),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.RIGHT_OUTER_JOIN)
rightOuterLogicalJoin(
logicalOlapScan().when(s -> s.getTable().getName().equals("score")),
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student")))
)
)
)
);
test(JoinType.RIGHT_OUTER_JOIN, false,
logicalLimit(
logicalJoin(
logicalOlapScan().when(s -> s.equals(scanScore)),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.RIGHT_OUTER_JOIN)
rightOuterLogicalJoin(
logicalOlapScan().when(s -> s.getTable().getName().equals("score")),
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student")))
)
)
);
}
@ -149,19 +149,19 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
test(JoinType.CROSS_JOIN, true,
logicalLimit(
logicalProject(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.CROSS_JOIN)
crossLogicalJoin(
logicalLimit(logicalOlapScan()),
logicalLimit(logicalOlapScan())
)
)
)
);
test(JoinType.CROSS_JOIN, false,
logicalLimit(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.CROSS_JOIN)
crossLogicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))),
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student")))
)
)
);
}
@ -172,8 +172,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
logicalLimit(
logicalProject(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))),
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student")))
)
)
)
@ -181,8 +181,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
test(JoinType.INNER_JOIN, false,
logicalLimit(
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("score"))),
logicalLimit(logicalOlapScan().when(s -> s.getTable().getName().equals("student")))
)
)
);

View File

@ -40,6 +40,7 @@ import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.pattern.PatternMatcher;
import org.apache.doris.nereids.processor.post.Validator;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
@ -413,6 +414,14 @@ public class PlanChecker {
}
public PlanChecker matches(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
checkSlotFromChildren(memo);
assertMatches(memo, () -> MatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern));
return this;
}
// TODO: remove it.
public PlanChecker matchesNotCheck(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
assertMatches(memo, () -> MatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern));
return this;
@ -420,6 +429,7 @@ public class PlanChecker {
public PlanChecker matchesExploration(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
checkSlotFromChildren(memo);
Supplier<Boolean> asserter = () -> new GroupExpressionMatching(patternDesc.pattern,
memo.getRoot().getLogicalExpressions().get(1)).iterator().hasNext();
Assertions.assertTrue(asserter.get(),
@ -429,6 +439,11 @@ public class PlanChecker {
return this;
}
private void checkSlotFromChildren(Memo memo) {
Validator validator = new Validator();
memo.getGroupExpressions().forEach((key, value) -> validator.visit(value.getPlan(), null));
}
private PlanChecker assertMatches(Memo memo, Supplier<Boolean> asserter) {
Assertions.assertTrue(asserter.get(),
() -> "pattern not match, plan :\n"

View File

@ -39,11 +39,6 @@ public class PlanConstructor {
public static final OlapTable score;
public static final OlapTable course;
public static final LogicalOlapScan scan1;
public static final LogicalOlapScan scan2;
public static final LogicalOlapScan scan3;
public static final LogicalOlapScan scan4;
private static final IdGenerator<ObjectId> RELATION_ID_GENERATOR = ObjectId.createGenerator();
static {
@ -81,11 +76,6 @@ public class PlanConstructor {
0, 0, (short) 0,
TStorageType.COLUMN,
KeysType.PRIMARY_KEYS);
scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
}
public static OlapTable newOlapTable(long tableId, String tableName, int hashColumn) {