diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index dddfc29392..a6f555b649 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -228,8 +228,7 @@ public class GroupExpression { @Override public String toString() { StringBuilder builder = new StringBuilder(); - builder.append(ownerGroup.getGroupId()) - .append("(plan=" + plan.toString() + ") children=["); + builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=["); for (Group group : children) { builder.append(group.getGroupId()).append(" "); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index d88b81ffed..e2aac0ae37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -39,9 +39,9 @@ import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickS import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; +import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters; +import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits; +import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects; import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition; import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin; import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject; @@ -68,7 +68,7 @@ public class RuleSet { .add(SemiJoinSemiJoinTranspose.INSTANCE) .add(new AggregateDisassemble()) .add(new PushdownFilterThroughProject()) - .add(new MergeConsecutiveProjects()) + .add(new MergeProjects()) .build(); public static final List PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of( @@ -78,9 +78,9 @@ public class RuleSet { new PushdownProjectThroughLimit(), new PushdownFilterThroughProject(), EliminateOuter.INSTANCE, - new MergeConsecutiveProjects(), - new MergeConsecutiveFilters(), - new MergeConsecutiveLimits()); + new MergeProjects(), + new MergeFilters(), + new MergeLimits()); public static final List IMPLEMENTATION_RULES = planRuleFactories() .add(new LogicalAggToPhysicalHashAgg()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2e85180097..642d6bfd0e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -100,9 +100,9 @@ public enum RuleType { REWRITE_JOIN_EXPRESSION(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), // Merge Consecutive plan - MERGE_CONSECUTIVE_FILTERS(RuleTypeClass.REWRITE), - MERGE_CONSECUTIVE_PROJECTS(RuleTypeClass.REWRITE), - MERGE_CONSECUTIVE_LIMITS(RuleTypeClass.REWRITE), + MERGE_FILTERS(RuleTypeClass.REWRITE), + MERGE_PROJECTS(RuleTypeClass.REWRITE), + MERGE_LIMITS(RuleTypeClass.REWRITE), // Eliminate plan ELIMINATE_LIMIT(RuleTypeClass.REWRITE), ELIMINATE_FILTER(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilters.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFilters.java similarity index 94% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilters.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFilters.java index dd7fca74f8..4a3aba00ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilters.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFilters.java @@ -43,7 +43,7 @@ import org.apache.doris.nereids.util.ExpressionUtils; * | * scan */ -public class MergeConsecutiveFilters extends OneRewriteRuleFactory { +public class MergeFilters extends OneRewriteRuleFactory { @Override public Rule build() { return logicalFilter(logicalFilter()).then(filter -> { @@ -52,7 +52,7 @@ public class MergeConsecutiveFilters extends OneRewriteRuleFactory { Expression childPredicates = childFilter.getPredicates(); Expression mergedPredicates = ExpressionUtils.and(predicates, childPredicates); return new LogicalFilter<>(mergedPredicates, childFilter.child()); - }).toRule(RuleType.MERGE_CONSECUTIVE_FILTERS); + }).toRule(RuleType.MERGE_FILTERS); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimits.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimits.java similarity index 94% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimits.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimits.java index f3e36b1513..0c0c222737 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimits.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimits.java @@ -40,7 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; * * Note that the top limit should not have valid offset info. */ -public class MergeConsecutiveLimits extends OneRewriteRuleFactory { +public class MergeLimits extends OneRewriteRuleFactory { @Override public Rule build() { return logicalLimit(logicalLimit()).whenNot(Limit::hasValidOffset).then(upperLimit -> { @@ -50,6 +50,6 @@ public class MergeConsecutiveLimits extends OneRewriteRuleFactory { bottomLimit.getOffset(), bottomLimit.child() ); - }).toRule(RuleType.MERGE_CONSECUTIVE_LIMITS); + }).toRule(RuleType.MERGE_LIMITS); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjects.java similarity index 95% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjects.java index 626d6d22cd..773feff136 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjects.java @@ -51,7 +51,7 @@ import java.util.stream.Collectors; * scan */ -public class MergeConsecutiveProjects extends OneRewriteRuleFactory { +public class MergeProjects extends OneRewriteRuleFactory { private static class ExpressionReplacer extends DefaultExpressionRewriter> { public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); @@ -114,10 +114,10 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { ); projectExpressions = projectExpressions.stream() - .map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap)) + .map(e -> MergeProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap)) .map(NamedExpression.class::cast) .collect(Collectors.toList()); return new LogicalProject<>(projectExpressions, childProject.children().get(0)); - }).toRule(RuleType.MERGE_CONSECUTIVE_PROJECTS); + }).toRule(RuleType.MERGE_PROJECTS); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java index 4110c15e3c..f77ca4ac07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java @@ -51,7 +51,7 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory { } private Plan unCorrelatedToJoin(LogicalApply apply) { - LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows( + LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>( new AssertNumRowsElement( 1, apply.getSubqueryExpr().toString(), AssertNumRowsElement.Assertion.EQ), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java index 1a1e75ad6e..d23b0a191d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractSingleTableExpressionFromDisjunctionTest.java @@ -126,14 +126,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM ) ); Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); - LogicalFilter root = new LogicalFilter(expr, join); + LogicalFilter root = new LogicalFilter<>(expr, join); PlanChecker.from(MemoTestUtils.createConnectContext(), root) .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) .matchesFromRoot( logicalFilter() .when(filter -> verifySingleTableExpression2(filter.getPredicates())) ); - Assertions.assertTrue(studentGender != null); + Assertions.assertNotNull(studentGender); } private boolean verifySingleTableExpression2(Expression expr) { @@ -163,14 +163,14 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements PatternM new EqualTo(studentGender, new IntegerLiteral(1)) ); Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course); - LogicalFilter root = new LogicalFilter(expr, join); + LogicalFilter root = new LogicalFilter<>(expr, join); PlanChecker.from(MemoTestUtils.createConnectContext(), root) .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) .matchesFromRoot( logicalFilter() .when(filter -> verifySingleTableExpression3(filter.getPredicates())) ); - Assertions.assertTrue(studentGender != null); + Assertions.assertNotNull(studentGender); } private boolean verifySingleTableExpression3(Expression expr) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java similarity index 95% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilterTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java index 3989f12ac8..15db4cdef4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeFiltersTest.java @@ -36,7 +36,7 @@ import java.util.List; /** * MergeConsecutiveFilter ut */ -public class MergeConsecutiveFilterTest { +public class MergeFiltersTest { @Test public void testMergeConsecutiveFilters() { UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table")); @@ -48,7 +48,7 @@ public class MergeConsecutiveFilterTest { LogicalFilter filter3 = new LogicalFilter<>(expression3, filter2); CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3); - List rules = Lists.newArrayList(new MergeConsecutiveFilters().build()); + List rules = Lists.newArrayList(new MergeFilters().build()); cascadesContext.bottomUpRewrite(rules); //check transformed plan Plan resultPlan = cascadesContext.getMemo().copyOut(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimitsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java similarity index 94% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimitsTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java index 7f96a954d6..3056f1ac34 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveLimitsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeLimitsTest.java @@ -29,7 +29,7 @@ import org.junit.jupiter.api.Test; import java.util.List; -public class MergeConsecutiveLimitsTest { +public class MergeLimitsTest { @Test public void testMergeConsecutiveLimits() { LogicalLimit limit3 = new LogicalLimit<>(3, 5, new UnboundRelation(Lists.newArrayList("db", "t"))); @@ -37,7 +37,7 @@ public class MergeConsecutiveLimitsTest { LogicalLimit limit1 = new LogicalLimit<>(10, 0, limit2); CascadesContext context = MemoTestUtils.createCascadesContext(limit1); - List rules = Lists.newArrayList(new MergeConsecutiveLimits().build()); + List rules = Lists.newArrayList(new MergeLimits().build()); context.topDownRewrite(rules); LogicalLimit limit = (LogicalLimit) context.getMemo().copyOut(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java similarity index 95% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java index 08c9af2946..d15c87660d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java @@ -40,7 +40,7 @@ import java.util.List; /** * MergeConsecutiveProjects ut */ -public class MergeConsecutiveProjectsTest { +public class MergeProjectsTest { @Test public void testMergeConsecutiveProjects() { UnboundRelation relation = new UnboundRelation(Lists.newArrayList("db", "table")); @@ -52,7 +52,7 @@ public class MergeConsecutiveProjectsTest { LogicalProject project3 = new LogicalProject<>(Lists.newArrayList(colA), project2); CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project3); - List rules = Lists.newArrayList(new MergeConsecutiveProjects().build()); + List rules = Lists.newArrayList(new MergeProjects().build()); cascadesContext.bottomUpRewrite(rules); Plan plan = cascadesContext.getMemo().copyOut(); System.out.println(plan.treeString()); @@ -95,7 +95,7 @@ public class MergeConsecutiveProjectsTest { project1); CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project2); - List rules = Lists.newArrayList(new MergeConsecutiveProjects().build()); + List rules = Lists.newArrayList(new MergeProjects().build()); cascadesContext.bottomUpRewrite(rules); Plan plan = cascadesContext.getMemo().copyOut(); System.out.println(plan.treeString()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java index ef819a66f6..8b6e49a44a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java @@ -30,7 +30,9 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.FieldChecker; +import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -46,14 +48,15 @@ import java.util.List; @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class NormalizeAggregateTest implements PatternMatchSupported { - private Plan rStudent; + private LogicalPlan rStudent; @BeforeAll public final void beforeAll() { - rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of("")); + rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, + ImmutableList.of("")); } - /** + /*- * original plan: * LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS `sum`#4], groupBy: [name#2]) * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) @@ -79,16 +82,18 @@ public class NormalizeAggregateTest implements PatternMatchSupported { logicalOlapScan() ).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key))) .when(aggregate -> aggregate.getOutputExpressions().get(0).equals(key)) - .when(aggregate -> aggregate.getOutputExpressions().get(1).child(0).equals(aggregateFunction.child(0))) + .when(aggregate -> aggregate.getOutputExpressions().get(1).child(0) + .equals(aggregateFunction.child(0))) .when(FieldChecker.check("normalized", true)) ).when(project -> project.getProjects().get(0).equals(key)) .when(project -> project.getProjects().get(1) instanceof Alias) - .when(project -> ((Alias) (project.getProjects().get(1))).getExprId().equals(aggregateFunction.getExprId())) + .when(project -> (project.getProjects().get(1)).getExprId() + .equals(aggregateFunction.getExprId())) .when(project -> project.getProjects().get(1).child(0) instanceof SlotReference) ); } - /** + /*- * original plan: * LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS `(sum((id * 1)) + 2)`#4], groupBy: [name#2]) * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) @@ -101,14 +106,15 @@ public class NormalizeAggregateTest implements PatternMatchSupported { */ @Test public void testComplexFuncWithComplexOutputOfFunc() { - NamedExpression key = rStudent.getOutput().get(2).toSlot(); - List groupExpressionList = Lists.newArrayList(key); Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)); Expression aggregateFunction = new Sum(multiply); Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2)); Alias output = new Alias(complexOutput, complexOutput.toSql()); List outputExpressionList = Lists.newArrayList(output); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); + + LogicalPlan root = new LogicalPlanBuilder(rStudent) + .aggGroupUsingIndex(ImmutableList.of(2), outputExpressionList) + .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), root) .applyTopDown(new NormalizeAggregate()) @@ -116,24 +122,27 @@ public class NormalizeAggregateTest implements PatternMatchSupported { logicalProject( logicalAggregate( logicalProject( - logicalOlapScan() + logicalOlapScan() ).when(project -> project.getProjects().size() == 2) .when(project -> project.getProjects().get(0) instanceof SlotReference) .when(project -> project.getProjects().get(1).child(0).equals(multiply)) - ).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key))) + ).when(FieldChecker.check("groupByExpressions", + ImmutableList.of(rStudent.getOutput().get(2)))) .when(aggregate -> aggregate.getOutputExpressions().size() == 1) - .when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction) + .when(aggregate -> aggregate.getOutputExpressions().get(0) + .child(0) instanceof AggregateFunction) ).when(project -> project.getProjects().size() == 1) .when(project -> project.getProjects().get(0) instanceof Alias) .when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId())) .when(project -> project.getProjects().get(0).child(0) instanceof Add) - .when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference) - .when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2))) + .when(project -> project.getProjects().get(0).child(0) + .child(0) instanceof SlotReference) + .when(project -> project.getProjects().get(0).child(0).child(1) + .equals(new IntegerLiteral(2))) ); } - - /** + /*- * original plan: * LogicalAggregate (phase: [GLOBAL], output: [((gender#1 + 1) + 2) AS `((gender + 1) + 2)`#4], groupBy: [(gender#1 + 1)]) * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) @@ -146,12 +155,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported { */ @Test public void testComplexKeyWithComplexOutputOfKey() { - Expression key = new Add(rStudent.getOutput().get(1).toSlot(), new IntegerLiteral(1)); + Expression key = new Add(rStudent.getOutput().get(1), new IntegerLiteral(1)); Expression complexKeyOutput = new Add(key, new IntegerLiteral(2)); NamedExpression keyOutput = new Alias(complexKeyOutput, complexKeyOutput.toSql()); + List groupExpressionList = Lists.newArrayList(key); List outputExpressionList = Lists.newArrayList(keyOutput); - Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); + + LogicalPlan root = new LogicalPlanBuilder(rStudent) + .agg(groupExpressionList, outputExpressionList) + .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), root) .applyTopDown(new NormalizeAggregate()) @@ -164,12 +177,16 @@ public class NormalizeAggregateTest implements PatternMatchSupported { .when(project -> project.getProjects().get(0) instanceof Alias) .when(project -> project.getProjects().get(0).child(0).equals(key)) ).when(aggregate -> aggregate.getGroupByExpressions().get(0) instanceof SlotReference) - .when(aggregate -> aggregate.getOutputExpressions().get(0) instanceof SlotReference) - .when(aggregate -> aggregate.getGroupByExpressions().equals(aggregate.getOutputExpressions())) + .when(aggregate -> aggregate.getOutputExpressions() + .get(0) instanceof SlotReference) + .when(aggregate -> aggregate.getGroupByExpressions() + .equals(aggregate.getOutputExpressions())) ).when(project -> project.getProjects().get(0).getExprId().equals(keyOutput.getExprId())) .when(project -> project.getProjects().get(0).child(0) instanceof Add) - .when(project -> project.getProjects().get(0).child(0).child(0) instanceof SlotReference) - .when(project -> project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2))) + .when(project -> project.getProjects().get(0).child(0) + .child(0) instanceof SlotReference) + .when(project -> project.getProjects().get(0).child(0).child(1) + .equals(new IntegerLiteral(2))) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java index 208ffb5b44..851f10e5e9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughAggregationTest.java @@ -17,97 +17,74 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.memo.Group; -import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.memo.Memo; import org.apache.doris.nereids.trees.expressions.Add; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.PlanRewriter; -import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; +public class PushdownFilterThroughAggregationTest implements PatternMatchSupported { -public class PushdownFilterThroughAggregationTest { - - /** - * origin plan: - * project - * | - * filter gender=1 - * | - * aggregation group by gender - * | - * scan(student) - * - * transformed plan: - * project - * | - * aggregation group by gender - * | - * filter gender=1 - * | - * scan(student) - */ + /*- + * origin plan: + * project + * | + * filter gender=1 + * | + * aggregation group by gender + * | + * scan(student) + * + * transformed plan: + * project + * | + * aggregation group by gender + * | + * filter gender=1 + * | + * scan(student) + */ @Test public void pushDownPredicateOneFilterTest() { - Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of("")); + LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, + ImmutableList.of("")); Slot gender = scan.getOutput().get(1); - Slot age = scan.getOutput().get(3); - List groupByKeys = Lists.newArrayList(age, gender); - List outputExpressionList = Lists.newArrayList(gender, age); - Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan); Expression filterPredicate = new GreaterThan(gender, Literal.of(1)); - LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation); - Plan root = new LogicalProject<>( - Lists.newArrayList(gender), - filter - ); + LogicalPlan plan = new LogicalPlanBuilder(scan) + .aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3)) + .filter(filterPredicate) + .project(ImmutableList.of(0)) + .build(); - Memo memo = rewrite(root); - System.out.println(memo.copyOut().treeString()); - Group rootGroup = memo.getRoot(); - - GroupExpression groupExpression = rootGroup - .getLogicalExpression().child(0) - .getLogicalExpression(); - aggregation = groupExpression.getPlan(); - Assertions.assertTrue(aggregation instanceof LogicalAggregate); - - groupExpression = groupExpression.child(0).getLogicalExpression(); - Plan bottomFilter = groupExpression.getPlan(); - Assertions.assertTrue(bottomFilter instanceof LogicalFilter); - Expression greater = ((LogicalFilter) bottomFilter).getPredicates(); - Assertions.assertTrue(greater instanceof GreaterThan); - Assertions.assertTrue(greater.child(0) instanceof Slot); - Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName()); - - groupExpression = groupExpression.child(0).getLogicalExpression(); - Plan scan2 = groupExpression.getPlan(); - Assertions.assertTrue(scan2 instanceof LogicalOlapScan); + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownFilterThroughAggregation()) + .matches( + logicalProject( + logicalAggregate( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().equals(filterPredicate)) + ) + ) + ); } - /** + /*- * origin plan: * project * | @@ -130,63 +107,40 @@ public class PushdownFilterThroughAggregationTest { */ @Test public void pushDownPredicateTwoFilterTest() { - Plan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of("")); + LogicalPlan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, + ImmutableList.of("")); Slot gender = scan.getOutput().get(1); Slot name = scan.getOutput().get(2); - Slot age = scan.getOutput().get(3); - List groupByKeys = Lists.newArrayList(age, gender); - List outputExpressionList = Lists.newArrayList(gender, age); - Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan); Expression filterPredicate = ExpressionUtils.and( new GreaterThan(gender, Literal.of(1)), new LessThanEqual( - new Add( - gender, - Literal.of(10) - ), - Literal.of(100) - ), - new EqualTo(name, Literal.of("abc")) - ); - LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation); - Plan root = new LogicalProject<>( - Lists.newArrayList(gender), - filter - ); + new Add(gender, Literal.of(10)), + Literal.of(100)), + new EqualTo(name, Literal.of("abc"))); - Memo memo = rewrite(root); - System.out.println(memo.copyOut().treeString()); - Group rootGroup = memo.getRoot(); - GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression(); - Plan upperFilter = groupExpression.getPlan(); - Assertions.assertTrue(upperFilter instanceof LogicalFilter); - Expression upperPredicates = ((LogicalFilter) upperFilter).getPredicates(); - Assertions.assertTrue(upperPredicates instanceof EqualTo); - Assertions.assertTrue(upperPredicates.child(0) instanceof Slot); - groupExpression = groupExpression.child(0).getLogicalExpression(); - aggregation = groupExpression.getPlan(); - Assertions.assertTrue(aggregation instanceof LogicalAggregate); - groupExpression = groupExpression.child(0).getLogicalExpression(); - Plan bottomFilter = groupExpression.getPlan(); - Assertions.assertTrue(bottomFilter instanceof LogicalFilter); - Expression bottomPredicates = ((LogicalFilter) bottomFilter).getPredicates(); - Assertions.assertTrue(bottomPredicates instanceof And); - Assertions.assertEquals(2, bottomPredicates.children().size()); - Expression greater = bottomPredicates.child(0); - Assertions.assertTrue(greater instanceof GreaterThan); - Assertions.assertTrue(greater.child(0) instanceof Slot); - Assertions.assertEquals("gender", ((Slot) greater.child(0)).getName()); - Expression less = bottomPredicates.child(1); - Assertions.assertTrue(less instanceof LessThanEqual); - Assertions.assertTrue(less.child(0) instanceof Add); + LogicalPlan plan = new LogicalPlanBuilder(scan) + .aggAllUsingIndex(ImmutableList.of(3, 1), ImmutableList.of(1, 3)) + .filter(filterPredicate) + .project(ImmutableList.of(0)) + .build(); - groupExpression = groupExpression.child(0).getLogicalExpression(); - Plan scan2 = groupExpression.getPlan(); - Assertions.assertTrue(scan2 instanceof LogicalOlapScan); + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownFilterThroughAggregation()) + .printlnTree() + .matches( + logicalProject( + logicalFilter( + logicalAggregate( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().child(0) instanceof GreaterThan) + .when(filter -> filter.getPredicates() + .child(1) instanceof LessThanEqual) + ) + ).when(filter -> filter.getPredicates() instanceof EqualTo) + ) + ); } - private Memo rewrite(Plan plan) { - return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushdownFilterThroughAggregation()); - } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java index f5c5a3460c..d1c7fcb452 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java @@ -24,7 +24,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.analysis.EliminateAliasNode; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; +import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PatternMatchSupported; @@ -114,7 +114,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported PlanChecker.from(connectContext) .analyze("SELECT * FROM V1") .applyTopDown(new EliminateAliasNode()) - .applyTopDown(new MergeConsecutiveProjects()) + .applyTopDown(new MergeProjects()) .matchesFromRoot( logicalProject( logicalOlapScan() @@ -141,7 +141,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported + "ON X.ID1 = Y.ID3" ) .applyTopDown(new EliminateAliasNode()) - .applyTopDown(new MergeConsecutiveProjects()) + .applyTopDown(new MergeProjects()) .matchesFromRoot( logicalProject( logicalJoin( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index f0bfa95d73..bdfa0d03d7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -23,6 +23,8 @@ import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -32,6 +34,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import java.util.ArrayList; @@ -118,8 +121,40 @@ public class LogicalPlanBuilder { } public LogicalPlanBuilder filter(Expression predicate) { - LogicalFilter limitPlan = new LogicalFilter<>(predicate, this.plan); - return from(limitPlan); + LogicalFilter filter = new LogicalFilter<>(predicate, this.plan); + return from(filter); } + public LogicalPlanBuilder aggAllUsingIndex(List groupByKeysIndex, List outputExprsIndex) { + Builder groupByBuilder = ImmutableList.builder(); + for (Integer index : groupByKeysIndex) { + groupByBuilder.add(this.plan.getOutput().get(index)); + } + ImmutableList groupByKeys = groupByBuilder.build(); + + Builder outputBuilder = ImmutableList.builder(); + for (Integer index : outputExprsIndex) { + outputBuilder.add(this.plan.getOutput().get(index)); + } + ImmutableList outputExpresList = outputBuilder.build(); + + LogicalAggregate agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan); + return from(agg); + } + + public LogicalPlanBuilder aggGroupUsingIndex(List groupByKeysIndex, List outputExpresList) { + Builder groupByBuilder = ImmutableList.builder(); + for (Integer index : groupByKeysIndex) { + groupByBuilder.add(this.plan.getOutput().get(index)); + } + ImmutableList groupByKeys = groupByBuilder.build(); + + LogicalAggregate agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan); + return from(agg); + } + + public LogicalPlanBuilder agg(List groupByKeys, List outputExpresList) { + LogicalAggregate agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan); + return from(agg); + } }