[Improvement](Planner)Enable hash join project (#8618)

This commit is contained in:
EmmyMiao87
2022-04-01 15:42:25 +08:00
committed by GitHub
parent 2730235e5b
commit 9f80f6cf5e
10 changed files with 406 additions and 21 deletions

View File

@ -22,6 +22,7 @@ import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
import org.apache.doris.thrift.TAggregationNode;
import org.apache.doris.thrift.TExplainLevel;
@ -29,17 +30,17 @@ import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TPlanNode;
import org.apache.doris.thrift.TPlanNodeType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
//import org.apache.doris.thrift.TAggregateFunctionCall;
import java.util.Set;
/**
* Aggregation computation.
@ -327,4 +328,21 @@ public class AggregationNode extends PlanNode {
public int getNumInstances() {
return children.get(0).getNumInstances();
}
@Override
public Set<SlotId> computeInputSlotIds() throws NotImplementedException {
Set<SlotId> result = Sets.newHashSet();
// compute group by slot
ArrayList<Expr> groupingExprs = aggInfo.getGroupingExprs();
List<SlotId> groupingSlotIds = Lists.newArrayList();
Expr.getIds(groupingExprs, null, groupingSlotIds);
result.addAll(groupingSlotIds);
// compute agg function slot
ArrayList<FunctionCallExpr> aggregateExprs = aggInfo.getAggregateExprs();
List<SlotId> aggregateSlotIds = Lists.newArrayList();
Expr.getIds(aggregateExprs, null, aggregateSlotIds);
result.addAll(aggregateSlotIds);
return result;
}
}

View File

@ -31,6 +31,7 @@ import org.apache.doris.catalog.ColumnStats;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.Pair;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.VectorizedUtil;
@ -43,6 +44,7 @@ import org.apache.doris.thrift.TPlanNodeType;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -51,6 +53,7 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
@ -76,6 +79,8 @@ public class HashJoinNode extends PlanNode {
private String colocateReason = ""; // if can not do colocate join, set reason here
private boolean isBucketShuffle = false; // the flag for bucket shuffle join
private List<SlotId> hashOutputSlotIds;
public HashJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRef innerRef,
List<Expr> eqJoinConjuncts, List<Expr> otherJoinConjuncts) {
super(id, "HASH JOIN");
@ -164,6 +169,67 @@ public class HashJoinNode extends PlanNode {
isColocate = colocate;
colocateReason = reason;
}
/**
* Calculate the slots output after going through the hash table in the hash join node.
* The most essential difference between 'hashOutputSlots' and 'outputSlots' is that
* it's output needs to contain other conjunct and conjunct columns.
* hash output slots = output slots + conjunct slots + other conjunct slots
* For example:
* select b.k1 from test.t1 a right join test.t1 b on a.k1=b.k1 and b.k2>1 where a.k2>1;
* output slots: b.k1
* other conjuncts: a.k2>1
* conjuncts: b.k2>1
* hash output slots: a.k2, b.k2, b.k1
* eq conjuncts: a.k1=b.k1
* @param slotIdList
*/
private void initHashOutputSlotIds(List<SlotId> slotIdList) {
hashOutputSlotIds = new ArrayList<>(slotIdList);
List<SlotId> otherAndConjunctSlotIds = Lists.newArrayList();
Expr.getIds(otherJoinConjuncts, null, otherAndConjunctSlotIds);
Expr.getIds(conjuncts, null, otherAndConjunctSlotIds);
for (SlotId slotId : otherAndConjunctSlotIds) {
if (!hashOutputSlotIds.contains(slotId)) {
hashOutputSlotIds.add(slotId);
}
}
}
@Override
public void initOutputSlotIds(Set<SlotId> requiredSlotIdSet, Analyzer analyzer) {
outputSlotIds = Lists.newArrayList();
for (TupleId tupleId : tupleIds) {
for (SlotDescriptor slotDescriptor : analyzer.getTupleDesc(tupleId).getSlots()) {
if (slotDescriptor.isMaterialized() &&
(requiredSlotIdSet == null || requiredSlotIdSet.contains(slotDescriptor.getId()))) {
outputSlotIds.add(slotDescriptor.getId());
}
}
}
initHashOutputSlotIds(outputSlotIds);
}
// output slots + predicate slots = input slots
@Override
public Set<SlotId> computeInputSlotIds() throws NotImplementedException {
Preconditions.checkState(outputSlotIds != null);
Set<SlotId> result = Sets.newHashSet();
result.addAll(outputSlotIds);
// eq conjunct
List<SlotId> eqConjunctSlotIds = Lists.newArrayList();
Expr.getIds(eqJoinConjuncts, null, eqConjunctSlotIds);
result.addAll(eqConjunctSlotIds);
// other conjunct
List<SlotId> otherConjunctSlotIds = Lists.newArrayList();
Expr.getIds(otherJoinConjuncts, null, otherConjunctSlotIds);
result.addAll(otherConjunctSlotIds);
// conjunct
List<SlotId> conjunctSlotIds = Lists.newArrayList();
Expr.getIds(conjuncts, null, conjunctSlotIds);
result.addAll(conjunctSlotIds);
return result;
}
@Override
public void init(Analyzer analyzer) throws UserException {
@ -607,6 +673,11 @@ public class HashJoinNode extends PlanNode {
if (votherJoinConjunct != null) {
msg.hash_join_node.setVotherJoinConjunct(votherJoinConjunct.treeToThrift());
}
if (hashOutputSlotIds != null) {
for (SlotId slotId : hashOutputSlotIds) {
msg.hash_join_node.addToHashOutputSlotIds(slotId.asInt());
}
}
}
@Override
@ -638,6 +709,21 @@ public class HashJoinNode extends PlanNode {
}
output.append(detailPrefix).append(String.format(
"cardinality=%s", cardinality)).append("\n");
// todo unify in plan node
if (outputSlotIds != null) {
output.append(detailPrefix).append("output slot ids: ");
for (SlotId slotId : outputSlotIds) {
output.append(slotId).append(" ");
}
output.append("\n");
}
if (hashOutputSlotIds != null) {
output.append(detailPrefix).append("hash output slot ids: ");
for (SlotId slotId : hashOutputSlotIds) {
output.append(slotId).append(" ");
}
output.append("\n");
}
return output.toString();
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.TreeNode;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.VectorizedUtil;
@ -129,6 +130,8 @@ abstract public class PlanNode extends TreeNode<PlanNode> {
private boolean cardinalityIsDone = false;
protected List<SlotId> outputSlotIds;
protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String planNodeName) {
this.id = id;
this.limit = -1;
@ -491,6 +494,11 @@ abstract public class PlanNode extends TreeNode<PlanNode> {
}
msg.compact_data = compactData;
if (outputSlotIds != null) {
for (SlotId slotId : outputSlotIds) {
msg.addToOutputSlotIds(slotId.asInt());
}
}
toThrift(msg);
container.addToNodes(msg);
if (this instanceof ExchangeNode) {
@ -851,15 +859,6 @@ abstract public class PlanNode extends TreeNode<PlanNode> {
return Joiner.on(", ").join(filtersStr) + "\n";
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[").append(getId().asInt()).append(": ").append(getPlanNodeName()).append("]");
sb.append("\nFragment: ").append(getFragmentId().asInt()).append("]");
sb.append("\n").append(getNodeExplainString("", TExplainLevel.BRIEF));
return sb.toString();
}
public void convertToVectoriezd() {
if (!conjuncts.isEmpty()) {
vconjunct = convertConjunctsToAndCompoundPredicate(conjuncts);
@ -870,4 +869,64 @@ abstract public class PlanNode extends TreeNode<PlanNode> {
child.convertToVectoriezd();
}
}
/**
* If an plan node implements this method, the plan node itself supports project optimization.
* @param requiredSlotIdSet: The upper plan node's requirement slot set for the current plan node.
* The requiredSlotIdSet could be null when the upper plan node cannot
* calculate the required slot.
* @param analyzer
* @throws NotImplementedException
*
* For example:
* Query: select a.k1 from a, b where a.k1=b.k1
* PlanNodeTree:
* output exprs: a.k1
* |
* hash join node
* (input slots: a.k1, b.k1)
* | |
* scan a(k1) scan b(k1)
*
* Function params: requiredSlotIdSet = a.k1
* After function:
* hash join node
* (output slots: a.k1)
* (input slots: a.k1, b.k1)
*/
public void initOutputSlotIds(Set<SlotId> requiredSlotIdSet, Analyzer analyzer) throws NotImplementedException {
throw new NotImplementedException("The `initOutputSlotIds` hasn't been implemented in " + planNodeName);
}
/**
* If an plan node implements this method, its child plan node has the ability to implement the project.
* The return value of this method will be used as
* the input(requiredSlotIdSet) of child plan node method initOutputSlotIds.
* That is to say, only when the plan node implements this method,
* its children can realize project optimization.
*
* @return The requiredSlotIdSet of this plan node
* @throws NotImplementedException
* PlanNodeTree:
* agg node(group by a.k1)
* |
* hash join node(a.k1=b.k1)
* | |
* scan a(k1) scan b(k1)
* After function:
* agg node
* (required slots: a.k1)
*/
public Set<SlotId> computeInputSlotIds() throws NotImplementedException {
throw new NotImplementedException("The `computeInputSlotIds` hasn't been implemented in " + planNodeName);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[").append(getId().asInt()).append(": ").append(getPlanNodeName()).append("]");
sb.append("\nFragment: ").append(getFragmentId().asInt()).append("]");
sb.append("\n").append(getNodeExplainString("", TExplainLevel.BRIEF));
return sb.toString();
}
}

View File

@ -30,7 +30,6 @@ import org.apache.doris.analysis.StorageBackend;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.common.UserException;
import org.apache.doris.common.profile.PlanTreeBuilder;
import org.apache.doris.common.profile.PlanTreePrinter;
@ -173,6 +172,13 @@ public class Planner {
singleNodePlan.convertToVectoriezd();
}
if (analyzer.getContext() != null
&& analyzer.getContext().getSessionVariable().isEnableProjection()
&& statement instanceof SelectStmt) {
ProjectPlanner projectPlanner = new ProjectPlanner(analyzer);
projectPlanner.projectSingleNodePlan(queryStmt.getResultExprs(), singleNodePlan);
}
if (statement instanceof InsertStmt) {
InsertStmt insertStmt = (InsertStmt) statement;
insertStmt.prepareExpressions();

View File

@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.planner;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.common.NotImplementedException;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
import java.util.Set;
public class ProjectPlanner {
private final static Logger LOG = LogManager.getLogger(PlanNode.class);
private Analyzer analyzer;
public ProjectPlanner(Analyzer analyzer) {
this.analyzer = analyzer;
}
public void projectSingleNodePlan(List<Expr> resultExprs, PlanNode root) {
Set<SlotId> resultSlotIds = getSlotIds(resultExprs, root);
projectPlanNode(resultSlotIds, root);
}
public void projectPlanNode(Set<SlotId> outputSlotIds, PlanNode planNode) {
try {
planNode.initOutputSlotIds(outputSlotIds, analyzer);
} catch (NotImplementedException e) {
LOG.debug(e);
}
if (planNode.getChildren().size() == 0) {
return;
}
Set<SlotId> inputSlotIds = null;
try {
inputSlotIds = planNode.computeInputSlotIds();
} catch (NotImplementedException e) {
LOG.debug(e);
}
for (PlanNode child : planNode.getChildren()) {
projectPlanNode(inputSlotIds, child);
}
}
private Set<SlotId> getSlotIds(List<Expr> resultExprs, PlanNode root) {
List<Expr> resExprs = Expr.substituteList(resultExprs,
root.getOutputSmap(), analyzer, false);
Set<SlotId> result = Sets.newHashSet();
for (Expr expr : resExprs) {
List<SlotId> slotIdList = Lists.newArrayList();
expr.getIds(null, slotIdList);
result.addAll(slotIdList);
}
return result;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
import org.apache.doris.thrift.TExplainLevel;
import org.apache.doris.thrift.TPlanNode;
@ -31,16 +32,18 @@ import org.apache.doris.thrift.TPlanNodeType;
import org.apache.doris.thrift.TSortInfo;
import org.apache.doris.thrift.TSortNode;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
/**
* Sorting.
@ -148,6 +151,13 @@ public class SortNode extends PlanNode {
LOG.debug("stats Sort: cardinality=" + Long.toString(cardinality));
}
@Override
public Set<SlotId> computeInputSlotIds() throws NotImplementedException {
List<SlotId> result = Lists.newArrayList();
Expr.getIds(resolvedTupleExprs, null, result);
return new HashSet<>(result);
}
@Override
protected String debugString() {
List<String> strings = Lists.newArrayList();
@ -155,8 +165,8 @@ public class SortNode extends PlanNode {
strings.add(isAsc ? "a" : "d");
}
return MoreObjects.toStringHelper(this).add("ordering_exprs",
Expr.debugString(info.getOrderingExprs())).add("is_asc",
"[" + Joiner.on(" ").join(strings) + "]").addValue(super.debugString()).toString();
Expr.debugString(info.getOrderingExprs())).add("is_asc",
"[" + Joiner.on(" ").join(strings) + "]").addValue(super.debugString()).toString();
}
@Override

View File

@ -84,6 +84,7 @@ public class TableFunctionNode extends PlanNode {
* Query: select k1 from table a lateral view explode_split(v1, ",") t1 as c1;
* The outputSlots: [k1, c1]
*/
// TODO(ml): Unified to projectplanner
public void projectSlots(Analyzer analyzer, SelectStmt selectStmt) throws AnalysisException {
// TODO(ml): Support project calculations that include aggregation and sorting in select stmt
if ((selectStmt.hasAggInfo() || selectStmt.getSortInfo() != null || selectStmt.hasAnalyticInfo())

View File

@ -176,6 +176,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String BLOCK_ENCRYPTION_MODE = "block_encryption_mode";
public static final String ENABLE_PROJECTION = "enable_projection";
// session origin value
public Map<Field, String> sessionOriginValue = new HashMap<Field, String>();
// check stmt is or not [select /*+ SET_VAR(...)*/ ...]
@ -429,6 +431,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = BLOCK_ENCRYPTION_MODE)
private String blockEncryptionMode = "";
@VariableMgr.VarAttr(name = ENABLE_PROJECTION)
private boolean enableProjection = false;
public String getBlockEncryptionMode() {
return blockEncryptionMode;
}
@ -893,6 +898,10 @@ public class SessionVariable implements Serializable, Writable {
public void setEnableInferPredicate(boolean enableInferPredicate) { this.enableInferPredicate = enableInferPredicate; }
public boolean isEnableProjection() {
return enableProjection;
}
// Serialize to thrift object
// used for rest api
public TQueryOptions toThrift() {

View File

@ -0,0 +1,111 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.planner;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.util.UUID;
public class ProjectPlannerFunctionTest {
private static String runningDir = "fe/mocked/ProjectPlannerFunctionTest/" + UUID.randomUUID().toString() + "/";
private static ConnectContext connectContext;
@BeforeClass
public static void beforeClass() throws Exception {
UtFrameUtils.createDorisCluster(runningDir);
// create connect context
connectContext = UtFrameUtils.createDefaultCtx();
// enable hash project
Deencapsulation.setField(connectContext.getSessionVariable(), "enableHashProject", true);
// create database
String createDbStmtStr = "create database test;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, connectContext);
Catalog.getCurrentCatalog().createDb(createDbStmt);
String createTableStmtStr = "create table test.t1 (k1 int, k2 int) distributed by hash (k1) "
+ "properties(\"replication_num\" = \"1\")";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTableStmtStr, connectContext);
Catalog.getCurrentCatalog().createTable(createTableStmt);
}
@AfterClass
public static void tearDown() {
File file = new File(runningDir);
file.delete();
}
// keep a.k2 after a join b
@Test
public void projectByAgg() throws Exception {
String queryStr = "desc verbose select a.k2 from test.t1 a , test.t1 b where a.k1=b.k1 group by a.k2;";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("output slot ids: 0"));
}
// keep a.k2 after a join b
@Test
public void projectBySort() throws Exception {
String queryStr = "desc verbose select a.k2 from test.t1 a , test.t1 b where a.k1=b.k1 order by a.k2;";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("output slot ids: 0"));
}
// keep a.k2 after a join c
// keep a.k1, a.k2 after a join b
@Test
public void projectByJoin() throws Exception {
String queryStr = "desc verbose select a.k2 from test.t1 a inner join test.t1 b on a.k1=b.k1 "
+ "inner join test.t1 c on a.k1=c.k1;";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("output slot ids: 3"));
Assert.assertTrue(explainString.contains("output slot ids: 0 3"));
}
// keep a.k2 after a join b
@Test
public void projectByResultExprs() throws Exception {
String queryStr = "desc verbose select a.k2 from test.t1 a , test.t1 b where a.k1=b.k1;";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("output slot ids: 0"));
}
// keep b.k1 after a join b
// keep a.k2, b.k1, b.k2 after <a,b> hash table
@Test
public void projectHashTable() throws Exception {
String queryStr = "desc verbose select b.k1 from test.t1 a right join test.t1 b on a.k1=b.k1 and b.k2>1 where a.k2>1;";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("output slot ids: 1"));
Assert.assertTrue(explainString.contains("hash output slot ids: 1 2 3"));
}
}

View File

@ -393,7 +393,10 @@ struct THashJoinNode {
// anything from the ON or USING clauses (but *not* the WHERE clause) that's not an
// equi-join predicate, only use in vec exec engine
5: optional Exprs.TExpr vother_join_conjunct
5: optional Exprs.TExpr vother_join_conjunct
// hash output column
6: optional list<Types.TSlotId> hash_output_slot_ids
}
struct TMergeJoinNode {
@ -789,6 +792,9 @@ struct TPlanNode {
40: optional Exprs.TExpr vconjunct
41: optional TTableFunctionNode table_function_node
// output column
42: optional list<Types.TSlotId> output_slot_ids
}
// A flattened representation of a tree of PlanNodes, obtained by depth-first