[refactor](Nereids): refactor UT by using Pattern and rename to remove consecutive (#13337)

* rename

* refactor UT
This commit is contained in:
jakevin
2022-10-13 16:41:51 +08:00
committed by GitHub
parent baf2689610
commit 4a6eb01ccb
15 changed files with 179 additions and 174 deletions

View File

@ -228,8 +228,7 @@ public class GroupExpression {
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(ownerGroup.getGroupId())
.append("(plan=" + plan.toString() + ") children=[");
builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=[");
for (Group group : children) {
builder.append(group.getGroupId()).append(" ");
}

View File

@ -39,9 +39,9 @@ import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickS
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
@ -68,7 +68,7 @@ public class RuleSet {
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(new AggregateDisassemble())
.add(new PushdownFilterThroughProject())
.add(new MergeConsecutiveProjects())
.add(new MergeProjects())
.build();
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
@ -78,9 +78,9 @@ public class RuleSet {
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
EliminateOuter.INSTANCE,
new MergeConsecutiveProjects(),
new MergeConsecutiveFilters(),
new MergeConsecutiveLimits());
new MergeProjects(),
new MergeFilters(),
new MergeLimits());
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
.add(new LogicalAggToPhysicalHashAgg())

View File

@ -100,9 +100,9 @@ public enum RuleType {
REWRITE_JOIN_EXPRESSION(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
// Merge Consecutive plan
MERGE_CONSECUTIVE_FILTERS(RuleTypeClass.REWRITE),
MERGE_CONSECUTIVE_PROJECTS(RuleTypeClass.REWRITE),
MERGE_CONSECUTIVE_LIMITS(RuleTypeClass.REWRITE),
MERGE_FILTERS(RuleTypeClass.REWRITE),
MERGE_PROJECTS(RuleTypeClass.REWRITE),
MERGE_LIMITS(RuleTypeClass.REWRITE),
// Eliminate plan
ELIMINATE_LIMIT(RuleTypeClass.REWRITE),
ELIMINATE_FILTER(RuleTypeClass.REWRITE),

View File

@ -43,7 +43,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
* |
* scan
*/
public class MergeConsecutiveFilters extends OneRewriteRuleFactory {
public class MergeFilters extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(logicalFilter()).then(filter -> {
@ -52,7 +52,7 @@ public class MergeConsecutiveFilters extends OneRewriteRuleFactory {
Expression childPredicates = childFilter.getPredicates();
Expression mergedPredicates = ExpressionUtils.and(predicates, childPredicates);
return new LogicalFilter<>(mergedPredicates, childFilter.child());
}).toRule(RuleType.MERGE_CONSECUTIVE_FILTERS);
}).toRule(RuleType.MERGE_FILTERS);
}
}

View File

@ -40,7 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
* </pre>
* Note that the top limit should not have valid offset info.
*/
public class MergeConsecutiveLimits extends OneRewriteRuleFactory {
public class MergeLimits extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalLimit(logicalLimit()).whenNot(Limit::hasValidOffset).then(upperLimit -> {
@ -50,6 +50,6 @@ public class MergeConsecutiveLimits extends OneRewriteRuleFactory {
bottomLimit.getOffset(),
bottomLimit.child()
);
}).toRule(RuleType.MERGE_CONSECUTIVE_LIMITS);
}).toRule(RuleType.MERGE_LIMITS);
}
}

View File

@ -51,7 +51,7 @@ import java.util.stream.Collectors;
* scan
*/
public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
public class MergeProjects extends OneRewriteRuleFactory {
private static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
@ -114,10 +114,10 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
);
projectExpressions = projectExpressions.stream()
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
.map(e -> MergeProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new LogicalProject<>(projectExpressions, childProject.children().get(0));
}).toRule(RuleType.MERGE_CONSECUTIVE_PROJECTS);
}).toRule(RuleType.MERGE_PROJECTS);
}
}

View File

