[ehancement](nereids) eliminate project in the post process phase (#14490)

Remove those projects that used for column pruning only and don't do any expression calculation, So that we could avoid some redundant data copy in do_projection of BE side.
This commit is contained in:
Kikyou1997
2022-11-28 00:39:36 +08:00
committed by GitHub
parent 280f8be4bd
commit b6605b99aa
4 changed files with 141 additions and 19 deletions

View File

@ -38,6 +38,7 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -80,6 +81,7 @@ import org.apache.doris.planner.EmptySetNode;
import org.apache.doris.planner.ExchangeNode;
import org.apache.doris.planner.HashJoinNode;
import org.apache.doris.planner.HashJoinNode.DistributionMode;
import org.apache.doris.planner.JoinNodeBase;
import org.apache.doris.planner.NestedLoopJoinNode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
@ -135,6 +137,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
rootFragment = exchangeToMergeFragment(rootFragment, context);
}
List<Expr> outputExprs = Lists.newArrayList();
if (physicalPlan instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) physicalPlan;
if (isUnnecessaryProject(project) && !projectOnAgg(project)) {
List<Slot> slotReferences = removeAlias(project);
physicalPlan = (PhysicalPlan) physicalPlan.child(0).withOutput(slotReferences);
}
}
physicalPlan.getOutput().stream().map(Slot::getExprId)
.forEach(exprId -> outputExprs.add(context.findSlotRef(exprId)));
rootFragment.setOutputExprs(outputExprs);
@ -678,6 +687,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
if (hashJoin.getOtherJoinConjuncts().isEmpty()
&& (joinType == JoinType.LEFT_ANTI_JOIN || joinType == JoinType.LEFT_SEMI_JOIN)) {
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
if (!leftSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
@ -685,12 +697,18 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
} else if (hashJoin.getOtherJoinConjuncts().isEmpty()
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
}
} else {
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
if (!leftSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
@ -699,6 +717,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
leftIntermediateSlotDescriptor.add(sd);
}
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) {
@ -824,6 +845,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
if (nestedLoopJoinNode.getConjuncts().isEmpty()
&& (joinType == JoinType.LEFT_ANTI_JOIN || joinType == JoinType.LEFT_SEMI_JOIN)) {
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
if (!leftSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
@ -831,17 +855,26 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
} else if (nestedLoopJoinNode.getConjuncts().isEmpty()
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
}
} else {
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
if (!leftSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
}
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
continue;
}
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
@ -904,43 +937,51 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
}
PlanFragment inputFragment = project.child(0).accept(this, context);
List<Expr> execExprList = project.getProjects()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
// TODO: fix the project alias of an aliased relation.
List<Slot> slotList = project.getOutput();
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Slot> slotList = project.getOutput();
// For hash join node, use vSrcToOutputSMap to describe the expression calculation, use
// vIntermediateTupleDescList as input, and set vOutputTupleDesc as the final output.
// TODO: HashJoinNode's be implementation is not support projection yet, remove this after when supported.
if (inputPlanNode instanceof HashJoinNode) {
HashJoinNode hashJoinNode = (HashJoinNode) inputPlanNode;
if (inputPlanNode instanceof JoinNodeBase) {
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
JoinNodeBase hashJoinNode = (JoinNodeBase) inputPlanNode;
hashJoinNode.setvOutputTupleDesc(tupleDescriptor);
hashJoinNode.setvSrcToOutputSMap(execExprList);
return inputFragment;
}
if (inputPlanNode instanceof NestedLoopJoinNode) {
NestedLoopJoinNode nestedLoopJoinNode = (NestedLoopJoinNode) inputPlanNode;
nestedLoopJoinNode.setvOutputTupleDesc(tupleDescriptor);
nestedLoopJoinNode.setvSrcToOutputSMap(execExprList);
return inputFragment;
}
inputPlanNode.setProjectList(execExprList);
inputPlanNode.setOutputTupleDesc(tupleDescriptor);
List<Expr> predicateList = inputPlanNode.getConjuncts();
Set<Integer> requiredSlotIdList = new HashSet<>();
for (Expr expr : predicateList) {
extractExecSlot(expr, requiredSlotIdList);
}
boolean nonPredicate = CollectionUtils.isEmpty(requiredSlotIdList);
for (Expr expr : execExprList) {
extractExecSlot(expr, requiredSlotIdList);
}
if (!hasExprCalc(project) && (!hasPrune(project) || nonPredicate) && !projectOnAgg(project)) {
List<NamedExpression> namedExpressions = project.getProjects();
for (int i = 0; i < namedExpressions.size(); i++) {
NamedExpression n = namedExpressions.get(i);
for (Expression e : n.children()) {
SlotReference slotReference = (SlotReference) e;
SlotRef slotRef = context.findSlotRef(slotReference.getExprId());
context.addExprIdSlotRefPair(slotList.get(i).getExprId(), slotRef);
}
}
} else {
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
inputPlanNode.setProjectList(execExprList);
inputPlanNode.setOutputTupleDesc(tupleDescriptor);
}
if (inputPlanNode instanceof OlapScanNode) {
updateChildSlotsMaterialization(inputPlanNode, requiredSlotIdList, context);
return inputFragment;
}
return inputFragment;
}
@ -1249,4 +1290,50 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(slot -> slot.getId().asInt())
.collect(ImmutableList.toImmutableList());
}
private boolean isUnnecessaryProject(PhysicalProject project) {
// The project list for agg is always needed,since tuple of agg contains the slots used by group by expr
return !hasPrune(project) && !hasExprCalc(project);
}
private boolean hasPrune(PhysicalProject project) {
PhysicalPlan child = (PhysicalPlan) project.child(0);
return project.getProjects().size() != child.getOutput().size();
}
private boolean projectOnAgg(PhysicalProject project) {
PhysicalPlan child = (PhysicalPlan) project.child(0);
while (child instanceof PhysicalFilter || child instanceof PhysicalDistribute) {
child = (PhysicalPlan) child.child(0);
}
return child instanceof PhysicalAggregate;
}
private boolean hasExprCalc(PhysicalProject<? extends Plan> project) {
for (NamedExpression p : project.getProjects()) {
if (p.children().size() > 1) {
return true;
}
for (Expression e : p.children()) {
if (!(e instanceof SlotReference)) {
return true;
}
}
}
return false;
}
private List<Slot> removeAlias(PhysicalProject project) {
List<NamedExpression> namedExpressions = project.getProjects();
List<Slot> slotReferences = new ArrayList<>();
for (NamedExpression n : namedExpressions) {
if (n instanceof Alias) {
slotReferences.add((SlotReference) n.child(0));
} else {
slotReferences.add((SlotReference) n);
}
}
return slotReferences;
}
}

