[feature](Nereids) replace order by keys by child output if possible (#14108)
To support query like that: SELECT c1 + 1 as a, sum(c2) FROM t GROUP BY c1 + 1 ORDER BY c1 + 1 After rewrite, plan will equal to SELECT c1 + 1 as a, sum(c2) FROM t GROUP BY c1 + 1 ORDER BY a
This commit is contained in:
@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.analysis.BindSlotReference;
|
||||
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
|
||||
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.RegisterCTE;
|
||||
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
|
||||
import org.apache.doris.nereids.rules.analysis.Scope;
|
||||
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
|
||||
|
||||
@ -52,7 +53,8 @@ public class AnalyzeRulesJob extends BatchRulesJob {
|
||||
new UserAuthentication(),
|
||||
new BindSlotReference(scope),
|
||||
new BindFunction(),
|
||||
new ProjectToGlobalAggregate()
|
||||
new ProjectToGlobalAggregate(),
|
||||
new ReplaceExpressionByChildOutput()
|
||||
)),
|
||||
topDownBatch(ImmutableList.of(
|
||||
new FillUpMissingSlots()
|
||||
|
||||
@ -46,9 +46,12 @@ public enum RuleType {
|
||||
BINDING_FILTER_FUNCTION(RuleTypeClass.REWRITE),
|
||||
BINDING_HAVING_FUNCTION(RuleTypeClass.REWRITE),
|
||||
BINDING_SORT_FUNCTION(RuleTypeClass.REWRITE),
|
||||
FILL_UP_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
|
||||
FILL_UP_SORT_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
|
||||
FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
|
||||
|
||||
REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE),
|
||||
|
||||
FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
FILL_UP_SORT_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
FILL_UP_SORT_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
FILL_UP_SORT_PROJECT(RuleTypeClass.REWRITE),
|
||||
|
||||
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
@ -50,15 +51,34 @@ import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Resolve having clause to the aggregation.
|
||||
* need Top to Down to traverse plan,
|
||||
* because we need to process FILL_UP_SORT_HAVING_AGGREGATE before FILL_UP_HAVING_AGGREGATE.
|
||||
*/
|
||||
public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
RuleType.FILL_UP_SORT_AGGREGATE_FUNCTIONS.build(
|
||||
RuleType.FILL_UP_SORT_PROJECT.build(
|
||||
logicalSort(logicalProject())
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
final Builder<NamedExpression> projectionsBuilder = ImmutableList.builder();
|
||||
projectionsBuilder.addAll(sort.child().getProjects());
|
||||
Set<Slot> notExistedInProject = sort.getExpressions().stream()
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.filter(s -> !sort.child().getOutputSet().contains(s))
|
||||
.collect(Collectors.toSet());
|
||||
projectionsBuilder.addAll(notExistedInProject);
|
||||
return new LogicalProject(sort.child().getOutput(),
|
||||
new LogicalSort<>(sort.getOrderKeys(),
|
||||
new LogicalProject<>(projectionsBuilder.build(),
|
||||
sort.child().child())));
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_SORT_AGGREGATE.build(
|
||||
logicalSort(logicalAggregate())
|
||||
.when(sort -> sort.getExpressions().stream()
|
||||
.anyMatch(e -> e.containsType(AggregateFunction.class)))
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = sort.child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
@ -74,10 +94,9 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
});
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS.build(
|
||||
RuleType.FILL_UP_SORT_HAVING_AGGREGATE.build(
|
||||
logicalSort(logicalHaving(logicalAggregate()))
|
||||
.when(sort -> sort.getExpressions().stream()
|
||||
.anyMatch(e -> e.containsType(AggregateFunction.class)))
|
||||
.when(this::checkSort)
|
||||
.then(sort -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = sort.child().child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
@ -93,7 +112,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
});
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_HAVING_AGGREGATE_FUNCTIONS.build(
|
||||
RuleType.FILL_UP_HAVING_AGGREGATE.build(
|
||||
logicalHaving(logicalAggregate()).then(having -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = having.child();
|
||||
Resolver resolver = new Resolver(aggregate);
|
||||
@ -104,27 +123,6 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
return new LogicalFilter<>(newPredicates, a);
|
||||
});
|
||||
})
|
||||
),
|
||||
RuleType.FILL_UP_SORT_PROJECT.build(
|
||||
logicalSort(logicalProject())
|
||||
.when(sort -> sort.getExpressions().stream()
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.anyMatch(s -> !sort.child().getOutputSet().contains(s)))
|
||||
.then(sort -> {
|
||||
final Builder<NamedExpression> projectionsBuilder = ImmutableList.builder();
|
||||
projectionsBuilder.addAll(sort.child().getProjects());
|
||||
Set<Slot> notExistedInProject = sort.getExpressions().stream()
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.filter(s -> !sort.child().getOutputSet().contains(s))
|
||||
.collect(Collectors.toSet());
|
||||
projectionsBuilder.addAll(notExistedInProject);
|
||||
return new LogicalProject(sort.child().getOutput(),
|
||||
new LogicalSort<>(sort.getOrderKeys(),
|
||||
new LogicalProject<>(projectionsBuilder.build(),
|
||||
sort.child().child())));
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
@ -253,4 +251,14 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
Plan plan = planGenerator.apply(resolver, newAggregate);
|
||||
return new LogicalProject<>(projections, plan);
|
||||
}
|
||||
|
||||
private boolean checkSort(LogicalSort<? extends LogicalPlan> logicalSort) {
|
||||
return logicalSort.getExpressions().stream()
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.anyMatch(s -> !logicalSort.child().getOutputSet().contains(s))
|
||||
|| logicalSort.getOrderKeys().stream()
|
||||
.map(OrderKey::getExpr)
|
||||
.anyMatch(e -> e.containsType(AggregateFunction.class));
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* replace.
|
||||
*/
|
||||
public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.<Rule>builder()
|
||||
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalProject()).then(sort -> {
|
||||
LogicalProject<GroupPlan> project = sort.child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
project.getProjects().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
))
|
||||
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalAggregate()).then(sort -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = sort.child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
aggregate.getOutputExpressions().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
)).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = sort.child().child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
aggregate.getOutputExpressions().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
))
|
||||
.build();
|
||||
}
|
||||
|
||||
private LogicalPlan replaceSortExpression(LogicalSort<? extends LogicalPlan> sort, Map<Expression, Slot> sMap) {
|
||||
List<OrderKey> orderKeys = sort.getOrderKeys();
|
||||
AtomicBoolean changed = new AtomicBoolean(false);
|
||||
List<OrderKey> newKeys = orderKeys.stream().map(k -> {
|
||||
Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap);
|
||||
if (newExpr != k.getExpr()) {
|
||||
changed.set(true);
|
||||
}
|
||||
return new OrderKey(newExpr, k.isAsc(), k.isNullFirst());
|
||||
}).collect(Collectors.toList());
|
||||
if (changed.get()) {
|
||||
return new LogicalSort<>(newKeys, sort.child());
|
||||
} else {
|
||||
return sort;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -15,12 +15,13 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
@ -23,7 +23,6 @@ import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
|
||||
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.analysis.EliminateAliasNode;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
|
||||
@ -31,6 +30,10 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
@ -44,6 +44,10 @@ import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.parser;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.common.ExceptionChecker;
|
||||
import org.apache.doris.nereids.datasets.tpch.AnalyzeCheckTestBase;
|
||||
@ -408,12 +408,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
|
||||
sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))));
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)";
|
||||
a1 = new SlotReference(
|
||||
@ -427,12 +426,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
|
||||
Alias value = new Alias(new ExprId(3), new Sum(a2), "value");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))));
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)";
|
||||
a1 = new SlotReference(
|
||||
@ -463,12 +461,11 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
|
||||
Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true))))));
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)";
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((short) 3))),
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
@ -29,7 +29,6 @@ import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.analysis.CTEContext;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.InApplyToJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter;
|
||||
@ -47,6 +46,10 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -0,0 +1,97 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class ReplaceExpressionByChildOutputTest implements PatternMatchSupported {
|
||||
|
||||
@Test
|
||||
void testSortProject() {
|
||||
SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE);
|
||||
Alias alias = new Alias(slotReference, "a");
|
||||
LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
LogicalProject<Plan> logicalProject = new LogicalProject<>(ImmutableList.of(alias), logicalOlapScan);
|
||||
List<OrderKey> orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true));
|
||||
LogicalSort<LogicalProject<Plan>> logicalSort = new LogicalSort<>(orderKeys, logicalProject);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort)
|
||||
.applyBottomUp(new ReplaceExpressionByChildOutput())
|
||||
.matchesFromRoot(
|
||||
logicalSort(logicalProject()).when(sort ->
|
||||
((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId()))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSortAggregate() {
|
||||
SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE);
|
||||
Alias alias = new Alias(slotReference, "a");
|
||||
LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
LogicalAggregate<Plan> logicalAggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(alias), ImmutableList.of(alias), logicalOlapScan);
|
||||
List<OrderKey> orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true));
|
||||
LogicalSort<LogicalAggregate<Plan>> logicalSort = new LogicalSort<>(orderKeys, logicalAggregate);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort)
|
||||
.applyBottomUp(new ReplaceExpressionByChildOutput())
|
||||
.matchesFromRoot(
|
||||
logicalSort(logicalAggregate()).when(sort ->
|
||||
((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId()))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSortHavingAggregate() {
|
||||
SlotReference slotReference = new SlotReference("col1", IntegerType.INSTANCE);
|
||||
Alias alias = new Alias(slotReference, "a");
|
||||
LogicalOlapScan logicalOlapScan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
LogicalAggregate<Plan> logicalAggregate = new LogicalAggregate<>(
|
||||
ImmutableList.of(alias), ImmutableList.of(alias), logicalOlapScan);
|
||||
LogicalHaving<Plan> logicalHaving = new LogicalHaving<>(BooleanLiteral.TRUE, logicalAggregate);
|
||||
List<OrderKey> orderKeys = ImmutableList.of(new OrderKey(slotReference, true, true));
|
||||
LogicalSort<Plan> logicalSort = new LogicalSort<>(orderKeys, logicalHaving);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), logicalSort)
|
||||
.applyBottomUp(new ReplaceExpressionByChildOutput())
|
||||
.matchesFromRoot(
|
||||
logicalSort(logicalHaving(logicalAggregate())).when(sort ->
|
||||
((Slot) (sort.getOrderKeys().get(0).getExpr())).getExprId().equals(alias.getExprId()))
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user