From 3b0ddd7ae03ca29deb2ddb871593cc6ed1fe8aad Mon Sep 17 00:00:00 2001 From: shee <13843187+qzsee@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:54:21 +0800 Subject: [PATCH] [Enhancement](Nereids)(Step1) prune column for filter/agg/join/sort (#10478) Column pruning for filter/agg/join/sort. #### For agg Pattern : agg() Transformed: ``` agg | project | child ``` #### For filter()/sort(): Pattern: project(filter()/join()/sort()) Transformed: ``` project | filter/sort | project | child ``` #### For join Pattern: project(join()) Transformed: ``` project | join / \ project project | | child child ``` for example: ```sql table a: k1,v1 table b: k1,k2,k3,v1 select a.k1,b.k2 from a,b on a.k1 = b.k1 where a.k1 > 1 ``` origin plan tree: ``` project(a.k1,b.k2 ) | join(a:k1,v1 b:k1,k2,k3,v1) / \ scan(a:k1,v1) scan(b:k1,k2,k3,v1) ``` transformed plan tree: ``` project(a.k1,b.k2 ) | join(a:k1 b:k1,k2) / \ project(k1) project(k1,k2) | | scan(a:k1,v1) scan(b:k1,k2,k3,v1) ``` --- .../jobs/rewrite/RewriteTopDownJob.java | 1 - .../plans/logical/LogicalAggregate.java | 2 + .../operators/plans/logical/LogicalJoin.java | 25 ++ .../apache/doris/nereids/rules/RuleType.java | 10 +- .../logical/AbstractPushDownProjectRule.java | 61 +++++ .../rules/rewrite/logical/ColumnPruning.java | 47 ++++ .../rewrite/logical/PruneAggChildColumns.java | 68 ++++++ .../logical/PruneFilterChildColumns.java | 70 ++++++ .../logical/PruneJoinChildrenColumns.java | 89 +++++++ .../logical/PruneSortChildColumns.java | 55 +++++ .../nereids/trees/expressions/Alias.java | 6 + .../doris/nereids/trees/expressions/Slot.java | 4 + .../trees/expressions/SlotReference.java | 7 + .../expressions/visitor/IterationVisitor.java | 2 +- .../doris/nereids/util/ExpressionUtils.java | 1 - .../rules/rewrite/logical/AnalyzeUtils.java | 62 +++++ .../rewrite/logical/ColumnPruningTest.java | 227 ++++++++++++++++++ 17 files changed, 732 insertions(+), 5 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AbstractPushDownProjectRule.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneFilterChildColumns.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneSortChildColumns.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AnalyzeUtils.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java index 4504109d68..5a43b7a29c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java @@ -79,7 +79,6 @@ public class RewriteTopDownJob extends Job { return; } } - logicalExpression.setApplied(rule); } for (Group childGroup : group.getLogicalExpression().children()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java index 28ce1a62cf..5b65064dc1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java @@ -40,6 +40,8 @@ import java.util.Objects; *

* Each agg node only contains the select statement field of the same layer, * and other agg nodes in the subquery contain. + * Note: In general, the output of agg is a subset of the group by column plus aggregate column. + * In special cases. this relationship does not hold. for example, select k1+1, sum(v1) from table group by k1. */ public class LogicalAggregate extends LogicalUnaryOperator { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalJoin.java index d3a93bbf96..e5e61e7a15 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalJoin.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.Collectors; /** * Logical join plan operator. @@ -72,11 +73,35 @@ public class LogicalJoin extends LogicalBinaryOperator { @Override public List computeOutput(Plan leftInput, Plan rightInput) { + + List newLeftOutput = leftInput.getOutput().stream().map(o -> o.withNullable(true)) + .collect(Collectors.toList()); + + List newRightOutput = rightInput.getOutput().stream().map(o -> o.withNullable(true)) + .collect(Collectors.toList()); + switch (joinType) { case LEFT_SEMI_JOIN: + case LEFT_ANTI_JOIN: return ImmutableList.copyOf(leftInput.getOutput()); case RIGHT_SEMI_JOIN: + case RIGHT_ANTI_JOIN: return ImmutableList.copyOf(rightInput.getOutput()); + case LEFT_OUTER_JOIN: + return ImmutableList.builder() + .addAll(leftInput.getOutput()) + .addAll(newRightOutput) + .build(); + case RIGHT_OUTER_JOIN: + return ImmutableList.builder() + .addAll(newLeftOutput) + .addAll(rightInput.getOutput()) + .build(); + case FULL_OUTER_JOIN: + return ImmutableList.builder() + .addAll(newLeftOutput) + .addAll(newRightOutput) + .build(); default: return ImmutableList.builder() .addAll(leftInput.getOutput()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 954f060009..9764e4ce95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -36,11 +36,17 @@ public enum RuleType { RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE), RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE), PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE), - // rewrite rules AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE), - COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE), + // predicate push down rules PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE), + // column prune rules, + COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE), + COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE), + COLUMN_PRUNE_SORT_CHILD(RuleTypeClass.REWRITE), + COLUMN_PRUNE_JOIN_CHILD(RuleTypeClass.REWRITE), + + REWRITE_SENTINEL(RuleTypeClass.REWRITE), // exploration rules LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AbstractPushDownProjectRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AbstractPushDownProjectRule.java new file mode 100644 index 0000000000..2b638050f9 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AbstractPushDownProjectRule.java @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.pattern.PatternDescriptor; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Set; + +/** + * push down project base class. + */ +public abstract class AbstractPushDownProjectRule extends OneRewriteRuleFactory { + + PatternDescriptor target; + RuleType ruleType; + + @Override + public Rule build() { + return logicalProject(target).then(project -> { + List projects = Lists.newArrayList(); + projects.addAll(project.operator.getProjects()); + Set projectSlots = SlotExtractor.extractSlot(projects); + return plan(project.operator, pushDownProject(project.child(), projectSlots)); + }).toRule(ruleType); + } + + protected abstract Plan pushDownProject(C plan, Set references); + + public void setTarget(PatternDescriptor target) { + this.target = target; + } + + public void setRuleType(RuleType ruleType) { + this.ruleType = ruleType; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java new file mode 100644 index 0000000000..b99ed88508 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruning.java @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.PlanRuleFactory; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RulePromise; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * column prune rule set. + */ +public class ColumnPruning implements PlanRuleFactory { + @Override + public List> buildRules() { + return ImmutableList.of( + new PruneFilterChildColumns().build(), + new PruneAggChildColumns().build(), + new PruneJoinChildrenColumns().build(), + new PruneSortChildColumns().build() + ); + } + + @Override + public RulePromise defaultPromise() { + return RulePromise.REWRITE; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java new file mode 100644 index 0000000000..46d5ea6ed5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.operators.plans.logical.LogicalProject; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * prune its child output according to agg. + * pattern: agg() + * table a: k1,k2,k3,v1 + * select k1,sum(v1) from a group by k1 + * plan tree: + * agg + * | + * scan(k1,k2,k3,v1) + * transformed: + * agg + * | + * project(k1,v1) + * | + * scan(k1,k2,k3,v1) + */ +public class PruneAggChildColumns extends OneRewriteRuleFactory { + + @Override + public Rule build() { + return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> { + List slots = Lists.newArrayList(); + slots.addAll(agg.operator.getExpressions()); + Set outputs = SlotExtractor.extractSlot(slots); + List prunedOutputs = agg.child().getOutput().stream().filter(outputs::contains) + .collect(Collectors.toList()); + if (prunedOutputs.size() == agg.child().getOutput().size()) { + return agg; + } + return plan(agg.operator, plan(new LogicalProject(prunedOutputs), agg.child())); + })); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneFilterChildColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneFilterChildColumns.java new file mode 100644 index 0000000000..9daaabd296 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneFilterChildColumns.java @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.operators.plans.logical.LogicalFilter; +import org.apache.doris.nereids.operators.plans.logical.LogicalProject; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan; + +import com.google.common.collect.Lists; + +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * prune filter output. + * pattern: project(filter()) + * table a: k1,k2,k3,v1 + * select k1 from a where k2 > 3 + * plan tree: + * project(k1) + * | + * filter(k2 > 3) + * | + * scan(k1,k2,k3,v1) + * transformed: + * | + * filter(k2 > 3) + * | + * project(k1,k2) + * | + * scan(k1,k2,k3,v1) + */ +public class PruneFilterChildColumns extends AbstractPushDownProjectRule> { + + public PruneFilterChildColumns() { + setRuleType(RuleType.COLUMN_PRUNE_FILTER_CHILD); + setTarget(logicalFilter()); + } + + @Override + protected Plan pushDownProject(LogicalUnaryPlan filterPlan, Set references) { + Set filterSlots = SlotExtractor.extractSlot(filterPlan.operator.getPredicates()); + Set required = Stream.concat(references.stream(), filterSlots.stream()).collect(Collectors.toSet()); + if (required.containsAll(filterPlan.child().getOutput())) { + return filterPlan; + } + return plan(filterPlan.operator, plan(new LogicalProject(Lists.newArrayList(required)), filterPlan.child())); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java new file mode 100644 index 0000000000..b9bc738768 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.operators.plans.logical.LogicalJoin; +import org.apache.doris.nereids.operators.plans.logical.LogicalProject; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * prune join children output. + * pattern: project(join()) + * table a: k1,k2,k3,v1 + * table b: k1,k2,v1,v2 + * select a.k1,b.k2 from a join b on a.k1 = b.k1 where a.k3 > 1 + * plan tree: + * project(a.k1,b.k2) + * | + * join(k1,k2,k3,v1,k1,k2,v1,v2) + * / \ + * scan(a) scan(b) + * transformed: + * project(a.k1,b.k2) + * | + * join(k1,k2,k3,v1,k1,k2,v1,v2) + * / \ + * project(a.k1,a.k3) project(b.k2,b.k1) + * | | + * scan scan + */ +public class PruneJoinChildrenColumns + extends AbstractPushDownProjectRule> { + + public PruneJoinChildrenColumns() { + setRuleType(RuleType.COLUMN_PRUNE_JOIN_CHILD); + setTarget(logicalJoin()); + } + + @Override + protected Plan pushDownProject(LogicalBinaryPlan joinPlan, + Set references) { + if (joinPlan.operator.getCondition().isPresent()) { + references.addAll(SlotExtractor.extractSlot(joinPlan.operator.getCondition().get())); + } + Set exprIds = references.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()); + + List leftInputs = joinPlan.left().getOutput().stream() + .filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList()); + List rightInputs = joinPlan.right().getOutput().stream() + .filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList()); + + Plan leftPlan = joinPlan.left(); + Plan rightPlan = joinPlan.right(); + + if (leftInputs.size() != leftPlan.getOutput().size()) { + leftPlan = plan(new LogicalProject(leftInputs), leftPlan); + } + + if (rightInputs.size() != rightPlan.getOutput().size()) { + rightPlan = plan(new LogicalProject(rightInputs), rightPlan); + } + return plan(joinPlan.operator, leftPlan, rightPlan); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneSortChildColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneSortChildColumns.java new file mode 100644 index 0000000000..333d4a7749 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneSortChildColumns.java @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.operators.plans.logical.LogicalProject; +import org.apache.doris.nereids.operators.plans.logical.LogicalSort; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan; + +import com.google.common.collect.Lists; + +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * prune join children output. + * pattern: project(sort()) + */ +public class PruneSortChildColumns extends AbstractPushDownProjectRule> { + + public PruneSortChildColumns() { + setRuleType(RuleType.COLUMN_PRUNE_SORT_CHILD); + setTarget(logicalSort()); + } + + @Override + protected Plan pushDownProject(LogicalUnaryPlan sortPlan, Set references) { + Set sortSlots = SlotExtractor.extractSlot(sortPlan.operator.getExpressions()); + Set required = Stream.concat(references.stream(), sortSlots.stream()).collect(Collectors.toSet()); + if (required.containsAll(sortPlan.child().getOutput())) { + return sortPlan; + } + return plan(sortPlan.operator, plan(new LogicalProject(Lists.newArrayList(required)), sortPlan.child())); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java index 33e7f79807..10f06916b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.NodeType; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; @@ -95,6 +96,11 @@ public class Alias extends NamedExpression return new Alias<>(childType, name); } + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitAlias(this, context); + } + + @Override public Expression withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new Alias<>(children.get(0), name); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Slot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Slot.java index 3a215d1c97..f77f9cc2bf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Slot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Slot.java @@ -32,4 +32,8 @@ public abstract class Slot extends NamedExpression implements LeafExpression { public Slot toSlot() { return this; } + + public Slot withNullable(boolean newNullable) { + throw new RuntimeException("Do not implement"); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 789f131a20..868d016046 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -145,4 +145,11 @@ public class SlotReference extends Slot { public SlotReference clone() { return new SlotReference(name, getDataType(), nullable, Lists.newArrayList(qualifier)); } + + public Slot withNullable(boolean newNullable) { + if (this.nullable == newNullable) { + return this; + } + return new SlotReference(exprId, name, dataType, newNullable, qualifier); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/IterationVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/IterationVisitor.java index 9462da90ce..1b05457259 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/IterationVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/IterationVisitor.java @@ -178,5 +178,5 @@ public abstract class IterationVisitor extends DefaultExpressionVisitor target; + List source; + + source = getStringList(p1); + target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.student.id", + "default_cluster:test.score.grade"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p20); + target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p21); + target = Lists.newArrayList("default_cluster:test.score.sid", "default_cluster:test.score.grade"); + Assertions.assertTrue(source.containsAll(target)); + + } + + @Test + public void testPruneColumns2() { + + String sql + = "select name,sex,cid,grade from student left join score on student.id = score.sid " + + "where score.grade > 60"; + Plan plan = AnalyzeUtils.analyze(sql, connectContext); + + Memo memo = new Memo(); + memo.initialize(plan); + + Plan out = process(memo); + + Plan l1 = out.child(0).child(0); + Plan l20 = l1.child(0).child(0); + Plan l21 = l1.child(0).child(1); + + LogicalProject p1 = (LogicalProject) l1.getOperator(); + LogicalProject p20 = (LogicalProject) l20.getOperator(); + Assertions.assertTrue(l21.getOperator() instanceof LogicalRelation); + + List target; + List source; + + source = getStringList(p1); + target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.score.cid", + "default_cluster:test.score.grade", "default_cluster:test.student.sex"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p20); + target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name", + "default_cluster:test.student.sex"); + Assertions.assertTrue(source.containsAll(target)); + } + + + @Test + public void testPruneColumns3() { + + String sql = "select id,name from student where age > 18"; + Plan plan = AnalyzeUtils.analyze(sql, connectContext); + + Memo memo = new Memo(); + memo.initialize(plan); + + Plan out = process(memo); + + Plan l1 = out.child(0).child(0); + LogicalProject p1 = (LogicalProject) l1.getOperator(); + + List target; + List source; + + source = getStringList(p1); + target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.student.id", + "default_cluster:test.student.age"); + Assertions.assertTrue(source.containsAll(target)); + + } + + @Test + public void testPruneColumns4() { + + String sql + = "select name,cname,grade from student left join score on student.id = score.sid left join course " + + "on score.cid = course.cid where score.grade > 60"; + Plan plan = AnalyzeUtils.analyze(sql, connectContext); + + Memo memo = new Memo(); + memo.initialize(plan); + + Plan out = process(memo); + + Plan l1 = out.child(0).child(0); + Plan l20 = l1.child(0).child(0); + Plan l21 = l1.child(0).child(1); + + Plan l20Left = l20.child(0).child(0); + Plan l20Right = l20.child(0).child(1); + + Assertions.assertTrue(l20.getOperator() instanceof LogicalProject); + Assertions.assertTrue(l20Left.getOperator() instanceof LogicalProject); + Assertions.assertTrue(l20Right.getOperator() instanceof LogicalRelation); + + LogicalProject p1 = (LogicalProject) l1.getOperator(); + LogicalProject p20 = (LogicalProject) l20.getOperator(); + LogicalProject p21 = (LogicalProject) l21.getOperator(); + + LogicalProject p20lo = (LogicalProject) l20Left.getOperator(); + + List target; + List source; + + source = getStringList(p1); + target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.course.cname", + "default_cluster:test.score.grade"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p20); + target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.score.cid", + "default_cluster:test.score.grade"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p21); + target = Lists.newArrayList("default_cluster:test.course.cid", "default_cluster:test.course.cname"); + Assertions.assertTrue(source.containsAll(target)); + + source = getStringList(p20lo); + target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name"); + Assertions.assertTrue(source.containsAll(target)); + } + + private Plan process(Memo memo) { + OptimizerContext optimizerContext = new OptimizerContext(memo); + PlannerContext plannerContext = new PlannerContext(optimizerContext, connectContext, new PhysicalProperties()); + RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), new ColumnPruning().buildRules(), + plannerContext); + plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob); + plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext); + return memo.copyOut(); + } + + private List getStringList(LogicalProject p) { + return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList()); + } +}