[SQL Plan]Fix explicit broadcast join bug (#4424)

Use broadcast join when users specify explicitly [BROADCAST] in queries.
This commit is contained in:
wyb
2020-08-25 22:06:45 +08:00
committed by GitHub
parent c201cf6e4f
commit 691227922e
2 changed files with 78 additions and 19 deletions

View File

@ -361,13 +361,18 @@ public class DistributedPlanner {
// side to be partitioned for correctness)
// - and the expected size of the hash tbl doesn't exceed perNodeMemLimit
// we set partition join as default when broadcast join cost equals partition join cost
if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN
&& node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN
&& (perNodeMemLimit == 0 || Math.round(
(double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)
&& (node.getInnerRef().isBroadcastJoin() || (!node.getInnerRef().isPartitionJoin()
&& isBroadcastCostSmaller(broadcastCost, partitionCost)))) {
doBroadcast = true;
if (node.getJoinOp() != JoinOperator.RIGHT_OUTER_JOIN && node.getJoinOp() != JoinOperator.FULL_OUTER_JOIN) {
if (node.getInnerRef().isBroadcastJoin()) {
// respect user join hint
doBroadcast = true;
} else if (!node.getInnerRef().isPartitionJoin()
&& isBroadcastCostSmaller(broadcastCost, partitionCost)
&& (perNodeMemLimit == 0
|| Math.round((double) rhsDataSize * PlannerContext.HASH_TBL_SPACE_OVERHEAD) <= perNodeMemLimit)) {
doBroadcast = true;
} else {
doBroadcast = false;
}
} else {
doBroadcast = false;
}

View File

@ -17,26 +17,60 @@
package org.apache.doris.planner;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.thrift.TExplainLevel;
import org.apache.doris.utframe.UtFrameUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.junit.Assert;
import org.junit.Test;
import java.util.List;
import java.util.Set;
import mockit.Expectations;
import mockit.Injectable;
import mockit.Mocked;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.util.List;
import java.util.Set;
import java.util.UUID;
public class DistributedPlannerTest {
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
private static ConnectContext ctx;
@Mocked
PlannerContext plannerContext;
@BeforeClass
public static void setUp() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
ctx = UtFrameUtils.createDefaultCtx();
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
Catalog.getCurrentCatalog().createDb(createDbStmt);
// create table tbl1
String createTblStmtStr = "create table db1.tbl1(k1 int, k2 varchar(32), v bigint sum) "
+ "AGGREGATE KEY(k1,k2) distributed by hash(k1) buckets 1 properties('replication_num' = '1');";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
// create table tbl2
createTblStmtStr = "create table db1.tbl2(k3 int, k4 varchar(32)) "
+ "DUPLICATE KEY(k3) distributed by hash(k3) buckets 1 properties('replication_num' = '1');";
createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
}
@After
public void tearDown() throws Exception {
FileUtils.deleteDirectory(new File(runningDir));
}
@Test
public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode assertNumRowsNode,
@ -44,7 +78,8 @@ public class DistributedPlannerTest {
@Injectable PlanNodeId planNodeId,
@Injectable PlanFragmentId planFragmentId,
@Injectable PlanNode inputPlanRoot,
@Injectable TupleId tupleId) {
@Injectable TupleId tupleId,
@Mocked PlannerContext plannerContext) {
DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext);
List<TupleId> tupleIdList = Lists.newArrayList(tupleId);
@ -82,7 +117,8 @@ public class DistributedPlannerTest {
@Test
public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode assertNumRowsNode,
@Injectable PlanFragment inputFragment){
@Injectable PlanFragment inputFragment,
@Mocked PlannerContext plannerContext){
DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext);
PlanFragment assertFragment = Deencapsulation.invoke(distributedPlanner, "createAssertFragment",
@ -91,4 +127,22 @@ public class DistributedPlannerTest {
Assert.assertTrue(assertFragment.getPlanRoot() instanceof AssertNumRowsNode);
}
@Test
public void testExplicitlyBroadcastJoin() throws Exception {
String sql = "explain select * from db1.tbl1 join [BROADCAST] db1.tbl2 on tbl1.k1 = tbl2.k3";
StmtExecutor stmtExecutor = new StmtExecutor(ctx, sql);
stmtExecutor.execute();
Planner planner = stmtExecutor.planner();
List<PlanFragment> fragments = planner.getFragments();
String plan = planner.getExplainString(fragments, TExplainLevel.NORMAL);
Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (BROADCAST)"));
sql = "explain select * from db1.tbl1 join [SHUFFLE] db1.tbl2 on tbl1.k1 = tbl2.k3";
stmtExecutor = new StmtExecutor(ctx, sql);
stmtExecutor.execute();
planner = stmtExecutor.planner();
fragments = planner.getFragments();
plan = planner.getExplainString(fragments, TExplainLevel.NORMAL);
Assert.assertEquals(1, StringUtils.countMatches(plan, "INNER JOIN (PARTITIONED)"));
}
}