[feature](nereids) merge proj-proj in post process (#14730)

* merge proj-proj

* v2this pr guarantees that the physical plan does not contains consecutive physical projects.
Like rewrite rule "merge projects", it works on physical plan, not logical plan.

* move merge-proj code into Project.java
This commit is contained in:
minghong
2022-12-03 23:41:02 +08:00
committed by GitHub
parent 7bb2343505
commit 97dcd2b13a
5 changed files with 242 additions and 70 deletions

View File

@ -0,0 +1,42 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import java.util.List;
/**
* merge consecutive projects
*/
public class MergeProjectPostProcessor extends PlanPostProcessor {
@Override
public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
Plan child = project.child();
child = child.accept(this, ctx);
if (child instanceof PhysicalProject) {
List<NamedExpression> projections = project.mergeProjections((PhysicalProject) child);
return project.withProjectionsAndChild(projections, child.child(0));
}
return project;
}
}

View File

@ -20,19 +20,11 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* this rule aims to merge consecutive filters.
@ -50,74 +42,14 @@ import java.util.stream.Collectors;
* |
* scan
*/
public class MergeProjects extends OneRewriteRuleFactory {
private static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
public Expression replace(Expression expr, Map<Expression, Expression> substitutionMap) {
if (expr instanceof SlotReference) {
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
return substitutionMap.getOrDefault(ref, expr);
}
return visit(expr, substitutionMap);
}
/**
* case 1:
* project(alias(c) as d, alias(x) as y)
* |
* | ===> project(alias(a) as d, alias(b) as y)
* |
* project(slotRef(a) as c, slotRef(b) as x)
* case 2:
* project(slotRef(x.c), slotRef(x.d))
* | ===> project(slotRef(a) as x.c, slotRef(b) as x.d)
* project(slotRef(a) as c, slotRef(b) as d)
* case 3: others
*/
@Override
public Expression visit(Expression expr, Map<Expression, Expression> substitutionMap) {
if (expr instanceof Alias && expr.child(0) instanceof SlotReference) {
// case 1:
Expression c = expr.child(0);
// Alias doesn't contain qualifier
Slot ref = ((SlotReference) c).withQualifier(Collections.emptyList());
if (substitutionMap.containsKey(ref)) {
return expr.withChildren(substitutionMap.get(ref).children());
}
} else if (expr instanceof SlotReference) {
// case 2:
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
if (substitutionMap.containsKey(ref)) {
Alias res = (Alias) substitutionMap.get(ref);
return res.child();
}
} else if (substitutionMap.containsKey(expr)) {
return substitutionMap.get(expr).child(0);
}
return super.visit(expr, substitutionMap);
}
}
@Override
public Rule build() {
return logicalProject(logicalProject()).then(project -> {
List<NamedExpression> projectExpressions = project.getProjects();
LogicalProject<GroupPlan> childProject = project.child();
List<NamedExpression> childProjectExpressions = childProject.getProjects();
Map<Expression, Expression> childAliasMap = childProjectExpressions.stream()
.filter(e -> e instanceof Alias)
.collect(Collectors.toMap(
NamedExpression::toSlot, e -> e)
);
projectExpressions = projectExpressions.stream()
.map(e -> MergeProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new LogicalProject<>(projectExpressions, childProject.children().get(0));
List<NamedExpression> projectExpressions = project.mergeProjections(childProject);
return new LogicalProject<>(projectExpressions, childProject.child(0));
}).toRule(RuleType.MERGE_PROJECTS);
}
}

View File

@ -21,7 +21,10 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -55,4 +58,79 @@ public interface Project {
)
);
}
/**
* combine upper level and bottom level projections
* 1. alias combination, for example
* proj(x as y, b) --> proj(a as x, b, c) =>(a as y, b)
* 2. remove used projection in bottom project
* @param childProject bottom project
* @return project list for merged project
*/
default List<NamedExpression> mergeProjections(Project childProject) {
List<NamedExpression> thisProjectExpressions = getProjects();
List<NamedExpression> childProjectExpressions = childProject.getProjects();
Map<Expression, Expression> bottomAliasMap = childProjectExpressions.stream()
.filter(e -> e instanceof Alias)
.collect(Collectors.toMap(
NamedExpression::toSlot, e -> e)
);
return thisProjectExpressions.stream()
.map(e -> ExpressionReplacer.INSTANCE.replace(e, bottomAliasMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
}
/**
* replace alias
*/
public static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
public Expression replace(Expression expr, Map<Expression, Expression> substitutionMap) {
if (expr instanceof SlotReference) {
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
return substitutionMap.getOrDefault(ref, expr);
}
return visit(expr, substitutionMap);
}
/**
* case 1:
* project(alias(c) as d, alias(x) as y)
* |
* | ===> project(alias(a) as d, alias(b) as y)
* |
* project(slotRef(a) as c, slotRef(b) as x)
* case 2:
* project(slotRef(x.c), slotRef(x.d))
* | ===> project(slotRef(a) as x.c, slotRef(b) as x.d)
* project(slotRef(a) as c, slotRef(b) as d)
* case 3: others
*/
@Override
public Expression visit(Expression expr, Map<Expression, Expression> substitutionMap) {
if (expr instanceof Alias && expr.child(0) instanceof SlotReference) {
// case 1:
Expression c = expr.child(0);
// Alias doesn't contain qualifier
Slot ref = ((SlotReference) c).withQualifier(Collections.emptyList());
if (substitutionMap.containsKey(ref)) {
return expr.withChildren(substitutionMap.get(ref).children());
}
} else if (expr instanceof SlotReference) {
// case 2:
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
if (substitutionMap.containsKey(ref)) {
Alias res = (Alias) substitutionMap.get(ref);
return res.child();
}
} else if (substitutionMap.containsKey(expr)) {
return substitutionMap.get(expr).child(0);
}
return super.visit(expr, substitutionMap);
}
}
}

