[Feature] Modify the cost evaluation algorithm of Broadcast and Shuffle Join (#6274)

issue #6272

After modifying the Hash Table memory cost estimation algorithm, TPC-DS query-72 etc. can be passed, avoiding OOM.
This commit is contained in:
Xinyi Zou
2021-07-26 09:38:41 +08:00
committed by GitHub
parent 13ef2c9e1d
commit 75d954ced5
4 changed files with 282 additions and 63 deletions

View File

@ -348,53 +348,7 @@ public class DistributedPlanner {
return leftChildFragment;
}
// broadcast: send the rightChildFragment's output to each node executing
// the leftChildFragment; the cost across all nodes is proportional to the
// total amount of data sent
// NOTICE:
// for now, only MysqlScanNode and OlapScanNode has Cardinality.
// OlapScanNode's cardinality is calculated by row num and data size,
// and MysqlScanNode's cardinality is always 0.
// Other ScanNode's cardinality is -1.
//
// So if there are other kind of scan node in join query, it won't be able to calculate the cost of
// join normally and result in both "broadcastCost" and "partitionCost" be 0. And this will lead
// to a SHUFFLE join.
PlanNode rhsTree = rightChildFragment.getPlanRoot();
long rhsDataSize = 0;
long broadcastCost = 0;
if (rhsTree.getCardinality() != -1 && leftChildFragment.getNumNodes() != -1) {
rhsDataSize = Math.round((double) rhsTree.getCardinality() * rhsTree.getAvgRowSize());
broadcastCost = rhsDataSize * leftChildFragment.getNumNodes();
}
if (LOG.isDebugEnabled()) {
LOG.debug("broadcast: cost=" + Long.toString(broadcastCost));
LOG.debug("card=" + Long.toString(rhsTree.getCardinality()) + " row_size="
+ Float.toString(rhsTree.getAvgRowSize()) + " #nodes="
+ Integer.toString(leftChildFragment.getNumNodes()));
}
// repartition: both left- and rightChildFragment are partitioned on the
// join exprs
// TODO: take existing partition of input fragments into account to avoid
// unnecessary repartitioning
PlanNode lhsTree = leftChildFragment.getPlanRoot();
long partitionCost = 0;
if (lhsTree.getCardinality() != -1 && rhsTree.getCardinality() != -1) {
partitionCost = Math.round(
(double) lhsTree.getCardinality() * lhsTree.getAvgRowSize() + (double) rhsTree
.getCardinality() * rhsTree.getAvgRowSize());
}
if (LOG.isDebugEnabled()) {
LOG.debug("partition: cost=" + Long.toString(partitionCost));
LOG.debug("lhs card=" + Long.toString(lhsTree.getCardinality()) + " row_size="
+ Float.toString(lhsTree.getAvgRowSize()));
LOG.debug("rhs card=" + Long.toString(rhsTree.getCardinality()) + " row_size="
+ Float.toString(rhsTree.getAvgRowSize()));
LOG.debug(rhsTree.getExplainString());
}
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
boolean doBroadcast;
// we do a broadcast join if
// - we're explicitly told to do so
@ -408,9 +362,9 @@ public class DistributedPlanner {
// respect user join hint
doBroadcast = true;
} else if (!node.getInnerRef().isPartitionJoin()
&& isBroadcastCostSmaller(broadcastCost, partitionCost)
&& joinCostEvaluation.isBroadcastCostSmaller()
&& (perNodeMemLimit == 0
|| Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) {
|| joinCostEvaluation.constructHashTableSpace() <= perNodeMemLimit)) {
doBroadcast = true;
} else {
doBroadcast = false;
@ -418,7 +372,6 @@ public class DistributedPlanner {
} else {
doBroadcast = false;
}
if (doBroadcast) {
node.setDistributionMode(HashJoinNode.DistributionMode.BROADCAST);
// Doesn't create a new fragment, but modifies leftChildFragment to execute

View File

@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.planner;
import org.apache.doris.qe.ConnectContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Evaluate the cost of Broadcast and Shuffle Join to choose between the two
*
* broadcast: send the rightChildFragment's output to each node executing the leftChildFragment; the cost across
* all nodes is proportional to the total amount of data sent.
* shuffle: also called Partitioned Join. That is, small tables and large tables are hashed according to the Join key,
* and then distributed Join is performed.
*
* NOTICE:
* for now, only MysqlScanNode and OlapScanNode has Cardinality. OlapScanNode's cardinality is calculated by row num
* and data size, and MysqlScanNode's cardinality is always 1. Other ScanNode's cardinality is -1.
* So if there are other kind of scan node in join query, it won't be able to calculate the cost of join normally
* and result in both "broadcastCost" and "partitionCost" be 0. And this will lead to a SHUFFLE join.
*/
public class JoinCostEvaluation {
private final static Logger LOG = LogManager.getLogger(JoinCostEvaluation.class);
private final long rhsTreeCardinality;
private final float rhsTreeAvgRowSize;
private final int rhsTreeTupleIdNum;
private final long lhsTreeCardinality;
private final float lhsTreeAvgRowSize;
private final int lhsTreeNumNodes;
private long broadcastCost = 0;
private long partitionCost = 0;
JoinCostEvaluation(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
PlanNode rhsTree = rightChildFragment.getPlanRoot();
rhsTreeCardinality = rhsTree.getCardinality();
rhsTreeAvgRowSize = rhsTree.getAvgRowSize();
rhsTreeTupleIdNum = rhsTree.getTupleIds().size();
PlanNode lhsTree = leftChildFragment.getPlanRoot();
lhsTreeCardinality = lhsTree.getCardinality();
lhsTreeAvgRowSize = lhsTree.getAvgRowSize();
lhsTreeNumNodes = leftChildFragment.getNumNodes();
String nodeOverview = setNodeOverview(node, rightChildFragment, leftChildFragment);
broadcastCost(nodeOverview);
shuffleCost(nodeOverview);
}
private String setNodeOverview(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
return "root node id=" + node.getId().toString() + ": " + node.planNodeName
+ " right fragment id=" + rightChildFragment.getFragmentId().toString()
+ " left fragment id=" + leftChildFragment.getFragmentId().toString();
}
private void broadcastCost(String nodeOverview) {
if (rhsTreeCardinality != -1 && lhsTreeNumNodes != -1) {
broadcastCost = Math.round((double) rhsTreeCardinality * rhsTreeAvgRowSize) * lhsTreeNumNodes;
}
if (LOG.isDebugEnabled()) {
LOG.debug(nodeOverview);
LOG.debug("broadcast: cost=" + Long.toString(broadcastCost));
LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality)
+ " rhs row_size=" + Float.toString(rhsTreeAvgRowSize)
+ " lhs nodes=" + Integer.toString(lhsTreeNumNodes));
}
}
/**
* repartition: both left- and rightChildFragment are partitioned on the join exprs
* TODO: take existing partition of input fragments into account to avoid unnecessary repartitioning
*/
private void shuffleCost(String nodeOverview) {
if (lhsTreeCardinality != -1 && rhsTreeCardinality != -1) {
partitionCost = Math.round(
(double) lhsTreeCardinality * lhsTreeAvgRowSize + (double) rhsTreeCardinality * rhsTreeAvgRowSize);
}
if (LOG.isDebugEnabled()) {
LOG.debug(nodeOverview);
LOG.debug("partition: cost=" + Long.toString(partitionCost));
LOG.debug("lhs card=" + Long.toString(lhsTreeCardinality) + " row_size="
+ Float.toString(lhsTreeAvgRowSize));
LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality) + " row_size="
+ Float.toString(rhsTreeAvgRowSize));
}
}
/**
* When broadcastCost and partitionCost are equal, there is no uniform standard for which join implementation
* is better. Some scenarios are suitable for broadcast join, and some scenarios are suitable for shuffle join.
* Therefore, we add a SessionVariable to help users choose a better join implementation.
*/
public boolean isBroadcastCostSmaller() {
String joinMethod = ConnectContext.get().getSessionVariable().getPreferJoinMethod();
if (joinMethod.equalsIgnoreCase("broadcast")) {
return broadcastCost <= partitionCost;
} else {
return broadcastCost < partitionCost;
}
}
/**
* Estimate the memory cost of constructing Hash Table in Broadcast Join.
* The memory cost by the Hash Table = ((cardinality/0.75[1]) * 8[2])[3] + (cardinality * avgRowSize)[4]
* + (nodeArrayLen[5] * 16[6])[7] + (nodeArrayLen * tupleNum[8] * 8[9])[10]. consists of four parts:
* 1) All bucket pointers. 2) Length of the node array. 3) Overhead of all nodes. 4) Tuple pointers of all nodes.
* - [1] Expansion factor of the number of HashTable buckets;
* - [2] The pointer length of each bucket of HashTable;
* - [3] bucketPointerSpace: The memory cost by all bucket pointers of HashTable;
* - [4] rhsDataSize: The memory cost by all nodes of HashTable, equal to the amount of data that the right table
* participates in the construction of HashTable;
* - [5] HashTable stores the length of the node array, which is larger than the actual cardinality. The initial
* value is 4096. When the storage is full, one-half of the current array length is added each time.
* The length of the array after each increment is actually a sequence of numbers:
* 4096 = pow(3/2, 0) * 4096,
* 6144 = pow(3/2, 1) * 4096,
* 9216 = pow(3/2, 2) * 4096,
* 13824 = pow(3/2, 3) * 4096,
* finally need to satisfy len(node array)> cardinality,
* so the number of increments = int((ln(cardinality/4096) / ln(3/2)) + 1),
* finally len(node array) = pow(3/2, int((ln(cardinality/4096) / ln(3/2)) + 1) * 4096
* - [6] The overhead length of each node of HashTable, including the next node pointer, Hash value,
* and a bool type variable;
* - [7] nodeOverheadSpace: The memory cost by the overhead of all nodes in the HashTable;
* - [8] Number of Tuples participating in the build;
* - [9] The length of each Tuple pointer;
* - [10] nodeTuplePointerSpace: The memory cost by Tuple pointers of all nodes in HashTable;
*/
public long constructHashTableSpace() {
double bucketPointerSpace = ((double) rhsTreeCardinality / 0.75) * 8;
double nodeArrayLen =
Math.pow(1.5, (int) ((Math.log((double) rhsTreeCardinality/4096) / Math.log(1.5)) + 1)) * 4096;
double nodeOverheadSpace = nodeArrayLen * 16;
double nodeTuplePointerSpace = nodeArrayLen * rhsTreeTupleIdNum * 8;
return Math.round((bucketPointerSpace + (double) rhsTreeCardinality * rhsTreeAvgRowSize
+ nodeOverheadSpace + nodeTuplePointerSpace) * PlannerContext.HASH_TBL_SPACE_OVERHEAD);
}
}

