From 75d954ced5dd5854567551cb2d5fbd9eb94d1353 Mon Sep 17 00:00:00 2001 From: Xinyi Zou Date: Mon, 26 Jul 2021 09:38:41 +0800 Subject: [PATCH] [Feature] Modify the cost evaluation algorithm of Broadcast and Shuffle Join (#6274) issue #6272 After modifying the Hash Table memory cost estimation algorithm, TPC-DS query-72 etc. can be passed, avoiding OOM. --- .../doris/planner/DistributedPlanner.java | 53 +----- .../doris/planner/JoinCostEvaluation.java | 153 ++++++++++++++++++ .../apache/doris/planner/OlapScanNode.java | 24 ++- .../doris/planner/JoinCostEvaluationTest.java | 115 +++++++++++++ 4 files changed, 282 insertions(+), 63 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/planner/JoinCostEvaluation.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/planner/JoinCostEvaluationTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java index 5aa9bb6fe9..7b8efe845f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java @@ -348,53 +348,7 @@ public class DistributedPlanner { return leftChildFragment; } - // broadcast: send the rightChildFragment's output to each node executing - // the leftChildFragment; the cost across all nodes is proportional to the - // total amount of data sent - - // NOTICE: - // for now, only MysqlScanNode and OlapScanNode has Cardinality. - // OlapScanNode's cardinality is calculated by row num and data size, - // and MysqlScanNode's cardinality is always 0. - // Other ScanNode's cardinality is -1. - // - // So if there are other kind of scan node in join query, it won't be able to calculate the cost of - // join normally and result in both "broadcastCost" and "partitionCost" be 0. And this will lead - // to a SHUFFLE join. - PlanNode rhsTree = rightChildFragment.getPlanRoot(); - long rhsDataSize = 0; - long broadcastCost = 0; - if (rhsTree.getCardinality() != -1 && leftChildFragment.getNumNodes() != -1) { - rhsDataSize = Math.round((double) rhsTree.getCardinality() * rhsTree.getAvgRowSize()); - broadcastCost = rhsDataSize * leftChildFragment.getNumNodes(); - } - if (LOG.isDebugEnabled()) { - LOG.debug("broadcast: cost=" + Long.toString(broadcastCost)); - LOG.debug("card=" + Long.toString(rhsTree.getCardinality()) + " row_size=" - + Float.toString(rhsTree.getAvgRowSize()) + " #nodes=" - + Integer.toString(leftChildFragment.getNumNodes())); - } - - // repartition: both left- and rightChildFragment are partitioned on the - // join exprs - // TODO: take existing partition of input fragments into account to avoid - // unnecessary repartitioning - PlanNode lhsTree = leftChildFragment.getPlanRoot(); - long partitionCost = 0; - if (lhsTree.getCardinality() != -1 && rhsTree.getCardinality() != -1) { - partitionCost = Math.round( - (double) lhsTree.getCardinality() * lhsTree.getAvgRowSize() + (double) rhsTree - .getCardinality() * rhsTree.getAvgRowSize()); - } - if (LOG.isDebugEnabled()) { - LOG.debug("partition: cost=" + Long.toString(partitionCost)); - LOG.debug("lhs card=" + Long.toString(lhsTree.getCardinality()) + " row_size=" - + Float.toString(lhsTree.getAvgRowSize())); - LOG.debug("rhs card=" + Long.toString(rhsTree.getCardinality()) + " row_size=" - + Float.toString(rhsTree.getAvgRowSize())); - LOG.debug(rhsTree.getExplainString()); - } - + JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment); boolean doBroadcast; // we do a broadcast join if // - we're explicitly told to do so @@ -408,9 +362,9 @@ public class DistributedPlanner { // respect user join hint doBroadcast = true; } else if (!node.getInnerRef().isPartitionJoin() - && isBroadcastCostSmaller(broadcastCost, partitionCost) + && joinCostEvaluation.isBroadcastCostSmaller() && (perNodeMemLimit == 0 - || Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) { + || joinCostEvaluation.constructHashTableSpace() <= perNodeMemLimit)) { doBroadcast = true; } else { doBroadcast = false; @@ -418,7 +372,6 @@ public class DistributedPlanner { } else { doBroadcast = false; } - if (doBroadcast) { node.setDistributionMode(HashJoinNode.DistributionMode.BROADCAST); // Doesn't create a new fragment, but modifies leftChildFragment to execute diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/JoinCostEvaluation.java b/fe/fe-core/src/main/java/org/apache/doris/planner/JoinCostEvaluation.java new file mode 100644 index 0000000000..e0b918d95a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/JoinCostEvaluation.java @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.planner; + +import org.apache.doris.qe.ConnectContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Evaluate the cost of Broadcast and Shuffle Join to choose between the two + * + * broadcast: send the rightChildFragment's output to each node executing the leftChildFragment; the cost across + * all nodes is proportional to the total amount of data sent. + * shuffle: also called Partitioned Join. That is, small tables and large tables are hashed according to the Join key, + * and then distributed Join is performed. + * + * NOTICE: + * for now, only MysqlScanNode and OlapScanNode has Cardinality. OlapScanNode's cardinality is calculated by row num + * and data size, and MysqlScanNode's cardinality is always 1. Other ScanNode's cardinality is -1. + * So if there are other kind of scan node in join query, it won't be able to calculate the cost of join normally + * and result in both "broadcastCost" and "partitionCost" be 0. And this will lead to a SHUFFLE join. + */ +public class JoinCostEvaluation { + private final static Logger LOG = LogManager.getLogger(JoinCostEvaluation.class); + + private final long rhsTreeCardinality; + private final float rhsTreeAvgRowSize; + private final int rhsTreeTupleIdNum; + private final long lhsTreeCardinality; + private final float lhsTreeAvgRowSize; + private final int lhsTreeNumNodes; + private long broadcastCost = 0; + private long partitionCost = 0; + + JoinCostEvaluation(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) { + PlanNode rhsTree = rightChildFragment.getPlanRoot(); + rhsTreeCardinality = rhsTree.getCardinality(); + rhsTreeAvgRowSize = rhsTree.getAvgRowSize(); + rhsTreeTupleIdNum = rhsTree.getTupleIds().size(); + PlanNode lhsTree = leftChildFragment.getPlanRoot(); + lhsTreeCardinality = lhsTree.getCardinality(); + lhsTreeAvgRowSize = lhsTree.getAvgRowSize(); + lhsTreeNumNodes = leftChildFragment.getNumNodes(); + + String nodeOverview = setNodeOverview(node, rightChildFragment, leftChildFragment); + broadcastCost(nodeOverview); + shuffleCost(nodeOverview); + } + + private String setNodeOverview(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) { + return "root node id=" + node.getId().toString() + ": " + node.planNodeName + + " right fragment id=" + rightChildFragment.getFragmentId().toString() + + " left fragment id=" + leftChildFragment.getFragmentId().toString(); + } + + private void broadcastCost(String nodeOverview) { + if (rhsTreeCardinality != -1 && lhsTreeNumNodes != -1) { + broadcastCost = Math.round((double) rhsTreeCardinality * rhsTreeAvgRowSize) * lhsTreeNumNodes; + } + if (LOG.isDebugEnabled()) { + LOG.debug(nodeOverview); + LOG.debug("broadcast: cost=" + Long.toString(broadcastCost)); + LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality) + + " rhs row_size=" + Float.toString(rhsTreeAvgRowSize) + + " lhs nodes=" + Integer.toString(lhsTreeNumNodes)); + } + } + + /** + * repartition: both left- and rightChildFragment are partitioned on the join exprs + * TODO: take existing partition of input fragments into account to avoid unnecessary repartitioning + */ + private void shuffleCost(String nodeOverview) { + if (lhsTreeCardinality != -1 && rhsTreeCardinality != -1) { + partitionCost = Math.round( + (double) lhsTreeCardinality * lhsTreeAvgRowSize + (double) rhsTreeCardinality * rhsTreeAvgRowSize); + } + if (LOG.isDebugEnabled()) { + LOG.debug(nodeOverview); + LOG.debug("partition: cost=" + Long.toString(partitionCost)); + LOG.debug("lhs card=" + Long.toString(lhsTreeCardinality) + " row_size=" + + Float.toString(lhsTreeAvgRowSize)); + LOG.debug("rhs card=" + Long.toString(rhsTreeCardinality) + " row_size=" + + Float.toString(rhsTreeAvgRowSize)); + } + } + + /** + * When broadcastCost and partitionCost are equal, there is no uniform standard for which join implementation + * is better. Some scenarios are suitable for broadcast join, and some scenarios are suitable for shuffle join. + * Therefore, we add a SessionVariable to help users choose a better join implementation. + */ + public boolean isBroadcastCostSmaller() { + String joinMethod = ConnectContext.get().getSessionVariable().getPreferJoinMethod(); + if (joinMethod.equalsIgnoreCase("broadcast")) { + return broadcastCost <= partitionCost; + } else { + return broadcastCost < partitionCost; + } + } + + /** + * Estimate the memory cost of constructing Hash Table in Broadcast Join. + * The memory cost by the Hash Table = ((cardinality/0.75[1]) * 8[2])[3] + (cardinality * avgRowSize)[4] + * + (nodeArrayLen[5] * 16[6])[7] + (nodeArrayLen * tupleNum[8] * 8[9])[10]. consists of four parts: + * 1) All bucket pointers. 2) Length of the node array. 3) Overhead of all nodes. 4) Tuple pointers of all nodes. + * - [1] Expansion factor of the number of HashTable buckets; + * - [2] The pointer length of each bucket of HashTable; + * - [3] bucketPointerSpace: The memory cost by all bucket pointers of HashTable; + * - [4] rhsDataSize: The memory cost by all nodes of HashTable, equal to the amount of data that the right table + * participates in the construction of HashTable; + * - [5] HashTable stores the length of the node array, which is larger than the actual cardinality. The initial + * value is 4096. When the storage is full, one-half of the current array length is added each time. + * The length of the array after each increment is actually a sequence of numbers: + * 4096 = pow(3/2, 0) * 4096, + * 6144 = pow(3/2, 1) * 4096, + * 9216 = pow(3/2, 2) * 4096, + * 13824 = pow(3/2, 3) * 4096, + * finally need to satisfy len(node array)> cardinality, + * so the number of increments = int((ln(cardinality/4096) / ln(3/2)) + 1), + * finally len(node array) = pow(3/2, int((ln(cardinality/4096) / ln(3/2)) + 1) * 4096 + * - [6] The overhead length of each node of HashTable, including the next node pointer, Hash value, + * and a bool type variable; + * - [7] nodeOverheadSpace: The memory cost by the overhead of all nodes in the HashTable; + * - [8] Number of Tuples participating in the build; + * - [9] The length of each Tuple pointer; + * - [10] nodeTuplePointerSpace: The memory cost by Tuple pointers of all nodes in HashTable; + */ + public long constructHashTableSpace() { + double bucketPointerSpace = ((double) rhsTreeCardinality / 0.75) * 8; + double nodeArrayLen = + Math.pow(1.5, (int) ((Math.log((double) rhsTreeCardinality/4096) / Math.log(1.5)) + 1)) * 4096; + double nodeOverheadSpace = nodeArrayLen * 16; + double nodeTuplePointerSpace = nodeArrayLen * rhsTreeTupleIdNum * 8; + return Math.round((bucketPointerSpace + (double) rhsTreeCardinality * rhsTreeAvgRowSize + + nodeOverheadSpace + nodeTuplePointerSpace) * PlannerContext.HASH_TBL_SPACE_OVERHEAD); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java index a69c08d514..aab988f650 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java @@ -299,14 +299,14 @@ public class OlapScanNode extends ScanNode { computeTupleState(analyzer); /** - * Compute InAccurate stats before mv selector and tablet pruning. + * Compute InAccurate cardinality before mv selector and tablet pruning. * - Accurate statistical information relies on the selector of materialized views and bucket reduction. * - However, Those both processes occur after the reorder algorithm is completed. - * - When Join reorder is turned on, the computeStats() must be completed before the reorder algorithm. - * - So only an inaccurate statistical information can be calculated here. + * - When Join reorder is turned on, the cardinality must be calculated before the reorder algorithm. + * - So only an inaccurate cardinality can be calculated here. */ if (analyzer.safeIsEnableJoinReorderBasedCost()) { - computeInaccurateStats(analyzer); + computeInaccurateCardinality(); } } @@ -326,9 +326,8 @@ public class OlapScanNode extends ScanNode { } catch (AnalysisException e) { throw new UserException(e.getMessage()); } - if (!analyzer.safeIsEnableJoinReorderBasedCost()) { - computeOldRowSizeAndCardinality(); - } + // Relatively accurate cardinality according to ScanRange in getScanRangeLocations + computeStats(analyzer); computeNumNodes(); } @@ -338,7 +337,9 @@ public class OlapScanNode extends ScanNode { } } - public void computeOldRowSizeAndCardinality() { + @Override + public void computeStats(Analyzer analyzer) { + super.computeStats(analyzer); if (cardinality > 0) { avgRowSize = totalBytes / (float) cardinality; capCardinalityAtLimit(); @@ -357,7 +358,7 @@ public class OlapScanNode extends ScanNode { } /** - * Calculate inaccurate stats such as: cardinality. + * Calculate inaccurate cardinality. * cardinality: the value of cardinality is the sum of rowcount which belongs to selectedPartitionIds * The cardinality here is actually inaccurate, it will be greater than the actual value. * There are two reasons @@ -369,11 +370,8 @@ public class OlapScanNode extends ScanNode { * 1. Calculate how many rows were scanned * 2. Apply conjunct * 3. Apply limit - * - * @param analyzer */ - private void computeInaccurateStats(Analyzer analyzer) { - super.computeStats(analyzer); + private void computeInaccurateCardinality() { // step1: Calculate how many rows were scanned cardinality = 0; for (long selectedPartitionId : selectedPartitionIds) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/JoinCostEvaluationTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/JoinCostEvaluationTest.java new file mode 100644 index 0000000000..26baf13ffb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/JoinCostEvaluationTest.java @@ -0,0 +1,115 @@ +package org.apache.doris.planner; + +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.TableRef; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.Lists; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; + +import mockit.Expectations; +import mockit.Mocked; + +public class JoinCostEvaluationTest { + + @Mocked + private ConnectContext ctx; + + @Mocked + private PlanNode node; + + @Mocked + private TableRef ref; + + @Mocked + private PlanFragmentId fragmentId; + + @Mocked + private PlanNodeId nodeId; + + @Mocked + private DataPartition partition; + + @Mocked + private BinaryPredicate expr; + + @Before + public void setUp() { + new Expectations() { + { + node.getTupleIds(); + result = Lists.newArrayList(); + node.getTblRefIds(); + result = Lists.newArrayList(); + node.getChildren(); + result = Lists.newArrayList(); + } + }; + } + + private PlanFragment createPlanFragment(long cardinality, float avgRowSize, int numNodes) { + HashJoinNode root + = new HashJoinNode(nodeId, node, node, ref, Lists.newArrayList(expr), Lists.newArrayList(expr)); + root.cardinality = cardinality; + root.avgRowSize = avgRowSize; + root.numNodes = numNodes; + return new PlanFragment(fragmentId, root, partition); + } + + private boolean callIsBroadcastCostSmaller(long rhsTreeCardinality, float rhsTreeAvgRowSize, + long lhsTreeCardinality, float lhsTreeAvgRowSize, int lhsTreeNumNodes) { + PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0); + PlanFragment leftChildFragment = createPlanFragment(lhsTreeCardinality, lhsTreeAvgRowSize, lhsTreeNumNodes); + JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment); + return joinCostEvaluation.isBroadcastCostSmaller(); + } + + @Test + public void testIsBroadcastCostSmaller() { + new Expectations() { + { + ConnectContext.get(); + result = ctx; + ConnectContext.get().getSessionVariable().getPreferJoinMethod(); + result = "broadcast"; + } + }; + Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 1)); + Assert.assertTrue(callIsBroadcastCostSmaller(1, 1, 1, 1, 2)); + Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, 1, 1, 3)); + Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, 1, 1, -1)); + Assert.assertTrue(callIsBroadcastCostSmaller(-1, 1, -1, 1, 1)); + Assert.assertFalse(callIsBroadcastCostSmaller(1, 1, -1, 1, 1)); + Assert.assertTrue(callIsBroadcastCostSmaller(20, 10, 5000, 2, 10)); + Assert.assertFalse(callIsBroadcastCostSmaller(20, 10, 5, 2, 10)); + } + + @Test + public void testConstructHashTableSpace() { + long rhsTreeCardinality = 4097; + float rhsTreeAvgRowSize = 16; + int rhsNodeTupleIdNum = 1; + int rhsTreeTupleIdNum = rhsNodeTupleIdNum * 2; + double nodeArrayLen = 6144; + new Expectations() { + { + node.getTupleIds(); + result = new ArrayList<>(Collections.nCopies(rhsNodeTupleIdNum, 0)); + } + }; + PlanFragment rightChildFragment = createPlanFragment(rhsTreeCardinality, rhsTreeAvgRowSize, 0); + PlanFragment leftChildFragment = createPlanFragment(0, 0, 0); + JoinCostEvaluation joinCostEvaluation = new JoinCostEvaluation(node, rightChildFragment, leftChildFragment); + long hashTableSpace = Math.round((((rhsTreeCardinality / 0.75) * 8) + + ((double) rhsTreeCardinality * rhsTreeAvgRowSize) + (nodeArrayLen * 16) + + (nodeArrayLen * rhsTreeTupleIdNum * 8)) * PlannerContext.HASH_TBL_SPACE_OVERHEAD); + + Assert.assertEquals(hashTableSpace, joinCostEvaluation.constructHashTableSpace()); + } +} \ No newline at end of file