[Feature] Modify the cost evaluation algorithm of Broadcast and Shuffle Join (#6274)
issue #6272 After modifying the Hash Table memory cost estimation algorithm, TPC-DS query-72 etc. can be passed, avoiding OOM.
This commit is contained in:
@ -348,53 +348,7 @@ public class DistributedPlanner {
|
||||
return leftChildFragment;
|
||||
}
|
||||
|
||||
// broadcast: send the rightChildFragment's output to each node executing
|
||||
// the leftChildFragment; the cost across all nodes is proportional to the
|
||||
// total amount of data sent
|
||||
|
||||
// NOTICE:
|
||||
// for now, only MysqlScanNode and OlapScanNode has Cardinality.
|
||||
// OlapScanNode's cardinality is calculated by row num and data size,
|
||||
// and MysqlScanNode's cardinality is always 0.
|
||||
// Other ScanNode's cardinality is -1.
|
||||
//
|
||||
// So if there are other kind of scan node in join query, it won't be able to calculate the cost of
|
||||
// join normally and result in both "broadcastCost" and "partitionCost" be 0. And this will lead
|
||||
// to a SHUFFLE join.
|
||||
PlanNode rhsTree = rightChildFragment.getPlanRoot();
|
||||
long rhsDataSize = 0;
|
||||
long broadcastCost = 0;
|
||||
if (rhsTree.getCardinality() != -1 && leftChildFragment.getNumNodes() != -1) {
|
||||
rhsDataSize = Math.round((double) rhsTree.getCardinality() * rhsTree.getAvgRowSize());
|
||||
broadcastCost = rhsDataSize * leftChildFragment.getNumNodes();
|
||||
}
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug("broadcast: cost=" + Long.toString(broadcastCost));
|
||||
LOG.debug("card=" + Long.toString(rhsTree.getCardinality()) + " row_size="
|
||||
+ Float.toString(rhsTree.getAvgRowSize()) + " #nodes="
|
||||
+ Integer.toString(leftChildFragment.getNumNodes()));
|
||||
}
|
||||
|
||||
// repartition: both left- and rightChildFragment are partitioned on the
|
||||
// join exprs
|
||||
// TODO: take existing partition of input fragments into account to avoid
|
||||
// unnecessary repartitioning
|
||||
PlanNode lhsTree = leftChildFragment.getPlanRoot();
|
||||
long partitionCost = 0;
|
||||
if (lhsTree.getCardinality() != -1 && rhsTree.getCardinality() != -1) {
|
||||
partitionCost = Math.round(
|
||||
(double) lhsTree.getCardinality() * lhsTree.getAvgRowSize() + (double) rhsTree
|
||||
.getCardinality() * rhsTree.getAvgRowSize());
|
||||
}
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug("partition: cost=" + Long.toString(partitionCost));
|
||||
LOG.debug("lhs card=" + Long.toString(lhsTree.getCardinality()) + " row_size="
|
||||
+ Float.toString(lhsTree.getAvgRowSize()));
|
||||
LOG.debug("rhs card=" + Long.toString(rhsTree.getCardinality()) + " row_size="
|
||||
+ Float.toString(rhsTree.getAvgRowSize()));
|
||||
LOG.debug(rhsTree.getExplainString());
|
||||
}
|
||||
|
||||
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
|
||||
boolean doBroadcast;
|
||||
// we do a broadcast join if
|
||||
// - we're explicitly told to do so
|
||||
@ -408,9 +362,9 @@ public class DistributedPlanner {
|
||||
// respect user join hint
|
||||
doBroadcast = true;
|
||||
} else if (!node.getInnerRef().isPartitionJoin()
|
||||
&& isBroadcastCostSmaller(broadcastCost, partitionCost)
|
||||
&& joinCostEvaluation.isBroadcastCostSmaller()
|
||||
&& (perNodeMemLimit == 0
|
||||
|| Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) {
|
||||
|| joinCostEvaluation.constructHashTableSpace() <= perNodeMemLimit)) {
|
||||
doBroadcast = true;
|
||||
} else {
|
||||
doBroadcast = false;
|
||||
@ -418,7 +372,6 @@ public class DistributedPlanner {
|
||||
} else {
|
||||
doBroadcast = false;
|
||||
}
|
||||
|
||||
if (doBroadcast) {
|
||||
node.setDistributionMode(HashJoinNode.DistributionMode.BROADCAST);
|
||||
// Doesn't create a new fragment, but modifies leftChildFragment to execute
|
||||
|
||||
@ -0,0 +1,153 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.planner;
|
||||
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
/**
|
||||
* Evaluate the cost of Broadcast and Shuffle Join to choose between the two
|
||||
*
|
||||
* broadcast: send the rightChildFragment's output to each node executing the leftChildFragment; the cost across
|
||||
* all nodes is proportional to the total amount of data sent.
|
||||
* shuffle: also called Partitioned Join. That is, small tables and large tables are hashed according to the Join key,
|
||||
* and then distributed Join is performed.
|
||||
*
|
||||
* NOTICE:
|
||||
* for now, only MysqlScanNode and OlapScanNode has Cardinality. OlapScanNode's cardinality is calculated by row num
|
||||
* and data size, and MysqlScanNode's cardinality is always 1. Other ScanNode's cardinality is -1.
|
||||
* So if there are other kind of scan node in join query, it won't be able to calculate the cost of join normally
|
||||
* and result in both "broadcastCost" and "partitionCost" be 0. And this will lead to a SHUFFLE join.
|
||||
*/
|
||||
public class JoinCostEvaluation {
|
||||
private final static Logger LOG = LogManager.getLogger(JoinCostEvaluation.class);
|
||||
|
||||
private final long rhsTreeCardinality;
|
||||
private final float rhsTreeAvgRowSize;
|
||||
private final int rhsTreeTupleIdNum;
|
||||
private final long lhsTreeCardinality;
|
||||
private final float lhsTreeAvgRowSize;
|
||||
private final int lhsTreeNumNodes;
|
||||
private long broadcastCost = 0;
|
||||
private long partitionCost = 0;
|
||||
|
||||
JoinCostEvaluation(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
|
||||
PlanNode rhsTree = rightChildFragment.getPlanRoot();
|
||||
rhsTreeCardinality = rhsTree.getCardinality();
|
||||
rhsTreeAvgRowSize = rhsTree.getAvgRowSize();
|
||||
rhsTreeTupleIdNum = rhsTree.getTupleIds().size();
|
||||
PlanNode lhsTree = leftChildFragment.getPlanRoot();
|
||||
lhsTreeCardinality = lhsTree.getCardinality();
|
||||
lhsTreeAvgRowSize = lhsTree.getAvgRowSize();
|
||||
lhsTreeNumNodes = leftChildFragment.getNumNodes();
|
||||
|
||||
String nodeOverview = setNodeOverview(node, rightChildFragment, leftChildFragment);
|
||||
broadcastCost(nodeOverview);
|
||||
shuffleCost(nodeOverview);
|
||||
}
|
||||
|
||||
private String setNodeOverview(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
|
||||
return "root node id=" + node.getId().toString() + ": " + node.planNodeName
|
||||
+ " right fragment id=" + rightChildFragment.getFragmentId().toString()
|
||||
+ " left fragment id=" + leftChildFragment.getFragmentId().toString();
|
||||
}
|
||||
|
||||
private void broadcastCost(String nodeOverview) {
|
||||
if (rhsTreeCardinality != -1 && lhsTreeNumNodes != -1) {
|
||||
broadcastCost = Math.round((double) rhsTreeCardinality * rhsTreeAvgRowSize) * lhsTreeNumNodes;
|
||||
}
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug(nodeOverview);
|
||||
LOG.debug("broadcast: cost=" + Long.toString(broadcastCost));
|
||||
LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality)
|
||||
+ " rhs row_size=" + Float.toString(rhsTreeAvgRowSize)
|
||||
+ " lhs nodes=" + Integer.toString(lhsTreeNumNodes));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* repartition: both left- and rightChildFragment are partitioned on the join exprs
|
||||
* TODO: take existing partition of input fragments into account to avoid unnecessary repartitioning
|
||||
*/
|
||||
private void shuffleCost(String nodeOverview) {
|
||||
if (lhsTreeCardinality != -1 && rhsTreeCardinality != -1) {
|
||||
partitionCost = Math.round(
|
||||
(double) lhsTreeCardinality * lhsTreeAvgRowSize + (double) rhsTreeCardinality * rhsTreeAvgRowSize);
|
||||
}
|
||||
if (LOG.isDebugEnabled()) {
|
||||
LOG.debug(nodeOverview);
|
||||
LOG.debug("partition: cost=" + Long.toString(partitionCost));
|
||||
LOG.debug("lhs card=" + Long.toString(lhsTreeCardinality) + " row_size="
|
||||
+ Float.toString(lhsTreeAvgRowSize));
|
||||
LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality) + " row_size="
|
||||
+ Float.toString(rhsTreeAvgRowSize));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When broadcastCost and partitionCost are equal, there is no uniform standard for which join implementation
|
||||
* is better. Some scenarios are suitable for broadcast join, and some scenarios are suitable for shuffle join.
|
||||
* Therefore, we add a SessionVariable to help users choose a better join implementation.
|
||||
*/
|
||||
public boolean isBroadcastCostSmaller() {
|
||||
String joinMethod = ConnectContext.get().getSessionVariable().getPreferJoinMethod();
|
||||
if (joinMethod.equalsIgnoreCase("broadcast")) {
|
||||
return broadcastCost <= partitionCost;
|
||||
} else {
|
||||
return broadcastCost < partitionCost;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate the memory cost of constructing Hash Table in Broadcast Join.
|
||||
* The memory cost by the Hash Table = ((cardinality/0.75[1]) * 8[2])[3] + (cardinality * avgRowSize)[4]
|
||||
* + (nodeArrayLen[5] * 16[6])[7] + (nodeArrayLen * tupleNum[8] * 8[9])[10]. consists of four parts:
|
||||
* 1) All bucket pointers. 2) Length of the node array. 3) Overhead of all nodes. 4) Tuple pointers of all nodes.
|
||||
* - [1] Expansion factor of the number of HashTable buckets;
|
||||
* - [2] The pointer length of each bucket of HashTable;
|
||||
* - [3] bucketPointerSpace: The memory cost by all bucket pointers of HashTable;
|
||||
* - [4] rhsDataSize: The memory cost by all nodes of HashTable, equal to the amount of data that the right table
|
||||
* participates in the construction of HashTable;
|
||||
* - [5] HashTable stores the length of the node array, which is larger than the actual cardinality. The initial
|
||||
* value is 4096. When the storage is full, one-half of the current array length is added each time.
|
||||
* The length of the array after each increment is actually a sequence of numbers:
|
||||
* 4096 = pow(3/2, 0) * 4096,
|
||||
* 6144 = pow(3/2, 1) * 4096,
|
||||
* 9216 = pow(3/2, 2) * 4096,
|
||||
* 13824 = pow(3/2, 3) * 4096,
|
||||
* finally need to satisfy len(node array)> cardinality,
|
||||
* so the number of increments = int((ln(cardinality/4096) / ln(3/2)) + 1),
|
||||
* finally len(node array) = pow(3/2, int((ln(cardinality/4096) / ln(3/2)) + 1) * 4096
|
||||
* - [6] The overhead length of each node of HashTable, including the next node pointer, Hash value,
|
||||
* and a bool type variable;
|
||||
* - [7] nodeOverheadSpace: The memory cost by the overhead of all nodes in the HashTable;
|
||||
* - [8] Number of Tuples participating in the build;
|
||||
* - [9] The length of each Tuple pointer;
|
||||
* - [10] nodeTuplePointerSpace: The memory cost by Tuple pointers of all nodes in HashTable;
|
||||
*/
|
||||
public long constructHashTableSpace() {
|
||||
double bucketPointerSpace = ((double) rhsTreeCardinality / 0.75) * 8;
|
||||
double nodeArrayLen =
|
||||
Math.pow(1.5, (int) ((Math.log((double) rhsTreeCardinality/4096) / Math.log(1.5)) + 1)) * 4096;
|
||||
double nodeOverheadSpace = nodeArrayLen * 16;
|
||||
double nodeTuplePointerSpace = nodeArrayLen * rhsTreeTupleIdNum * 8;
|
||||
return Math.round((bucketPointerSpace + (double) rhsTreeCardinality * rhsTreeAvgRowSize
|
||||
+ nodeOverheadSpace + nodeTuplePointerSpace) * PlannerContext.HASH_TBL_SPACE_OVERHEAD);
|
||||
}
|
||||
}
|
||||
@ -299,14 +299,14 @@ public class OlapScanNode extends ScanNode {
|
||||
computeTupleState(analyzer);
|
||||
|
||||
/**
|
||||
* Compute InAccurate stats before mv selector and tablet pruning.
|
||||
* Compute InAccurate cardinality before mv selector and tablet pruning.
|
||||
* - Accurate statistical information relies on the selector of materialized views and bucket reduction.
|
||||
* - However, Those both processes occur after the reorder algorithm is completed.
|
||||
* - When Join reorder is turned on, the computeStats() must be completed before the reorder algorithm.
|
||||
* - So only an inaccurate statistical information can be calculated here.
|
||||
* - When Join reorder is turned on, the cardinality must be calculated before the reorder algorithm.
|
||||
* - So only an inaccurate cardinality can be calculated here.
|
||||
*/
|
||||
if (analyzer.safeIsEnableJoinReorderBasedCost()) {
|
||||
computeInaccurateStats(analyzer);
|
||||
computeInaccurateCardinality();
|
||||
}
|
||||
}
|
||||
|
||||
@ -326,9 +326,8 @@ public class OlapScanNode extends ScanNode {
|
||||
} catch (AnalysisException e) {
|
||||
throw new UserException(e.getMessage());
|
||||
}
|
||||
if (!analyzer.safeIsEnableJoinReorderBasedCost()) {
|
||||
computeOldRowSizeAndCardinality();
|
||||
}
|
||||
// Relatively accurate cardinality according to ScanRange in getScanRangeLocations
|
||||
computeStats(analyzer);
|
||||
computeNumNodes();
|
||||
}
|
||||
|
||||
@ -338,7 +337,9 @@ public class OlapScanNode extends ScanNode {
|
||||
}
|
||||
}
|
||||
|
||||
public void computeOldRowSizeAndCardinality() {
|
||||
@Override
|
||||
public void computeStats(Analyzer analyzer) {
|
||||
super.computeStats(analyzer);
|
||||
if (cardinality > 0) {
|
||||
avgRowSize = totalBytes / (float) cardinality;
|
||||
capCardinalityAtLimit();
|
||||
@ -357,7 +358,7 @@ public class OlapScanNode extends ScanNode {
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate inaccurate stats such as: cardinality.
|
||||
* Calculate inaccurate cardinality.
|
||||
* cardinality: the value of cardinality is the sum of rowcount which belongs to selectedPartitionIds
|
||||
* The cardinality here is actually inaccurate, it will be greater than the actual value.
|
||||
* There are two reasons
|
||||
@ -369,11 +370,8 @@ public class OlapScanNode extends ScanNode {
|
||||
* 1. Calculate how many rows were scanned
|
||||
* 2. Apply conjunct
|
||||
* 3. Apply limit
|
||||
*
|
||||
* @param analyzer
|
||||
*/
|
||||
private void computeInaccurateStats(Analyzer analyzer) {
|
||||
super.computeStats(analyzer);
|
||||
private void computeInaccurateCardinality() {
|
||||
// step1: Calculate how many rows were scanned
|
||||
cardinality = 0;
|
||||
for (long selectedPartitionId : selectedPartitionIds) {
|
||||
|
||||
@ -0,0 +1,115 @@
|
||||
package org.apache.doris.planner;
|
||||
|
||||
import org.apache.doris.analysis.BinaryPredicate;
|
||||
import org.apache.doris.analysis.TableRef;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
||||
import mockit.Expectations;
|
||||
import mockit.Mocked;
|
||||
|
||||
public class JoinCostEvaluationTest {
|
||||
|
||||
@Mocked
|
||||
private ConnectContext ctx;
|
||||
|
||||
@Mocked
|
||||
private PlanNode node;
|
||||
|
||||
@Mocked
|
||||
private TableRef ref;
|
||||
|
||||
@Mocked
|
||||
private PlanFragmentId fragmentId;
|
||||
|
||||
@Mocked
|
||||
private PlanNodeId nodeId;
|
||||
|
||||
@Mocked
|
||||
private DataPartition partition;
|
||||
|
||||
@Mocked
|
||||
private BinaryPredicate expr;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
new Expectations() {
|
||||
{
|
||||
node.getTupleIds();
|
||||
result = Lists.newArrayList();
|
||||
node.getTblRefIds();
|
||||
result = Lists.newArrayList();
|
||||
node.getChildren();
|
||||
result = Lists.newArrayList();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private PlanFragment createPlanFragment(long cardinality, float avgRowSize, int numNodes) {
|
||||
HashJoinNode root
|
||||
= new HashJoinNode(nodeId, node, node, ref, Lists.newArrayList(expr), Lists.newArrayList(expr));
|
||||
root.cardinality = cardinality;
|
||||
root.avgRowSize = avgRowSize;
|
||||
root.numNodes = numNodes;
|
||||
return new PlanFragment(fragmentId, root, partition);
|
||||
}
|
||||
|
||||
private boolean callIsBroadcastCostSmaller(long rhsTreeCardinality, float rhsTreeAvgRowSize,
|
||||
long lhsTreeCardinality, float lhsTreeAvgRowSize, int lhsTreeNumNodes) {
|
||||
PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0);
|
||||
PlanFragment leftChildFragment = createPlanFragment(lhsTreeCardinality, lhsTreeAvgRowSize, lhsTreeNumNodes);
|
||||
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
|
||||
return joinCostEvaluation.isBroadcastCostSmaller();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsBroadcastCostSmaller() {
|
||||
new Expectations() {
|
||||
{
|
||||
ConnectContext.get();
|
||||
result = ctx;
|
||||
ConnectContext.get().getSessionVariable().getPreferJoinMethod();
|
||||
result = "broadcast";
|
||||
}
|
||||
};
|
||||
Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 1));
|
||||
Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 2));
|
||||
Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, 1, 1, 3));
|
||||
Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, 1, 1, -1));
|
||||
Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, -1, 1, 1));
|
||||
Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, -1, 1, 1));
|
||||
Assert.assertTrue(callIsBroadcastCostSmaller(20, 10, 5000, 2, 10));
|
||||
Assert.assertFalse(callIsBroadcastCostSmaller(20, 10, 5, 2, 10));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstructHashTableSpace() {
|
||||
long rhsTreeCardinality = 4097;
|
||||
float rhsTreeAvgRowSize = 16;
|
||||
int rhsNodeTupleIdNum = 1;
|
||||
int rhsTreeTupleIdNum = rhsNodeTupleIdNum * 2;
|
||||
double nodeArrayLen = 6144;
|
||||
new Expectations() {
|
||||
{
|
||||
node.getTupleIds();
|
||||
result = new ArrayList<>(Collections.nCopies(rhsNodeTupleIdNum, 0));
|
||||
}
|
||||
};
|
||||
PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0);
|
||||
PlanFragment leftChildFragment = createPlanFragment(0, 0, 0);
|
||||
JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment);
|
||||
long hashTableSpace = Math.round((((rhsTreeCardinality / 0.75) * 8)
|
||||
+ ((double) rhsTreeCardinality * rhsTreeAvgRowSize) + (nodeArrayLen * 16)
|
||||
+ (nodeArrayLen * rhsTreeTupleIdNum * 8)) * PlannerContext.HASH_TBL_SPACE_OVERHEAD);
|
||||
|
||||
Assert.assertEquals(hashTableSpace, joinCostEvaluation.constructHashTableSpace());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user