@ -51,7 +51,7 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
}
private Plan unCorrelatedToJoin(LogicalApply apply) {
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows(
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(
new AssertNumRowsElement(
1, apply.getSubqueryExpr().toString(),
AssertNumRowsElement.Assertion.EQ),

View File

@ -126,14 +126,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM
)
);
Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
LogicalFilter root = new LogicalFilter(expr, join);
LogicalFilter root = new LogicalFilter<>(expr, join);
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
.matchesFromRoot(
logicalFilter()
.when(filter -> verifySingleTableExpression2(filter.getPredicates()))
);
Assertions.assertTrue(studentGender != null);
Assertions.assertNotNull(studentGender);
}
private boolean verifySingleTableExpression2(Expression expr) {
@ -163,14 +163,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM
new EqualTo(studentGender, new IntegerLiteral(1))
);
Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
LogicalFilter root = new LogicalFilter(expr, join);
LogicalFilter root = new LogicalFilter<>(expr, join);
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
.matchesFromRoot(
logicalFilter()
.when(filter -> verifySingleTableExpression3(filter.getPredicates()))
);
Assertions.assertTrue(studentGender != null);
Assertions.assertNotNull(studentGender);
}
private boolean verifySingleTableExpression3(Expression expr) {

View File

@ -36,7 +36,7 @@ import java.util.List;
/**
* MergeConsecutiveFilter ut
*/
public class MergeConsecutiveFilterTest {
public class MergeFiltersTest {
@Test
public void testMergeConsecutiveFilters() {
UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table"));
@ -48,7 +48,7 @@ public class MergeConsecutiveFilterTest {
LogicalFilter filter3 = new LogicalFilter<>(expression3, filter2);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveFilters().build());
List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
cascadesContext.bottomUpRewrite(rules);
//check transformed plan
Plan resultPlan = cascadesContext.getMemo().copyOut();

View File

@ -29,7 +29,7 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class MergeConsecutiveLimitsTest {
public class MergeLimitsTest {
@Test
public void testMergeConsecutiveLimits() {
LogicalLimit limit3 = new LogicalLimit<>(3, 5, new UnboundRelation(Lists.newArrayList("db", "t")));
@ -37,7 +37,7 @@ public class MergeConsecutiveLimitsTest {
LogicalLimit limit1 = new LogicalLimit<>(10, 0, limit2);
CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveLimits().build());
List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
context.topDownRewrite(rules);
LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();

View File

@ -40,7 +40,7 @@ import java.util.List;
/**
* MergeConsecutiveProjects ut
*/
public class MergeConsecutiveProjectsTest {
public class MergeProjectsTest {
@Test
public void testMergeConsecutiveProjects() {
UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table"));
@ -52,7 +52,7 @@ public class MergeConsecutiveProjectsTest {
LogicalProject project3 = new LogicalProject<>(Lists.newArrayList(colA), project2);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project3);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
List<Rule> rules = Lists.newArrayList(new MergeProjects().build());
cascadesContext.bottomUpRewrite(rules);
Plan plan = cascadesContext.getMemo().copyOut();
System.out.println(plan.treeString());
@ -95,7 +95,7 @@ public class MergeConsecutiveProjectsTest {
project1);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project2);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
List<Rule> rules = Lists.newArrayList(new MergeProjects().build());
cascadesContext.bottomUpRewrite(rules);
Plan plan = cascadesContext.getMemo().copyOut();
System.out.println(plan.treeString());

View File

@ -30,7 +30,9 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.FieldChecker;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
@ -46,14 +48,15 @@ import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class NormalizeAggregateTest implements PatternMatchSupported {
private Plan rStudent;
private LogicalPlan rStudent;
@BeforeAll
public final void beforeAll() {
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
ImmutableList.of(""));
}
/**
/*-
* original plan:
* LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS `sum`#4], groupBy: [name#2])
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
@ -79,16 +82,18 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
logicalOlapScan()
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
.when(aggregate -> aggregate.getOutputExpressions().get(0).equals(key))
.when(aggregate -> aggregate.getOutputExpressions().get(1).child(0).equals(aggregateFunction.child(0)))
.when(aggregate -> aggregate.getOutputExpressions().get(1).child(0)
.equals(aggregateFunction.child(0)))
.when(FieldChecker.check("normalized", true))
).when(project -> project.getProjects().get(0).equals(key))
.when(project -> project.getProjects().get(1) instanceof Alias)
.when(project -> ((Alias) (project.getProjects().get(1))).getExprId().equals(aggregateFunction.getExprId()))
.when(project -> (project.getProjects().get(1)).getExprId()
.equals(aggregateFunction.getExprId()))
.when(project -> project.getProjects().get(1).child(0) instanceof SlotReference)
);
}
/**
/*-
* original plan:
* LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS `(sum((id * 1)) + 2)`#4], groupBy: [name#2])
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
@ -101,14 +106,15 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
*/
@Test
public void testComplexFuncWithComplexOutputOfFunc() {
NamedExpression key = rStudent.getOutput().get(2).toSlot();
List<Expression> groupExpressionList = Lists.newArrayList(key);
Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1));
Expression aggregateFunction = new Sum(multiply);
Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2));
Alias output = new Alias(complexOutput, complexOutput.toSql());
List<NamedExpression> outputExpressionList = Lists.newArrayList(output);
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
LogicalPlan root = new LogicalPlanBuilder(rStudent)
.aggGroupUsingIndex(ImmutableList.of(2), outputExpressionList)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new NormalizeAggregate())
@ -116,24 +122,27 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
logicalOlapScan()
).when(project -> project.getProjects().size() == 2)
.when(project -> project.getProjects().get(0) instanceof SlotReference)
.when(project -> project.getProjects().get(1).child(0).equals(multiply))
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
).when(FieldChecker.check("groupByExpressions",
ImmutableList.of(rStudent.getOutput().get(2))))
.when(aggregate -> aggregate.getOutputExpressions().size() == 1)
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction)
.when(aggregate -> aggregate.getOutputExpressions().get(0)
.child(0) instanceof AggregateFunction)
).when(project -> project.getProjects().size() == 1)
.when(project -> project.getProjects().get(0) instanceof Alias)
.when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId()))
.when(project -> project.getProjects().get(0).child(0) instanceof Add)
.when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
.when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
.when(project -> project.getProjects().get(0).child(0)
.child(0) instanceof SlotReference)
.when(project -> project.getProjects().get(0).child(0).child(1)
.equals(new IntegerLiteral(2)))
);
}
/**
/*-
* original plan:
* LogicalAggregate (phase: [GLOBAL], output: [((gender#1 + 1) + 2) AS `((gender + 1) + 2)`#4], groupBy: [(gender#1 + 1)])
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
@ -146,12 +155,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
*/
@Test
public void testComplexKeyWithComplexOutputOfKey() {
Expression key = new Add(rStudent.getOutput().get(1).toSlot(), new IntegerLiteral(1));
Expression key = new Add(rStudent.getOutput().get(1), new IntegerLiteral(1));
Expression complexKeyOutput = new Add(key, new IntegerLiteral(2));
NamedExpression keyOutput = new Alias(complexKeyOutput, complexKeyOutput.toSql());
List<Expression> groupExpressionList = Lists.newArrayList(key);
List<NamedExpression> outputExpressionList = Lists.newArrayList(keyOutput);
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
LogicalPlan root = new LogicalPlanBuilder(rStudent)
.agg(groupExpressionList, outputExpressionList)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new NormalizeAggregate())
@ -164,12 +177,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
.when(project -> project.getProjects().get(0) instanceof Alias)
.when(project -> project.getProjects().get(0).child(0).equals(key))
).when(aggregate -> aggregate.getGroupByExpressions().get(0) instanceof SlotReference)
.when(aggregate -> aggregate.getOutputExpressions().get(0) instanceof SlotReference)
.when(aggregate -> aggregate.getGroupByExpressions().equals(aggregate.getOutputExpressions()))
.when(aggregate -> aggregate.getOutputExpressions()
.get(0) instanceof SlotReference)
.when(aggregate -> aggregate.getGroupByExpressions()
.equals(aggregate.getOutputExpressions()))
).when(project -> project.getProjects().get(0).getExprId().equals(keyOutput.getExprId()))
.when(project -> project.getProjects().get(0).child(0) instanceof Add)
.when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
.when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
.when(project -> project.getProjects().get(0).child(0)
.child(0) instanceof SlotReference)
.when(project -> project.getProjects().get(0).child(0).child(1)
.equals(new IntegerLiteral(2)))
);
}

