[refactor](Nereids): refactor UT by using Pattern and rename to remove consecutive (#13337)
* rename * refactor UT
This commit is contained in:
@ -228,8 +228,7 @@ public class GroupExpression {
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append(ownerGroup.getGroupId())
|
||||
.append("(plan=" + plan.toString() + ") children=[");
|
||||
builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=[");
|
||||
for (Group group : children) {
|
||||
builder.append(group.getGroupId()).append(" ");
|
||||
}
|
||||
|
||||
@ -39,9 +39,9 @@ import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickS
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
|
||||
@ -68,7 +68,7 @@ public class RuleSet {
|
||||
.add(SemiJoinSemiJoinTranspose.INSTANCE)
|
||||
.add(new AggregateDisassemble())
|
||||
.add(new PushdownFilterThroughProject())
|
||||
.add(new MergeConsecutiveProjects())
|
||||
.add(new MergeProjects())
|
||||
.build();
|
||||
|
||||
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
|
||||
@ -78,9 +78,9 @@ public class RuleSet {
|
||||
new PushdownProjectThroughLimit(),
|
||||
new PushdownFilterThroughProject(),
|
||||
EliminateOuter.INSTANCE,
|
||||
new MergeConsecutiveProjects(),
|
||||
new MergeConsecutiveFilters(),
|
||||
new MergeConsecutiveLimits());
|
||||
new MergeProjects(),
|
||||
new MergeFilters(),
|
||||
new MergeLimits());
|
||||
|
||||
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
|
||||
.add(new LogicalAggToPhysicalHashAgg())
|
||||
|
||||
@ -100,9 +100,9 @@ public enum RuleType {
|
||||
REWRITE_JOIN_EXPRESSION(RuleTypeClass.REWRITE),
|
||||
REORDER_JOIN(RuleTypeClass.REWRITE),
|
||||
// Merge Consecutive plan
|
||||
MERGE_CONSECUTIVE_FILTERS(RuleTypeClass.REWRITE),
|
||||
MERGE_CONSECUTIVE_PROJECTS(RuleTypeClass.REWRITE),
|
||||
MERGE_CONSECUTIVE_LIMITS(RuleTypeClass.REWRITE),
|
||||
MERGE_FILTERS(RuleTypeClass.REWRITE),
|
||||
MERGE_PROJECTS(RuleTypeClass.REWRITE),
|
||||
MERGE_LIMITS(RuleTypeClass.REWRITE),
|
||||
// Eliminate plan
|
||||
ELIMINATE_LIMIT(RuleTypeClass.REWRITE),
|
||||
ELIMINATE_FILTER(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -43,7 +43,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
* |
|
||||
* scan
|
||||
*/
|
||||
public class MergeConsecutiveFilters extends OneRewriteRuleFactory {
|
||||
public class MergeFilters extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFilter(logicalFilter()).then(filter -> {
|
||||
@ -52,7 +52,7 @@ public class MergeConsecutiveFilters extends OneRewriteRuleFactory {
|
||||
Expression childPredicates = childFilter.getPredicates();
|
||||
Expression mergedPredicates = ExpressionUtils.and(predicates, childPredicates);
|
||||
return new LogicalFilter<>(mergedPredicates, childFilter.child());
|
||||
}).toRule(RuleType.MERGE_CONSECUTIVE_FILTERS);
|
||||
}).toRule(RuleType.MERGE_FILTERS);
|
||||
}
|
||||
|
||||
}
|
||||
@ -40,7 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
* </pre>
|
||||
* Note that the top limit should not have valid offset info.
|
||||
*/
|
||||
public class MergeConsecutiveLimits extends OneRewriteRuleFactory {
|
||||
public class MergeLimits extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalLimit(logicalLimit()).whenNot(Limit::hasValidOffset).then(upperLimit -> {
|
||||
@ -50,6 +50,6 @@ public class MergeConsecutiveLimits extends OneRewriteRuleFactory {
|
||||
bottomLimit.getOffset(),
|
||||
bottomLimit.child()
|
||||
);
|
||||
}).toRule(RuleType.MERGE_CONSECUTIVE_LIMITS);
|
||||
}).toRule(RuleType.MERGE_LIMITS);
|
||||
}
|
||||
}
|
||||
@ -51,7 +51,7 @@ import java.util.stream.Collectors;
|
||||
* scan
|
||||
*/
|
||||
|
||||
public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
|
||||
public class MergeProjects extends OneRewriteRuleFactory {
|
||||
|
||||
private static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
|
||||
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
|
||||
@ -114,10 +114,10 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
|
||||
);
|
||||
|
||||
projectExpressions = projectExpressions.stream()
|
||||
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
|
||||
.map(e -> MergeProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
return new LogicalProject<>(projectExpressions, childProject.children().get(0));
|
||||
}).toRule(RuleType.MERGE_CONSECUTIVE_PROJECTS);
|
||||
}).toRule(RuleType.MERGE_PROJECTS);
|
||||
}
|
||||
}
|
||||
@ -51,7 +51,7 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
|
||||
}
|
||||
|
||||
private Plan unCorrelatedToJoin(LogicalApply apply) {
|
||||
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows(
|
||||
LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(
|
||||
new AssertNumRowsElement(
|
||||
1, apply.getSubqueryExpr().toString(),
|
||||
AssertNumRowsElement.Assertion.EQ),
|
||||
|
||||
@ -126,14 +126,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM
|
||||
)
|
||||
);
|
||||
Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
|
||||
LogicalFilter root = new LogicalFilter(expr, join);
|
||||
LogicalFilter root = new LogicalFilter<>(expr, join);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
|
||||
.matchesFromRoot(
|
||||
logicalFilter()
|
||||
.when(filter -> verifySingleTableExpression2(filter.getPredicates()))
|
||||
);
|
||||
Assertions.assertTrue(studentGender != null);
|
||||
Assertions.assertNotNull(studentGender);
|
||||
}
|
||||
|
||||
private boolean verifySingleTableExpression2(Expression expr) {
|
||||
@ -163,14 +163,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM
|
||||
new EqualTo(studentGender, new IntegerLiteral(1))
|
||||
);
|
||||
Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course);
|
||||
LogicalFilter root = new LogicalFilter(expr, join);
|
||||
LogicalFilter root = new LogicalFilter<>(expr, join);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
|
||||
.matchesFromRoot(
|
||||
logicalFilter()
|
||||
.when(filter -> verifySingleTableExpression3(filter.getPredicates()))
|
||||
);
|
||||
Assertions.assertTrue(studentGender != null);
|
||||
Assertions.assertNotNull(studentGender);
|
||||
}
|
||||
|
||||
private boolean verifySingleTableExpression3(Expression expr) {
|
||||
|
||||
@ -36,7 +36,7 @@ import java.util.List;
|
||||
/**
|
||||
* MergeConsecutiveFilter ut
|
||||
*/
|
||||
public class MergeConsecutiveFilterTest {
|
||||
public class MergeFiltersTest {
|
||||
@Test
|
||||
public void testMergeConsecutiveFilters() {
|
||||
UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table"));
|
||||
@ -48,7 +48,7 @@ public class MergeConsecutiveFilterTest {
|
||||
LogicalFilter filter3 = new LogicalFilter<>(expression3, filter2);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveFilters().build());
|
||||
List<Rule> rules = Lists.newArrayList(new MergeFilters().build());
|
||||
cascadesContext.bottomUpRewrite(rules);
|
||||
//check transformed plan
|
||||
Plan resultPlan = cascadesContext.getMemo().copyOut();
|
||||
@ -29,7 +29,7 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class MergeConsecutiveLimitsTest {
|
||||
public class MergeLimitsTest {
|
||||
@Test
|
||||
public void testMergeConsecutiveLimits() {
|
||||
LogicalLimit limit3 = new LogicalLimit<>(3, 5, new UnboundRelation(Lists.newArrayList("db", "t")));
|
||||
@ -37,7 +37,7 @@ public class MergeConsecutiveLimitsTest {
|
||||
LogicalLimit limit1 = new LogicalLimit<>(10, 0, limit2);
|
||||
|
||||
CascadesContext context = MemoTestUtils.createCascadesContext(limit1);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveLimits().build());
|
||||
List<Rule> rules = Lists.newArrayList(new MergeLimits().build());
|
||||
context.topDownRewrite(rules);
|
||||
LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut();
|
||||
|
||||
@ -40,7 +40,7 @@ import java.util.List;
|
||||
/**
|
||||
* MergeConsecutiveProjects ut
|
||||
*/
|
||||
public class MergeConsecutiveProjectsTest {
|
||||
public class MergeProjectsTest {
|
||||
@Test
|
||||
public void testMergeConsecutiveProjects() {
|
||||
UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table"));
|
||||
@ -52,7 +52,7 @@ public class MergeConsecutiveProjectsTest {
|
||||
LogicalProject project3 = new LogicalProject<>(Lists.newArrayList(colA), project2);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project3);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
|
||||
List<Rule> rules = Lists.newArrayList(new MergeProjects().build());
|
||||
cascadesContext.bottomUpRewrite(rules);
|
||||
Plan plan = cascadesContext.getMemo().copyOut();
|
||||
System.out.println(plan.treeString());
|
||||
@ -95,7 +95,7 @@ public class MergeConsecutiveProjectsTest {
|
||||
project1);
|
||||
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project2);
|
||||
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
|
||||
List<Rule> rules = Lists.newArrayList(new MergeProjects().build());
|
||||
cascadesContext.bottomUpRewrite(rules);
|
||||
Plan plan = cascadesContext.getMemo().copyOut();
|
||||
System.out.println(plan.treeString());
|
||||
@ -30,7 +30,9 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
@ -46,14 +48,15 @@ import java.util.List;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
private Plan rStudent;
|
||||
private LogicalPlan rStudent;
|
||||
|
||||
@BeforeAll
|
||||
public final void beforeAll() {
|
||||
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
}
|
||||
|
||||
/**
|
||||
/*-
|
||||
* original plan:
|
||||
* LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS `sum`#4], groupBy: [name#2])
|
||||
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
|
||||
@ -79,16 +82,18 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0).equals(key))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(1).child(0).equals(aggregateFunction.child(0)))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(1).child(0)
|
||||
.equals(aggregateFunction.child(0)))
|
||||
.when(FieldChecker.check("normalized", true))
|
||||
).when(project -> project.getProjects().get(0).equals(key))
|
||||
.when(project -> project.getProjects().get(1) instanceof Alias)
|
||||
.when(project -> ((Alias) (project.getProjects().get(1))).getExprId().equals(aggregateFunction.getExprId()))
|
||||
.when(project -> (project.getProjects().get(1)).getExprId()
|
||||
.equals(aggregateFunction.getExprId()))
|
||||
.when(project -> project.getProjects().get(1).child(0) instanceof SlotReference)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
/*-
|
||||
* original plan:
|
||||
* LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS `(sum((id * 1)) + 2)`#4], groupBy: [name#2])
|
||||
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
|
||||
@ -101,14 +106,15 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
*/
|
||||
@Test
|
||||
public void testComplexFuncWithComplexOutputOfFunc() {
|
||||
NamedExpression key = rStudent.getOutput().get(2).toSlot();
|
||||
List<Expression> groupExpressionList = Lists.newArrayList(key);
|
||||
Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1));
|
||||
Expression aggregateFunction = new Sum(multiply);
|
||||
Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2));
|
||||
Alias output = new Alias(complexOutput, complexOutput.toSql());
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(output);
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
LogicalPlan root = new LogicalPlanBuilder(rStudent)
|
||||
.aggGroupUsingIndex(ImmutableList.of(2), outputExpressionList)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new NormalizeAggregate())
|
||||
@ -116,24 +122,27 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
logicalOlapScan()
|
||||
).when(project -> project.getProjects().size() == 2)
|
||||
.when(project -> project.getProjects().get(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(1).child(0).equals(multiply))
|
||||
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
|
||||
).when(FieldChecker.check("groupByExpressions",
|
||||
ImmutableList.of(rStudent.getOutput().get(2))))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().size() == 1)
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction)
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0)
|
||||
.child(0) instanceof AggregateFunction)
|
||||
).when(project -> project.getProjects().size() == 1)
|
||||
.when(project -> project.getProjects().get(0) instanceof Alias)
|
||||
.when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId()))
|
||||
.when(project -> project.getProjects().get(0).child(0) instanceof Add)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
|
||||
.when(project -> project.getProjects().get(0).child(0)
|
||||
.child(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(1)
|
||||
.equals(new IntegerLiteral(2)))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
/*-
|
||||
* original plan:
|
||||
* LogicalAggregate (phase: [GLOBAL], output: [((gender#1 + 1) + 2) AS `((gender + 1) + 2)`#4], groupBy: [(gender#1 + 1)])
|
||||
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
|
||||
@ -146,12 +155,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
*/
|
||||
@Test
|
||||
public void testComplexKeyWithComplexOutputOfKey() {
|
||||
Expression key = new Add(rStudent.getOutput().get(1).toSlot(), new IntegerLiteral(1));
|
||||
Expression key = new Add(rStudent.getOutput().get(1), new IntegerLiteral(1));
|
||||
Expression complexKeyOutput = new Add(key, new IntegerLiteral(2));
|
||||
NamedExpression keyOutput = new Alias(complexKeyOutput, complexKeyOutput.toSql());
|
||||
|
||||
List<Expression> groupExpressionList = Lists.newArrayList(key);
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(keyOutput);
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
LogicalPlan root = new LogicalPlanBuilder(rStudent)
|
||||
.agg(groupExpressionList, outputExpressionList)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
|
||||
.applyTopDown(new NormalizeAggregate())
|
||||
@ -164,12 +177,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
.when(project -> project.getProjects().get(0) instanceof Alias)
|
||||
.when(project -> project.getProjects().get(0).child(0).equals(key))
|
||||
).when(aggregate -> aggregate.getGroupByExpressions().get(0) instanceof SlotReference)
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0) instanceof SlotReference)
|
||||
.when(aggregate -> aggregate.getGroupByExpressions().equals(aggregate.getOutputExpressions()))
|
||||
.when(aggregate -> aggregate.getOutputExpressions()
|
||||
.get(0) instanceof SlotReference)
|
||||
.when(aggregate -> aggregate.getGroupByExpressions()
|
||||
.equals(aggregate.getOutputExpressions()))
|
||||
).when(project -> project.getProjects().get(0).getExprId().equals(keyOutput.getExprId()))
|
||||
.when(project -> project.getProjects().get(0).child(0) instanceof Add)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
|
||||
.when(project -> project.getProjects().get(0).child(0)
|
||||
.child(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(0).child(0).child(1)
|
||||
.equals(new IntegerLiteral(2)))
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
@ -17,97 +17,74 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
public class PushdownFilterThroughAggregationTest implements PatternMatchSupported {
|
||||
|
||||
public class PushdownFilterThroughAggregationTest {
|
||||
|
||||
/**
|
||||
* origin plan:
|
||||
* project
|
||||
* |
|
||||
* filter gender=1
|
||||
* |
|
||||
* aggregation group by gender
|
||||
* |
|
||||
* scan(student)
|
||||
*
|
||||
* transformed plan:
|
||||
* project
|
||||
* |
|
||||
* aggregation group by gender
|
||||
* |
|
||||
* filter gender=1
|
||||
* |
|
||||
* scan(student)
|
||||
*/
|
||||
/*-
|
||||
* origin plan:
|
||||
* project
|
||||
* |
|
||||
* filter gender=1
|
||||
* |
|
||||
* aggregation group by gender
|
||||
* |
|
||||
* scan(student)
|
||||
*
|
||||
* transformed plan:
|
||||
* project
|
||||
* |
|
||||
* aggregation group by gender
|
||||
* |
|
||||
* filter gender=1
|
||||
* |
|
||||
* scan(student)
|
||||
*/
|
||||
@Test
|
||||
public void pushDownPredicateOneFilterTest() {
|
||||
Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
Slot gender = scan.getOutput().get(1);
|
||||
Slot age = scan.getOutput().get(3);
|
||||
|
||||
List<Expression> groupByKeys = Lists.newArrayList(age, gender);
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
|
||||
Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
|
||||
Expression filterPredicate = new GreaterThan(gender, Literal.of(1));
|
||||
LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
|
||||
Plan root = new LogicalProject<>(
|
||||
Lists.newArrayList(gender),
|
||||
filter
|
||||
);
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan)
|
||||
.aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3))
|
||||
.filter(filterPredicate)
|
||||
.project(ImmutableList.of(0))
|
||||
.build();
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
System.out.println(memo.copyOut().treeString());
|
||||
Group rootGroup = memo.getRoot();
|
||||
|
||||
GroupExpression groupExpression = rootGroup
|
||||
.getLogicalExpression().child(0)
|
||||
.getLogicalExpression();
|
||||
aggregation = groupExpression.getPlan();
|
||||
Assertions.assertTrue(aggregation instanceof LogicalAggregate);
|
||||
|
||||
groupExpression = groupExpression.child(0).getLogicalExpression();
|
||||
Plan bottomFilter = groupExpression.getPlan();
|
||||
Assertions.assertTrue(bottomFilter instanceof LogicalFilter);
|
||||
Expression greater = ((LogicalFilter<?>) bottomFilter).getPredicates();
|
||||
Assertions.assertTrue(greater instanceof GreaterThan);
|
||||
Assertions.assertTrue(greater.child(0) instanceof Slot);
|
||||
Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName());
|
||||
|
||||
groupExpression = groupExpression.child(0).getLogicalExpression();
|
||||
Plan scan2 = groupExpression.getPlan();
|
||||
Assertions.assertTrue(scan2 instanceof LogicalOlapScan);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownFilterThroughAggregation())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalFilter(
|
||||
logicalOlapScan()
|
||||
).when(filter -> filter.getPredicates().equals(filterPredicate))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
/*-
|
||||
* origin plan:
|
||||
* project
|
||||
* |
|
||||
@ -130,63 +107,40 @@ public class PushdownFilterThroughAggregationTest {
|
||||
*/
|
||||
@Test
|
||||
public void pushDownPredicateTwoFilterTest() {
|
||||
Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of(""));
|
||||
LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student,
|
||||
ImmutableList.of(""));
|
||||
Slot gender = scan.getOutput().get(1);
|
||||
Slot name = scan.getOutput().get(2);
|
||||
Slot age = scan.getOutput().get(3);
|
||||
|
||||
List<Expression> groupByKeys = Lists.newArrayList(age, gender);
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
|
||||
Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
|
||||
Expression filterPredicate = ExpressionUtils.and(
|
||||
new GreaterThan(gender, Literal.of(1)),
|
||||
new LessThanEqual(
|
||||
new Add(
|
||||
gender,
|
||||
Literal.of(10)
|
||||
),
|
||||
Literal.of(100)
|
||||
),
|
||||
new EqualTo(name, Literal.of("abc"))
|
||||
);
|
||||
LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
|
||||
Plan root = new LogicalProject<>(
|
||||
Lists.newArrayList(gender),
|
||||
filter
|
||||
);
|
||||
new Add(gender, Literal.of(10)),
|
||||
Literal.of(100)),
|
||||
new EqualTo(name, Literal.of("abc")));
|
||||
|
||||
Memo memo = rewrite(root);
|
||||
System.out.println(memo.copyOut().treeString());
|
||||
Group rootGroup = memo.getRoot();
|
||||
GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression();
|
||||
Plan upperFilter = groupExpression.getPlan();
|
||||
Assertions.assertTrue(upperFilter instanceof LogicalFilter);
|
||||
Expression upperPredicates = ((LogicalFilter<?>) upperFilter).getPredicates();
|
||||
Assertions.assertTrue(upperPredicates instanceof EqualTo);
|
||||
Assertions.assertTrue(upperPredicates.child(0) instanceof Slot);
|
||||
groupExpression = groupExpression.child(0).getLogicalExpression();
|
||||
aggregation = groupExpression.getPlan();
|
||||
Assertions.assertTrue(aggregation instanceof LogicalAggregate);
|
||||
groupExpression = groupExpression.child(0).getLogicalExpression();
|
||||
Plan bottomFilter = groupExpression.getPlan();
|
||||
Assertions.assertTrue(bottomFilter instanceof LogicalFilter);
|
||||
Expression bottomPredicates = ((LogicalFilter<?>) bottomFilter).getPredicates();
|
||||
Assertions.assertTrue(bottomPredicates instanceof And);
|
||||
Assertions.assertEquals(2, bottomPredicates.children().size());
|
||||
Expression greater = bottomPredicates.child(0);
|
||||
Assertions.assertTrue(greater instanceof GreaterThan);
|
||||
Assertions.assertTrue(greater.child(0) instanceof Slot);
|
||||
Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName());
|
||||
Expression less = bottomPredicates.child(1);
|
||||
Assertions.assertTrue(less instanceof LessThanEqual);
|
||||
Assertions.assertTrue(less.child(0) instanceof Add);
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan)
|
||||
.aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3))
|
||||
.filter(filterPredicate)
|
||||
.project(ImmutableList.of(0))
|
||||
.build();
|
||||
|
||||
groupExpression = groupExpression.child(0).getLogicalExpression();
|
||||
Plan scan2 = groupExpression.getPlan();
|
||||
Assertions.assertTrue(scan2 instanceof LogicalOlapScan);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownFilterThroughAggregation())
|
||||
.printlnTree()
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalFilter(
|
||||
logicalOlapScan()
|
||||
).when(filter -> filter.getPredicates().child(0) instanceof GreaterThan)
|
||||
.when(filter -> filter.getPredicates()
|
||||
.child(1) instanceof LessThanEqual)
|
||||
)
|
||||
).when(filter -> filter.getPredicates() instanceof EqualTo)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private Memo rewrite(Plan plan) {
|
||||
return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownFilterThroughAggregation());
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.analysis.EliminateAliasNode;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
@ -114,7 +114,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("SELECT * FROM V1")
|
||||
.applyTopDown(new EliminateAliasNode())
|
||||
.applyTopDown(new MergeConsecutiveProjects())
|
||||
.applyTopDown(new MergeProjects())
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
@ -141,7 +141,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
|
||||
+ "ON X.ID1 = Y.ID3"
|
||||
)
|
||||
.applyTopDown(new EliminateAliasNode())
|
||||
.applyTopDown(new MergeConsecutiveProjects())
|
||||
.applyTopDown(new MergeProjects())
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
|
||||
@ -23,6 +23,8 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
@ -32,6 +34,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -118,8 +121,40 @@ public class LogicalPlanBuilder {
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder filter(Expression predicate) {
|
||||
LogicalFilter<LogicalPlan> limitPlan = new LogicalFilter<>(predicate, this.plan);
|
||||
return from(limitPlan);
|
||||
LogicalFilter<LogicalPlan> filter = new LogicalFilter<>(predicate, this.plan);
|
||||
return from(filter);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder aggAllUsingIndex(List<Integer> groupByKeysIndex, List<Integer> outputExprsIndex) {
|
||||
Builder<Expression> groupByBuilder = ImmutableList.builder();
|
||||
for (Integer index : groupByKeysIndex) {
|
||||
groupByBuilder.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
|
||||
|
||||
Builder<NamedExpression> outputBuilder = ImmutableList.builder();
|
||||
for (Integer index : outputExprsIndex) {
|
||||
outputBuilder.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
ImmutableList<NamedExpression> outputExpresList = outputBuilder.build();
|
||||
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
|
||||
return from(agg);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder aggGroupUsingIndex(List<Integer> groupByKeysIndex, List<NamedExpression> outputExpresList) {
|
||||
Builder<Expression> groupByBuilder = ImmutableList.builder();
|
||||
for (Integer index : groupByKeysIndex) {
|
||||
groupByBuilder.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
|
||||
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
|
||||
return from(agg);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExpresList) {
|
||||
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
|
||||
return from(agg);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user