View File

@ -46,7 +46,7 @@ public class PlanPostProcessors {
public PhysicalPlan process(PhysicalPlan physicalPlan) {
PhysicalPlan resultPlan = physicalPlan;
for (PlanPostProcessor processor : getProcessors()) {
resultPlan = (PhysicalPlan) physicalPlan.accept(processor, cascadesContext);
resultPlan = (PhysicalPlan) resultPlan.accept(processor, cascadesContext);
}
return resultPlan;
}

View File

@ -26,19 +26,39 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
/**
* test ELIMINATE_UNNECESSARY_PROJECT rule.
*/
public class EliminateUnnecessaryProjectTest {
public class EliminateUnnecessaryProjectTest extends TestWithFeService {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTable("CREATE TABLE t1 (col1 int not null, col2 int not null, col3 int not null)\n"
+ "DISTRIBUTED BY HASH(col3)\n"
+ "BUCKETS 1\n"
+ "PROPERTIES(\n"
+ " \"replication_num\"=\"1\"\n"
+ ");");
}
@Test
public void testEliminateNonTopUnnecessaryProject() {
@ -82,4 +102,19 @@ public class EliminateUnnecessaryProjectTest {
Plan actual = cascadesContext.getMemo().copyOut();
Assertions.assertTrue(actual instanceof LogicalProject);
}
@Test
public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
p -> {
List<PlanFragment> fragments = p.getFragments();
Assertions.assertTrue(fragments.stream()
.flatMap(fragment -> {
Set<OlapScanNode> scans = Sets.newHashSet();
fragment.getPlanRoot().collect(OlapScanNode.class, scans);
return scans.stream();
})
.noneMatch(s -> s.getProjectList() != null));
});
}
}

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
suite("explain") {
suite("nereids_explain") {
sql """
SET enable_vectorized_engine=true
"""
@ -28,8 +28,8 @@ suite("explain") {
explain {
sql("select count(2) + 1, sum(2) + sum(lo_suppkey) from lineorder")
contains "projections: lo_suppkey"
contains "project output tuple id: 1"
contains "(sum(2) + sum(lo_suppkey))[#24]"
contains "project output tuple id: 3"
}