View File

@ -17,97 +17,74 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
public class PushdownFilterThroughAggregationTest implements PatternMatchSupported {
public class PushdownFilterThroughAggregationTest {
/**
* origin plan:
* project
* |
* filter gender=1
* |
* aggregation group by gender
* |
* scan(student)
*
* transformed plan:
* project
* |
* aggregation group by gender
* |
* filter gender=1
* |
* scan(student)
*/
/*-
* origin plan:
* project
* |
* filter gender=1
* |
* aggregation group by gender
* |
* scan(student)
*
* transformed plan:
* project
* |
* aggregation group by gender
* |
* filter gender=1
* |
* scan(student)
*/
@Test
public void pushDownPredicateOneFilterTest() {
Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
ImmutableList.of(""));
Slot gender = scan.getOutput().get(1);
Slot age = scan.getOutput().get(3);
List<Expression> groupByKeys = Lists.newArrayList(age, gender);
List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
Expression filterPredicate = new GreaterThan(gender, Literal.of(1));
LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
Plan root = new LogicalProject<>(
Lists.newArrayList(gender),
filter
);
LogicalPlan plan = new LogicalPlanBuilder(scan)
.aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3))
.filter(filterPredicate)
.project(ImmutableList.of(0))
.build();
Memo memo = rewrite(root);
System.out.println(memo.copyOut().treeString());
Group rootGroup = memo.getRoot();
GroupExpression groupExpression = rootGroup
.getLogicalExpression().child(0)
.getLogicalExpression();
aggregation = groupExpression.getPlan();
Assertions.assertTrue(aggregation instanceof LogicalAggregate);
groupExpression = groupExpression.child(0).getLogicalExpression();
Plan bottomFilter = groupExpression.getPlan();
Assertions.assertTrue(bottomFilter instanceof LogicalFilter);
Expression greater = ((LogicalFilter<?>) bottomFilter).getPredicates();
Assertions.assertTrue(greater instanceof GreaterThan);
Assertions.assertTrue(greater.child(0) instanceof Slot);
Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName());
groupExpression = groupExpression.child(0).getLogicalExpression();
Plan scan2 = groupExpression.getPlan();
Assertions.assertTrue(scan2 instanceof LogicalOlapScan);
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownFilterThroughAggregation())
.matches(
logicalProject(
logicalAggregate(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().equals(filterPredicate))
)
)
);
}
/**
/*-
* origin plan:
* project
* |
@ -130,63 +107,40 @@ public class PushdownFilterThroughAggregationTest {
*/
@Test
public void pushDownPredicateTwoFilterTest() {
Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
ImmutableList.of(""));
Slot gender = scan.getOutput().get(1);
Slot name = scan.getOutput().get(2);
Slot age = scan.getOutput().get(3);
List<Expression> groupByKeys = Lists.newArrayList(age, gender);
List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
Expression filterPredicate = ExpressionUtils.and(
new GreaterThan(gender, Literal.of(1)),
new LessThanEqual(
new Add(
gender,
Literal.of(10)
),
Literal.of(100)
),
new EqualTo(name, Literal.of("abc"))
);
LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
Plan root = new LogicalProject<>(
Lists.newArrayList(gender),
filter
);
new Add(gender, Literal.of(10)),
Literal.of(100)),
new EqualTo(name, Literal.of("abc")));
Memo memo = rewrite(root);
System.out.println(memo.copyOut().treeString());
Group rootGroup = memo.getRoot();
GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression();
Plan upperFilter = groupExpression.getPlan();
Assertions.assertTrue(upperFilter instanceof LogicalFilter);
Expression upperPredicates = ((LogicalFilter<?>) upperFilter).getPredicates();
Assertions.assertTrue(upperPredicates instanceof EqualTo);
Assertions.assertTrue(upperPredicates.child(0) instanceof Slot);
groupExpression = groupExpression.child(0).getLogicalExpression();
aggregation = groupExpression.getPlan();
Assertions.assertTrue(aggregation instanceof LogicalAggregate);
groupExpression = groupExpression.child(0).getLogicalExpression();
Plan bottomFilter = groupExpression.getPlan();
Assertions.assertTrue(bottomFilter instanceof LogicalFilter);
Expression bottomPredicates = ((LogicalFilter<?>) bottomFilter).getPredicates();
Assertions.assertTrue(bottomPredicates instanceof And);
Assertions.assertEquals(2, bottomPredicates.children().size());
Expression greater = bottomPredicates.child(0);
Assertions.assertTrue(greater instanceof GreaterThan);
Assertions.assertTrue(greater.child(0) instanceof Slot);
Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName());
Expression less = bottomPredicates.child(1);
Assertions.assertTrue(less instanceof LessThanEqual);
Assertions.assertTrue(less.child(0) instanceof Add);
LogicalPlan plan = new LogicalPlanBuilder(scan)
.aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3))
.filter(filterPredicate)
.project(ImmutableList.of(0))
.build();
groupExpression = groupExpression.child(0).getLogicalExpression();
Plan scan2 = groupExpression.getPlan();
Assertions.assertTrue(scan2 instanceof LogicalOlapScan);
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownFilterThroughAggregation())
.printlnTree()
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().child(0) instanceof GreaterThan)
.when(filter -> filter.getPredicates()
.child(1) instanceof LessThanEqual)
)
).when(filter -> filter.getPredicates() instanceof EqualTo)
)
);
}
private Memo rewrite(Plan plan) {
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownFilterThroughAggregation());
}
}

