[Enhancement](Nereids)(Step1) prune column for filter/agg/join/sort (#10478)

Column pruning for filter/agg/join/sort.

#### For agg
Pattern : agg()
Transformed:
```
agg
  |
project
  |
child
```
#### For filter()/sort():
Pattern: project(filter()/join()/sort())
Transformed:
```
project
    |
filter/sort
   |
project
   |
child
```
#### For join
Pattern: project(join())
Transformed:
```
        project
             |
           join
       /          \
project    project
   |              |
child        child
```

for example:
```sql
table a: k1,v1
table b: k1,k2,k3,v1
select a.k1,b.k2 from a,b on a.k1 = b.k1 where a.k1 > 1
```

origin plan tree:
```
         project(a.k1,b.k2 )
                        |
          join(a:k1,v1 b:k1,k2,k3,v1)
                /                   \
 scan(a:k1,v1)         scan(b:k1,k2,k3,v1)
```

transformed plan tree:

```
              project(a.k1,b.k2 )
                        |
               join(a:k1 b:k1,k2)
               /                  \
          project(k1)   project(k1,k2)
               |                      |
 scan(a:k1,v1)       scan(b:k1,k2,k3,v1)
```
This commit is contained in:
shee
2022-07-05 17:54:21 +08:00
committed by GitHub
parent 86502b014d
commit 3b0ddd7ae0
17 changed files with 732 additions and 5 deletions

View File

@ -79,7 +79,6 @@ public class RewriteTopDownJob extends Job<Plan> {
return;
}
}
logicalExpression.setApplied(rule);
}
for (Group childGroup : group.getLogicalExpression().children()) {

View File

@ -40,6 +40,8 @@ import java.util.Objects;
* <p>
* Each agg node only contains the select statement field of the same layer,
* and other agg nodes in the subquery contain.
* Note: In general, the output of agg is a subset of the group by column plus aggregate column.
* In special cases. this relationship does not hold. for example, select k1+1, sum(v1) from table group by k1.
*/
public class LogicalAggregate extends LogicalUnaryOperator {

View File

@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Logical join plan operator.
@ -72,11 +73,35 @@ public class LogicalJoin extends LogicalBinaryOperator {
@Override
public List<Slot> computeOutput(Plan leftInput, Plan rightInput) {
List<Slot> newLeftOutput = leftInput.getOutput().stream().map(o -> o.withNullable(true))
.collect(Collectors.toList());
List<Slot> newRightOutput = rightInput.getOutput().stream().map(o -> o.withNullable(true))
.collect(Collectors.toList());
switch (joinType) {
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
return ImmutableList.copyOf(leftInput.getOutput());
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
return ImmutableList.copyOf(rightInput.getOutput());
case LEFT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(leftInput.getOutput())
.addAll(newRightOutput)
.build();
case RIGHT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(newLeftOutput)
.addAll(rightInput.getOutput())
.build();
case FULL_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(newLeftOutput)
.addAll(newRightOutput)
.build();
default:
return ImmutableList.<Slot>builder()
.addAll(leftInput.getOutput())

View File

@ -36,11 +36,17 @@ public enum RuleType {
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
// rewrite rules
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_SORT_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_JOIN_CHILD(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),
// exploration rules
LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION),

View File

@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Set;
/**
* push down project base class.
*/
public abstract class AbstractPushDownProjectRule<C extends Plan> extends OneRewriteRuleFactory {
PatternDescriptor<C, Plan> target;
RuleType ruleType;
@Override
public Rule<Plan> build() {
return logicalProject(target).then(project -> {
List<Expression> projects = Lists.newArrayList();
projects.addAll(project.operator.getProjects());
Set<Slot> projectSlots = SlotExtractor.extractSlot(projects);
return plan(project.operator, pushDownProject(project.child(), projectSlots));
}).toRule(ruleType);
}
protected abstract Plan pushDownProject(C plan, Set<Slot> references);
public void setTarget(PatternDescriptor<C, Plan> target) {
this.target = target;
}
public void setRuleType(RuleType ruleType) {
this.ruleType = ruleType;
}
}

View File

@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* column prune rule set.
*/
public class ColumnPruning implements PlanRuleFactory {
@Override
public List<Rule<Plan>> buildRules() {
return ImmutableList.of(
new PruneFilterChildColumns().build(),
new PruneAggChildColumns().build(),
new PruneJoinChildrenColumns().build(),
new PruneSortChildColumns().build()
);
}
@Override
public RulePromise defaultPromise() {
return RulePromise.REWRITE;
}
}

View File

@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* prune its child output according to agg.
* pattern: agg()
* table a: k1,k2,k3,v1
* select k1,sum(v1) from a group by k1
* plan tree:
* agg
* |
* scan(k1,k2,k3,v1)
* transformed:
* agg
* |
* project(k1,v1)
* |
* scan(k1,k2,k3,v1)
*/
public class PruneAggChildColumns extends OneRewriteRuleFactory {
@Override
public Rule<Plan> build() {
return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
List<Expression> slots = Lists.newArrayList();
slots.addAll(agg.operator.getExpressions());
Set<Slot> outputs = SlotExtractor.extractSlot(slots);
List<NamedExpression> prunedOutputs = agg.child().getOutput().stream().filter(outputs::contains)
.collect(Collectors.toList());
if (prunedOutputs.size() == agg.child().getOutput().size()) {
return agg;
}
return plan(agg.operator, plan(new LogicalProject(prunedOutputs), agg.child()));
}));
}
}

View File

@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
import com.google.common.collect.Lists;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* prune filter output.
* pattern: project(filter())
* table a: k1,k2,k3,v1
* select k1 from a where k2 > 3
* plan tree:
* project(k1)
* |
* filter(k2 > 3)
* |
* scan(k1,k2,k3,v1)
* transformed:
* |
* filter(k2 > 3)
* |
* project(k1,k2)
* |
* scan(k1,k2,k3,v1)
*/
public class PruneFilterChildColumns extends AbstractPushDownProjectRule<LogicalUnaryPlan<LogicalFilter, GroupPlan>> {
public PruneFilterChildColumns() {
setRuleType(RuleType.COLUMN_PRUNE_FILTER_CHILD);
setTarget(logicalFilter());
}
@Override
protected Plan pushDownProject(LogicalUnaryPlan<LogicalFilter, GroupPlan> filterPlan, Set<Slot> references) {
Set<Slot> filterSlots = SlotExtractor.extractSlot(filterPlan.operator.getPredicates());
Set<Slot> required = Stream.concat(references.stream(), filterSlots.stream()).collect(Collectors.toSet());
if (required.containsAll(filterPlan.child().getOutput())) {
return filterPlan;
}
return plan(filterPlan.operator, plan(new LogicalProject(Lists.newArrayList(required)), filterPlan.child()));
}
}

View File

@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* prune join children output.
* pattern: project(join())
* table a: k1,k2,k3,v1
* table b: k1,k2,v1,v2
* select a.k1,b.k2 from a join b on a.k1 = b.k1 where a.k3 > 1
* plan tree:
* project(a.k1,b.k2)
* |
* join(k1,k2,k3,v1,k1,k2,v1,v2)
* / \
* scan(a) scan(b)
* transformed:
* project(a.k1,b.k2)
* |
* join(k1,k2,k3,v1,k1,k2,v1,v2)
* / \
* project(a.k1,a.k3) project(b.k2,b.k1)
* | |
* scan scan
*/
public class PruneJoinChildrenColumns
extends AbstractPushDownProjectRule<LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan>> {
public PruneJoinChildrenColumns() {
setRuleType(RuleType.COLUMN_PRUNE_JOIN_CHILD);
setTarget(logicalJoin());
}
@Override
protected Plan pushDownProject(LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan> joinPlan,
Set<Slot> references) {
if (joinPlan.operator.getCondition().isPresent()) {
references.addAll(SlotExtractor.extractSlot(joinPlan.operator.getCondition().get()));
}
Set<ExprId> exprIds = references.stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
List<NamedExpression> leftInputs = joinPlan.left().getOutput().stream()
.filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList());
List<NamedExpression> rightInputs = joinPlan.right().getOutput().stream()
.filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList());
Plan leftPlan = joinPlan.left();
Plan rightPlan = joinPlan.right();
if (leftInputs.size() != leftPlan.getOutput().size()) {
leftPlan = plan(new LogicalProject(leftInputs), leftPlan);
}
if (rightInputs.size() != rightPlan.getOutput().size()) {
rightPlan = plan(new LogicalProject(rightInputs), rightPlan);
}
return plan(joinPlan.operator, leftPlan, rightPlan);
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.operators.plans.logical.LogicalSort;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
import com.google.common.collect.Lists;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* prune join children output.
* pattern: project(sort())
*/
public class PruneSortChildColumns extends AbstractPushDownProjectRule<LogicalUnaryPlan<LogicalSort, GroupPlan>> {
public PruneSortChildColumns() {
setRuleType(RuleType.COLUMN_PRUNE_SORT_CHILD);
setTarget(logicalSort());
}
@Override
protected Plan pushDownProject(LogicalUnaryPlan<LogicalSort, GroupPlan> sortPlan, Set<Slot> references) {
Set<Slot> sortSlots = SlotExtractor.extractSlot(sortPlan.operator.getExpressions());
Set<Slot> required = Stream.concat(references.stream(), sortSlots.stream()).collect(Collectors.toSet());
if (required.containsAll(sortPlan.child().getOutput())) {
return sortPlan;
}
return plan(sortPlan.operator, plan(new LogicalProject(Lists.newArrayList(required)), sortPlan.child()));
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
@ -95,6 +96,11 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
return new Alias<>(childType, name);
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAlias(this, context);
}
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Alias<>(children.get(0), name);

View File

@ -32,4 +32,8 @@ public abstract class Slot extends NamedExpression implements LeafExpression {
public Slot toSlot() {
return this;
}
public Slot withNullable(boolean newNullable) {
throw new RuntimeException("Do not implement");
}
}

View File

@ -145,4 +145,11 @@ public class SlotReference extends Slot {
public SlotReference clone() {
return new SlotReference(name, getDataType(), nullable, Lists.newArrayList(qualifier));
}
public Slot withNullable(boolean newNullable) {
if (this.nullable == newNullable) {
return this;
}
return new SlotReference(exprId, name, dataType, newNullable, qualifier);
}
}

View File

@ -178,5 +178,5 @@ public abstract class IterationVisitor<C> extends DefaultExpressionVisitor<Void,
public Void visitMod(Mod mod, C context) {
return visitArithmetic(mod, context);
}
}

View File

@ -137,4 +137,3 @@ public class ExpressionUtils {
return combine(op, result);
}
}

View File

@ -0,0 +1,62 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.OptimizerContext;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
/**
* sql parse util.
*/
public class AnalyzeUtils {
private static final NereidsParser parser = new NereidsParser();
/**
* analyze sql.
*/
public static LogicalPlan analyze(String sql, ConnectContext connectContext) {
try {
LogicalPlan parsed = parser.parseSingle(sql);
return analyze(parsed, connectContext);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static LogicalPlan analyze(LogicalPlan inputPlan, ConnectContext connectContext) {
Memo memo = new Memo();
memo.initialize(inputPlan);
OptimizerContext optimizerContext = new OptimizerContext(memo);
PlannerContext plannerContext = new PlannerContext(optimizerContext, connectContext, new PhysicalProperties());
optimizerContext.pushJob(
new RewriteBottomUpJob(memo.getRoot(), new BindSlotReference().buildRules(), plannerContext));
optimizerContext.pushJob(
new RewriteBottomUpJob(memo.getRoot(), new BindRelation().buildRules(), plannerContext));
plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext);
return (LogicalPlan) memo.copyOut();
}
}

View File

@ -0,0 +1,227 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.OptimizerContext;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.operators.plans.logical.LogicalProject;
import org.apache.doris.nereids.operators.plans.logical.LogicalRelation;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.stream.Collectors;
/**
* column prune ut.
*/
public class ColumnPruningTest extends TestWithFeService {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
createTable("create table test.student (\n" + "id int not null,\n" + "name varchar(128),\n"
+ "age int,sex int)\n" + "distributed by hash(id) buckets 10\n"
+ "properties('replication_num' = '1');");
createTable("create table test.score (\n" + "sid int not null, \n" + "cid int not null, \n" + "grade double)\n"
+ "distributed by hash(sid,cid) buckets 10\n" + "properties('replication_num' = '1');");
createTable("create table test.course (\n" + "cid int not null, \n" + "cname varchar(128), \n"
+ "teacher varchar(128))\n" + "distributed by hash(cid) buckets 10\n"
+ "properties('replication_num' = '1');");
connectContext.setDatabase("default_cluster:test");
}
@Test
public void testPruneColumns1() {
String sql
= "select id,name,grade from student left join score on student.id = score.sid where score.grade > 60";
Plan plan = AnalyzeUtils.analyze(sql, connectContext);
Memo memo = new Memo();
memo.initialize(plan);
Plan out = process(memo);
System.out.println(out.treeString());
Plan l1 = out.child(0).child(0);
Plan l20 = l1.child(0).child(0);
Plan l21 = l1.child(0).child(1);
LogicalProject p1 = (LogicalProject) l1.getOperator();
LogicalProject p20 = (LogicalProject) l20.getOperator();
LogicalProject p21 = (LogicalProject) l21.getOperator();
List<String> target;
List<String> source;
source = getStringList(p1);
target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.student.id",
"default_cluster:test.score.grade");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p20);
target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p21);
target = Lists.newArrayList("default_cluster:test.score.sid", "default_cluster:test.score.grade");
Assertions.assertTrue(source.containsAll(target));
}
@Test
public void testPruneColumns2() {
String sql
= "select name,sex,cid,grade from student left join score on student.id = score.sid "
+ "where score.grade > 60";
Plan plan = AnalyzeUtils.analyze(sql, connectContext);
Memo memo = new Memo();
memo.initialize(plan);
Plan out = process(memo);
Plan l1 = out.child(0).child(0);
Plan l20 = l1.child(0).child(0);
Plan l21 = l1.child(0).child(1);
LogicalProject p1 = (LogicalProject) l1.getOperator();
LogicalProject p20 = (LogicalProject) l20.getOperator();
Assertions.assertTrue(l21.getOperator() instanceof LogicalRelation);
List<String> target;
List<String> source;
source = getStringList(p1);
target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.score.cid",
"default_cluster:test.score.grade", "default_cluster:test.student.sex");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p20);
target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name",
"default_cluster:test.student.sex");
Assertions.assertTrue(source.containsAll(target));
}
@Test
public void testPruneColumns3() {
String sql = "select id,name from student where age > 18";
Plan plan = AnalyzeUtils.analyze(sql, connectContext);
Memo memo = new Memo();
memo.initialize(plan);
Plan out = process(memo);
Plan l1 = out.child(0).child(0);
LogicalProject p1 = (LogicalProject) l1.getOperator();
List<String> target;
List<String> source;
source = getStringList(p1);
target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.student.id",
"default_cluster:test.student.age");
Assertions.assertTrue(source.containsAll(target));
}
@Test
public void testPruneColumns4() {
String sql
= "select name,cname,grade from student left join score on student.id = score.sid left join course "
+ "on score.cid = course.cid where score.grade > 60";
Plan plan = AnalyzeUtils.analyze(sql, connectContext);
Memo memo = new Memo();
memo.initialize(plan);
Plan out = process(memo);
Plan l1 = out.child(0).child(0);
Plan l20 = l1.child(0).child(0);
Plan l21 = l1.child(0).child(1);
Plan l20Left = l20.child(0).child(0);
Plan l20Right = l20.child(0).child(1);
Assertions.assertTrue(l20.getOperator() instanceof LogicalProject);
Assertions.assertTrue(l20Left.getOperator() instanceof LogicalProject);
Assertions.assertTrue(l20Right.getOperator() instanceof LogicalRelation);
LogicalProject p1 = (LogicalProject) l1.getOperator();
LogicalProject p20 = (LogicalProject) l20.getOperator();
LogicalProject p21 = (LogicalProject) l21.getOperator();
LogicalProject p20lo = (LogicalProject) l20Left.getOperator();
List<String> target;
List<String> source;
source = getStringList(p1);
target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.course.cname",
"default_cluster:test.score.grade");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p20);
target = Lists.newArrayList("default_cluster:test.student.name", "default_cluster:test.score.cid",
"default_cluster:test.score.grade");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p21);
target = Lists.newArrayList("default_cluster:test.course.cid", "default_cluster:test.course.cname");
Assertions.assertTrue(source.containsAll(target));
source = getStringList(p20lo);
target = Lists.newArrayList("default_cluster:test.student.id", "default_cluster:test.student.name");
Assertions.assertTrue(source.containsAll(target));
}
private Plan process(Memo memo) {
OptimizerContext optimizerContext = new OptimizerContext(memo);
PlannerContext plannerContext = new PlannerContext(optimizerContext, connectContext, new PhysicalProperties());
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), new ColumnPruning().buildRules(),
plannerContext);
plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob);
plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext);
return memo.copyOut();
}
private List<String> getStringList(LogicalProject p) {
return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList());
}
}