View File

@ -122,4 +122,20 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
return new PhysicalProject<>(projects, Optional.empty(), getLogicalProperties(), physicalProperties,
statsDeriveResult, child());
}
/**
* replace projections and child, it is used for merge consecutive projections.
* @param projections new projections
* @param child new child
* @return new project
*/
public PhysicalProject<Plan> withProjectionsAndChild(List<NamedExpression> projections, Plan child) {
return new PhysicalProject<Plan>(ImmutableList.copyOf(projections),
groupExpression,
getLogicalProperties(),
physicalProperties,
statsDeriveResult,
child
);
}
}

View File

@ -0,0 +1,104 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.postprocess;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.processor.post.MergeProjectPostProcessor;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.Lists;
import mockit.Injectable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class MergeProjectPostProcessTest {
/**
* proj1(y, z)
* |
* proj2(x as y, col2 as z, col3)
* |
* proj3(col1 as x, col2, col3)
* |
* SCAN(col1, col2, col3)
*
* transform to
*
* proj4(col1 as y, col2 as z)
* |
* SCAN(col1, col2, col3)
*
*/
@Test
public void testMergeProj(@Injectable LogicalProperties placeHolder, @Injectable CascadesContext ctx) {
OlapTable t1 = PlanConstructor.newOlapTable(0, "t1", 0, KeysType.AGG_KEYS);
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
List<Slot> t1Output = new ArrayList<>();
SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
SlotReference b = new SlotReference("b", IntegerType.INSTANCE);
SlotReference c = new SlotReference("c", IntegerType.INSTANCE);
t1Output.add(a);
t1Output.add(b);
t1Output.add(c);
LogicalProperties t1Properties = new LogicalProperties(() -> t1Output);
PhysicalOlapScan scan = new PhysicalOlapScan(RelationId.createGenerator().getNextId(), t1, qualifier, 0L,
Collections.emptyList(), Collections.emptyList(), null, PreAggStatus.on(),
Optional.empty(),
t1Properties);
Alias x = new Alias(a, "x");
List<NamedExpression> projList3 = Lists.newArrayList(x, b, c);
PhysicalProject proj3 = new PhysicalProject(projList3, placeHolder, scan);
Alias y = new Alias(x.toSlot(), "y");
Alias z = new Alias(b, "z");
List<NamedExpression> projList2 = Lists.newArrayList(y, z, c);
PhysicalProject proj2 = new PhysicalProject(projList2, placeHolder, proj3);
List<NamedExpression> projList1 = Lists.newArrayList(y.toSlot(), z.toSlot());
PhysicalProject proj1 = new PhysicalProject(projList1, placeHolder, proj2);
MergeProjectPostProcessor processor = new MergeProjectPostProcessor();
PhysicalPlan newPlan = (PhysicalPlan) proj1.accept(processor, ctx);
Assertions.assertTrue(newPlan instanceof PhysicalProject);
Assertions.assertTrue(newPlan.child(0) instanceof PhysicalOlapScan);
List<NamedExpression> resProjList = ((PhysicalProject<?>) newPlan).getProjects();
Assertions.assertEquals(resProjList.size(), 2);
Assertions.assertTrue(resProjList.get(0) instanceof Alias);
Assertions.assertEquals(resProjList.get(0).getName(), "y");
Assertions.assertEquals(((Alias) resProjList.get(0)).child(), a);
Assertions.assertTrue(resProjList.get(1) instanceof Alias);
Assertions.assertEquals(resProjList.get(1).getName(), "z");
Assertions.assertEquals(((Alias) resProjList.get(1)).child(), b);
}
}