View File

@ -299,14 +299,14 @@ public class OlapScanNode extends ScanNode {
computeTupleState(analyzer);
/**
* Compute InAccurate stats before mv selector and tablet pruning.
* Compute InAccurate cardinality before mv selector and tablet pruning.
* - Accurate statistical information relies on the selector of materialized views and bucket reduction.
* - However, Those both processes occur after the reorder algorithm is completed.
* - When Join reorder is turned on, the computeStats() must be completed before the reorder algorithm.
* - So only an inaccurate statistical information can be calculated here.
* - When Join reorder is turned on, the cardinality must be calculated before the reorder algorithm.
* - So only an inaccurate cardinality can be calculated here.
*/
if (analyzer.safeIsEnableJoinReorderBasedCost()) {
computeInaccurateStats(analyzer);
computeInaccurateCardinality();
}
}
@ -326,9 +326,8 @@ public class OlapScanNode extends ScanNode {
} catch (AnalysisException e) {
throw new UserException(e.getMessage());
}
if (!analyzer.safeIsEnableJoinReorderBasedCost()) {
computeOldRowSizeAndCardinality();
}
// Relatively accurate cardinality according to ScanRange in getScanRangeLocations
computeStats(analyzer);
computeNumNodes();
}
@ -338,7 +337,9 @@ public class OlapScanNode extends ScanNode {
}
}
public void computeOldRowSizeAndCardinality() {
@Override
public void computeStats(Analyzer analyzer) {
super.computeStats(analyzer);
if (cardinality > 0) {
avgRowSize = totalBytes / (float) cardinality;
capCardinalityAtLimit();
@ -357,7 +358,7 @@ public class OlapScanNode extends ScanNode {
}
/**
* Calculate inaccurate stats such as: cardinality.
* Calculate inaccurate cardinality.
* cardinality: the value of cardinality is the sum of rowcount which belongs to selectedPartitionIds
* The cardinality here is actually inaccurate, it will be greater than the actual value.
* There are two reasons
@ -369,11 +370,8 @@ public class OlapScanNode extends ScanNode {
* 1. Calculate how many rows were scanned
* 2. Apply conjunct
* 3. Apply limit
*
* @param analyzer
*/
private void computeInaccurateStats(Analyzer analyzer) {
super.computeStats(analyzer);
private void computeInaccurateCardinality() {
// step1: Calculate how many rows were scanned
cardinality = 0;
for (long selectedPartitionId : selectedPartitionIds) {

View File

@ -0,0 +1,115 @@
package org.apache.doris.planner;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.TableRef;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import mockit.Expectations;
import mockit.Mocked;
public class JoinCostEvaluationTest {
@Mocked
private ConnectContext ctx;
@Mocked
private PlanNode node;
@Mocked
private TableRef ref;
@Mocked
private PlanFragmentId fragmentId;
@Mocked
private PlanNodeId nodeId;
@Mocked
private DataPartition partition;
@Mocked
private BinaryPredicate expr;
@Before
public void setUp() {
new Expectations() {
{
node.getTupleIds();
result = Lists.newArrayList();
node.getTblRefIds();
result = Lists.newArrayList();
node.getChildren();
result = Lists.newArrayList();
}
};
}
private PlanFragment createPlanFragment(long cardinality, float avgRowSize, int numNodes) {
HashJoinNode root
= new HashJoinNode(nodeId, node, node, ref, Lists.newArrayList(expr), Lists.newArrayList(expr));
root.cardinality = cardinality;
root.avgRowSize = avgRowSize;
root.numNodes = numNodes;
return new PlanFragment(fragmentId, root, partition);
}
private boolean callIsBroadcastCostSmaller(long rhsTreeCardinality, float rhsTreeAvgRowSize,
long lhsTreeCardinality, float lhsTreeAvgRowSize, int lhsTreeNumNodes) {
PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0);
PlanFragment leftChildFragment = createPlanFragment(lhsTreeCardinality, lhsTreeAvgRowSize, lhsTreeNumNodes);
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
return joinCostEvaluation.isBroadcastCostSmaller();
}
@Test
public void testIsBroadcastCostSmaller() {
new Expectations() {
{
ConnectContext.get();
result = ctx;
ConnectContext.get().getSessionVariable().getPreferJoinMethod();
result = "broadcast";
}
};
Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 1));
Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 2));
Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, 1, 1, 3));
Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, 1, 1, -1));
Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, -1, 1, 1));
Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, -1, 1, 1));
Assert.assertTrue(callIsBroadcastCostSmaller(20, 10, 5000, 2, 10));
Assert.assertFalse(callIsBroadcastCostSmaller(20, 10, 5, 2, 10));
}
@Test
public void testConstructHashTableSpace() {
long rhsTreeCardinality = 4097;
float rhsTreeAvgRowSize = 16;
int rhsNodeTupleIdNum = 1;
int rhsTreeTupleIdNum = rhsNodeTupleIdNum * 2;
double nodeArrayLen = 6144;
new Expectations() {
{
node.getTupleIds();
result = new ArrayList<>(Collections.nCopies(rhsNodeTupleIdNum, 0));
}
};
PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0);
PlanFragment leftChildFragment = createPlanFragment(0, 0, 0);
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
long hashTableSpace = Math.round((((rhsTreeCardinality / 0.75) * 8)
+ ((double) rhsTreeCardinality * rhsTreeAvgRowSize) + (nodeArrayLen * 16)
+ (nodeArrayLen * rhsTreeTupleIdNum * 8)) * PlannerContext.HASH_TBL_SPACE_OVERHEAD);
Assert.assertEquals(hashTableSpace, joinCostEvaluation.constructHashTableSpace());
}
}