From 5fad4f4c7b78ee8a7ca38b82900fb08d9782175c Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Fri, 11 Nov 2022 13:34:29 +0800 Subject: [PATCH] [feature](Nereids) replace order by keys by child output if possible (#14108) To support query like that: SELECT c1 + 1 as a, sum(c2) FROM t GROUP BY c1 + 1 ORDER BY c1 + 1 After rewrite, plan will equal to SELECT c1 + 1 as a, sum(c2) FROM t GROUP BY c1 + 1 ORDER BY a --- .../nereids/jobs/batch/AnalyzeRulesJob.java | 4 +- .../apache/doris/nereids/rules/RuleType.java | 9 +- .../rules/analysis/FillUpMissingSlots.java | 64 ++++++------ .../ReplaceExpressionByChildOutput.java | 99 +++++++++++++++++++ .../analysis}/AnalyzeFunctionTest.java | 3 +- .../analysis}/AnalyzeSubQueryTest.java | 7 +- .../analysis}/AnalyzeWhereSubqueryTest.java | 6 +- .../analysis}/FillUpMissingSlotsTest.java | 35 +++---- .../analysis}/RegisterCTETest.java | 7 +- .../ReplaceExpressionByChildOutputTest.java | 97 ++++++++++++++++++ 10 files changed, 274 insertions(+), 57 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java rename fe/fe-core/src/test/java/org/apache/doris/nereids/{util => rules/analysis}/AnalyzeFunctionTest.java (95%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/{util => rules/analysis}/AnalyzeSubQueryTest.java (97%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/{util => rules/analysis}/AnalyzeWhereSubqueryTest.java (99%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/{parser => rules/analysis}/FillUpMissingSlotsTest.java (95%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/{util => rules/analysis}/RegisterCTETest.java (98%) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutputTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java index db7528d4bc..a7a48738ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.analysis.BindSlotReference; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate; import org.apache.doris.nereids.rules.analysis.RegisterCTE; +import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput; import org.apache.doris.nereids.rules.analysis.Scope; import org.apache.doris.nereids.rules.analysis.UserAuthentication; @@ -52,7 +53,8 @@ public class AnalyzeRulesJob extends BatchRulesJob { new UserAuthentication(), new BindSlotReference(scope), new BindFunction(), - new ProjectToGlobalAggregate() + new ProjectToGlobalAggregate(), + new ReplaceExpressionByChildOutput() )), topDownBatch(ImmutableList.of( new FillUpMissingSlots() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 89a00279e6..14b5b2a263 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -46,9 +46,12 @@ public enum RuleType { BINDING_FILTER_FUNCTION(RuleTypeClass.REWRITE), BINDING_HAVING_FUNCTION(RuleTypeClass.REWRITE), BINDING_SORT_FUNCTION(RuleTypeClass.REWRITE), - FILL_UP_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE), - FILL_UP_SORT_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE), - FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE), + + REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE), + + FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE), + FILL_UP_SORT_AGGREGATE(RuleTypeClass.REWRITE), + FILL_UP_SORT_HAVING_AGGREGATE(RuleTypeClass.REWRITE), FILL_UP_SORT_PROJECT(RuleTypeClass.REWRITE), RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index 9215363e95..fc385587dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.util.ExpressionUtils; @@ -50,15 +51,34 @@ import java.util.stream.Collectors; /** * Resolve having clause to the aggregation. + * need Top to Down to traverse plan, + * because we need to process FILL_UP_SORT_HAVING_AGGREGATE before FILL_UP_HAVING_AGGREGATE. */ public class FillUpMissingSlots implements AnalysisRuleFactory { @Override public List buildRules() { return ImmutableList.of( - RuleType.FILL_UP_SORT_AGGREGATE_FUNCTIONS.build( + RuleType.FILL_UP_SORT_PROJECT.build( + logicalSort(logicalProject()) + .when(this::checkSort) + .then(sort -> { + final Builder projectionsBuilder = ImmutableList.builder(); + projectionsBuilder.addAll(sort.child().getProjects()); + Set notExistedInProject = sort.getExpressions().stream() + .map(Expression::getInputSlots) + .flatMap(Set::stream) + .filter(s -> !sort.child().getOutputSet().contains(s)) + .collect(Collectors.toSet()); + projectionsBuilder.addAll(notExistedInProject); + return new LogicalProject(sort.child().getOutput(), + new LogicalSort<>(sort.getOrderKeys(), + new LogicalProject<>(projectionsBuilder.build(), + sort.child().child()))); + }) + ), + RuleType.FILL_UP_SORT_AGGREGATE.build( logicalSort(logicalAggregate()) - .when(sort -> sort.getExpressions().stream() - .anyMatch(e -> e.containsType(AggregateFunction.class))) + .when(this::checkSort) .then(sort -> { LogicalAggregate aggregate = sort.child(); Resolver resolver = new Resolver(aggregate); @@ -74,10 +94,9 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { }); }) ), - RuleType.FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS.build( + RuleType.FILL_UP_SORT_HAVING_AGGREGATE.build( logicalSort(logicalHaving(logicalAggregate())) - .when(sort -> sort.getExpressions().stream() - .anyMatch(e -> e.containsType(AggregateFunction.class))) + .when(this::checkSort) .then(sort -> { LogicalAggregate aggregate = sort.child().child(); Resolver resolver = new Resolver(aggregate); @@ -93,7 +112,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { }); }) ), - RuleType.FILL_UP_HAVING_AGGREGATE_FUNCTIONS.build( + RuleType.FILL_UP_HAVING_AGGREGATE.build( logicalHaving(logicalAggregate()).then(having -> { LogicalAggregate aggregate = having.child(); Resolver resolver = new Resolver(aggregate); @@ -104,27 +123,6 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { return new LogicalFilter<>(newPredicates, a); }); }) - ), - RuleType.FILL_UP_SORT_PROJECT.build( - logicalSort(logicalProject()) - .when(sort -> sort.getExpressions().stream() - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .anyMatch(s -> !sort.child().getOutputSet().contains(s))) - .then(sort -> { - final Builder projectionsBuilder = ImmutableList.builder(); - projectionsBuilder.addAll(sort.child().getProjects()); - Set notExistedInProject = sort.getExpressions().stream() - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .filter(s -> !sort.child().getOutputSet().contains(s)) - .collect(Collectors.toSet()); - projectionsBuilder.addAll(notExistedInProject); - return new LogicalProject(sort.child().getOutput(), - new LogicalSort<>(sort.getOrderKeys(), - new LogicalProject<>(projectionsBuilder.build(), - sort.child().child()))); - }) ) ); } @@ -253,4 +251,14 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { Plan plan = planGenerator.apply(resolver, newAggregate); return new LogicalProject<>(projections, plan); } + + private boolean checkSort(LogicalSort logicalSort) { + return logicalSort.getExpressions().stream() + .map(Expression::getInputSlots) + .flatMap(Set::stream) + .anyMatch(s -> !logicalSort.child().getOutputSet().contains(s)) + || logicalSort.getOrderKeys().stream() + .map(OrderKey::getExpr) + .anyMatch(e -> e.containsType(AggregateFunction.class)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java new file mode 100644 index 0000000000..fb04a9c8de --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +/** + * replace. + */ +public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory { + @Override + public List buildRules() { + return ImmutableList.builder() + .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( + logicalSort(logicalProject()).then(sort -> { + LogicalProject project = sort.child(); + Map sMap = Maps.newHashMap(); + project.getProjects().stream() + .filter(Alias.class::isInstance) + .map(Alias.class::cast) + .forEach(p -> sMap.put(p.child(), p.toSlot())); + return replaceSortExpression(sort, sMap); + }) + )) + .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( + logicalSort(logicalAggregate()).then(sort -> { + LogicalAggregate aggregate = sort.child(); + Map sMap = Maps.newHashMap(); + aggregate.getOutputExpressions().stream() + .filter(Alias.class::isInstance) + .map(Alias.class::cast) + .forEach(p -> sMap.put(p.child(), p.toSlot())); + return replaceSortExpression(sort, sMap); + }) + )).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( + logicalSort(logicalHaving(logicalAggregate())).then(sort -> { + LogicalAggregate aggregate = sort.child().child(); + Map sMap = Maps.newHashMap(); + aggregate.getOutputExpressions().stream() + .filter(Alias.class::isInstance) + .map(Alias.class::cast) + .forEach(p -> sMap.put(p.child(), p.toSlot())); + return replaceSortExpression(sort, sMap); + }) + )) + .build(); + } + + private LogicalPlan replaceSortExpression(LogicalSort sort, Map sMap) { + List orderKeys = sort.getOrderKeys(); + AtomicBoolean changed = new AtomicBoolean(false); + List newKeys = orderKeys.stream().map(k -> { + Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap); + if (newExpr != k.getExpr()) { + changed.set(true); + } + return new OrderKey(newExpr, k.isAsc(), k.isNullFirst()); + }).collect(Collectors.toList()); + if (changed.get()) { + return new LogicalSort<>(newKeys, sort.child()); + } else { + return sort; + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeFunctionTest.java similarity index 95% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeFunctionTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeFunctionTest.java index 878d57fe22..c60a63ce78 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeFunctionTest.java @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.util; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; import org.junit.jupiter.api.Assertions; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java similarity index 97% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java index 43ca065228..e0a07abc2b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.util; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; @@ -23,7 +23,6 @@ import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.rules.analysis.EliminateAliasNode; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; @@ -31,6 +30,10 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.util.FieldChecker; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java similarity index 99% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index f3e45631e3..1059d9bd3c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.util; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; @@ -44,6 +44,10 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.util.FieldChecker; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java similarity index 95% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/parser/FillUpMissingSlotsTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index c63310e518..642b1d90b4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.parser; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.common.ExceptionChecker; import org.apache.doris.nereids.datasets.tpch.AnalyzeCheckTestBase; @@ -408,12 +408,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( - logicalProject( - logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))))); + logicalSort( + logicalAggregate( + logicalOlapScan() + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)"; a1 = new SlotReference( @@ -427,12 +426,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt Alias value = new Alias(new ExprId(3), new Sum(a2), "value"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( - logicalProject( - logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))))); + logicalSort( + logicalAggregate( + logicalOlapScan() + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)"; a1 = new SlotReference( @@ -463,12 +461,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( - logicalProject( - logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) - ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true)))))); + logicalSort( + logicalAggregate( + logicalOlapScan() + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) + ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)"; Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((short) 3))), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/RegisterCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/util/RegisterCTETest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java index 85ac67b453..5ac061ed58 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/RegisterCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/RegisterCTETest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.util; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; @@ -29,7 +29,6 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; -import org.apache.doris.nereids.rules.analysis.CTEContext; import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.InApplyToJoin; import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter; @@ -47,6 +46,10 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.util.FieldChecker; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutputTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutputTest.java new file mode 100644 index 0000000000..5206dcb402 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutputTest.java @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +import java.util.List; + +public class ReplaceExpressionByChildOutputTest implements PatternMatchSupported { + + @Test + void testSortProject() { + SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE); + Alias alias = new Alias(slotReference, "a"); + LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalProject logicalProject = new LogicalProject<>(ImmutableList.of(alias), logicalOlapScan); + List orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true)); + LogicalSort> logicalSort = new LogicalSort<>(orderKeys, logicalProject); + + PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort) + .applyBottomUp(new ReplaceExpressionByChildOutput()) + .matchesFromRoot( + logicalSort(logicalProject()).when(sort -> + ((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId())) + ); + } + + @Test + void testSortAggregate() { + SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE); + Alias alias = new Alias(slotReference, "a"); + LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalAggregate logicalAggregate = new LogicalAggregate<>( + ImmutableList.of(alias), ImmutableList.of(alias), logicalOlapScan); + List orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true)); + LogicalSort> logicalSort = new LogicalSort<>(orderKeys, logicalAggregate); + + PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort) + .applyBottomUp(new ReplaceExpressionByChildOutput()) + .matchesFromRoot( + logicalSort(logicalAggregate()).when(sort -> + ((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId())) + ); + } + + @Test + void testSortHavingAggregate() { + SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE); + Alias alias = new Alias(slotReference, "a"); + LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalAggregate logicalAggregate = new LogicalAggregate<>( + ImmutableList.of(alias), ImmutableList.of(alias), logicalOlapScan); + LogicalHaving logicalHaving = new LogicalHaving<>(BooleanLiteral.TRUE, logicalAggregate); + List orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true)); + LogicalSort logicalSort = new LogicalSort<>(orderKeys, logicalHaving); + + PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort) + .applyBottomUp(new ReplaceExpressionByChildOutput()) + .matchesFromRoot( + logicalSort(logicalHaving(logicalAggregate())).when(sort -> + ((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId())) + ); + } +}