View File

@ -24,7 +24,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.analysis.EliminateAliasNode;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
@ -114,7 +114,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
PlanChecker.from(connectContext)
.analyze("SELECT * FROM V1")
.applyTopDown(new EliminateAliasNode())
.applyTopDown(new MergeConsecutiveProjects())
.applyTopDown(new MergeProjects())
.matchesFromRoot(
logicalProject(
logicalOlapScan()
@ -141,7 +141,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
+ "ON X.ID1 = Y.ID3"
)
.applyTopDown(new EliminateAliasNode())
.applyTopDown(new MergeConsecutiveProjects())
.applyTopDown(new MergeProjects())
.matchesFromRoot(
logicalProject(
logicalJoin(

View File

@ -23,6 +23,8 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@ -32,6 +34,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import java.util.ArrayList;
@ -118,8 +121,40 @@ public class LogicalPlanBuilder {
}
public LogicalPlanBuilder filter(Expression predicate) {
LogicalFilter<LogicalPlan> limitPlan = new LogicalFilter<>(predicate, this.plan);
return from(limitPlan);
LogicalFilter<LogicalPlan> filter = new LogicalFilter<>(predicate, this.plan);
return from(filter);
}
public LogicalPlanBuilder aggAllUsingIndex(List<Integer> groupByKeysIndex, List<Integer> outputExprsIndex) {
Builder<Expression> groupByBuilder = ImmutableList.builder();
for (Integer index : groupByKeysIndex) {
groupByBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
Builder<NamedExpression> outputBuilder = ImmutableList.builder();
for (Integer index : outputExprsIndex) {
outputBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<NamedExpression> outputExpresList = outputBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
return from(agg);
}
public LogicalPlanBuilder aggGroupUsingIndex(List<Integer> groupByKeysIndex, List<NamedExpression> outputExpresList) {
Builder<Expression> groupByBuilder = ImmutableList.builder();
for (Integer index : groupByKeysIndex) {
groupByBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
return from(agg);
}
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExpresList) {
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
return from(agg);
}
}