[enhancement](Nereids) support other join framework in DPHyper (#21835)

implement CD-A algorithm in order to support others join in DPHyper.
The algorithm details are in on the correct and complete enumeration of the core search
This commit is contained in:
谢健
2023-07-21 18:31:52 +08:00
committed by GitHub
parent bed940b7fc
commit b76d0d84ac
12 changed files with 208 additions and 50 deletions

View File

@ -60,7 +60,9 @@ public class StatementContext {
private StatementBase parsedStatement;
private ColumnAliasGenerator columnAliasGenerator;
private int joinCount = 0;
private int maxNAryInnerJoin = 0;
private boolean isDpHyp = false;
private boolean isOtherJoinReorder = false;
@ -112,6 +114,16 @@ public class StatementContext {
return maxNAryInnerJoin;
}
public void setMaxContinuousJoin(int joinCount) {
if (joinCount > this.joinCount) {
this.joinCount = joinCount;
}
}
public int getMaxContinuousJoin() {
return joinCount;
}
public boolean isDpHyp() {
return isDpHyp;
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.jobs.executor;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
@ -57,9 +56,10 @@ public class Optimizer {
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
serializeStatUsed(cascadesContext.getConnectContext());
// DPHyp optimize
StatementContext statementContext = cascadesContext.getStatementContext();
boolean isDpHyp = getSessionVariable().enableDPHypOptimizer || statementContext.getMaxNAryInnerJoin()
> getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
int maxJoinCount = cascadesContext.getMemo().countMaxContinuousJoin();
cascadesContext.getStatementContext().setMaxContinuousJoin(maxJoinCount);
boolean isDpHyp = getSessionVariable().enableDPHypOptimizer
|| maxJoinCount > getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
cascadesContext.getStatementContext().setDpHyp(isDpHyp);
cascadesContext.getStatementContext().setOtherJoinReorder(false);
if (!getSessionVariable().isDisableJoinReorder() && isDpHyp) {

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
@ -66,7 +67,7 @@ public class JoinOrderJob extends Job {
}
private Group optimizePlan(Group group) {
if (group.isInnerJoinGroup()) {
if (group.isValidJoinGroup()) {
return optimizeJoin(group);
}
GroupExpression rootExpr = group.getLogicalExpression();
@ -111,19 +112,19 @@ public class JoinOrderJob extends Job {
* @param group root group, should be join type
* @param hyperGraph build hyperGraph
*/
public void buildGraph(Group group, HyperGraph hyperGraph) {
public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
if (group.isProjectGroup()) {
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
processProjectPlan(hyperGraph, group);
return;
return edgeMap;
}
if (!group.isInnerJoinGroup()) {
if (!group.isValidJoinGroup()) {
hyperGraph.addNode(optimizePlan(group));
return;
return new BitSet();
}
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
buildGraph(group.getLogicalExpression().child(1), hyperGraph);
hyperGraph.addEdge(group);
BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap);
}
/**

View File

@ -35,6 +35,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@ -112,7 +113,7 @@ public class HyperGraph {
* @param group The group that is the end node in graph
*/
public void addNode(Group group) {
Preconditions.checkArgument(!group.isInnerJoinGroup());
Preconditions.checkArgument(!group.isValidJoinGroup());
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
@ -134,10 +135,11 @@ public class HyperGraph {
*
* @param group The join group
*/
public void addEdge(Group group) {
Preconditions.checkArgument(group.isInnerJoinGroup());
public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) {
Preconditions.checkArgument(group.isValidJoinGroup());
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
for (Expression expression : join.getHashJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
if (!conjuncts.containsKey(ends)) {
@ -152,25 +154,61 @@ public class HyperGraph {
}
conjuncts.get(ends).second.add(expression);
}
BitSet edgeMap = new BitSet();
edgeMap.or(leftEdgeMap);
edgeMap.or(rightEdgeMap);
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
entry.getValue().second, JoinHint.NONE, join.left(), join.right());
entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
Lists.newArrayList(join.left(), join.right()));
Edge edge = new Edge(singleJoin, edges.size());
Pair<Long, Long> ends = entry.getKey();
edge.setLeft(ends.first);
edge.setOriginalLeft(ends.first);
edge.setRight(ends.second);
edge.setOriginalRight(ends.second);
initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
edgeMap.set(edge.getIndex());
edges.add(edge);
}
return edgeMap;
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
// We don't implement this trick now.
}
// Make edge with CD-A algorithm in
// On the correct and complete enumeration of the core search
private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) {
long left = ends.first;
long right = ends.second;
for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) {
Edge lEdge = edges.get(i);
if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getLeft());
}
if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getRight());
}
}
for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) {
Edge rEdge = edges.get(i);
if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getRight());
}
if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getLeft());
}
}
edge.setOriginalLeft(left);
edge.setOriginalRight(right);
edge.setLeft(left);
edge.setRight(right);
}
private int findRoot(List<Integer> parent, int idx) {
int root = parent.get(idx);
if (root != idx) {

View File

@ -59,6 +59,7 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* The Receiver is used for cached the plan that has been emitted and build the new plan
@ -117,6 +118,9 @@ public class PlanReceiver implements AbstractReceiver {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
}
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
@ -207,30 +211,37 @@ public class PlanReceiver implements AbstractReceiver {
// Check whether only NSL can be performed
LogicalProperties joinProperties = new LogicalProperties(
() -> JoinUtils.getJoinOutput(joinType, left, right));
List<Plan> plans = Lists.newArrayList();
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
return Lists.newArrayList(
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
}
} else {
return Lists.newArrayList(
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right),
new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
plans.add(new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
}
}
return plans;
}
private JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
private @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
JoinType joinType = null;
for (Edge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
for (Expression expression : edge.getExpressions()) {

View File

@ -21,7 +21,6 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -374,16 +373,11 @@ public class Group {
/**
* This function used to check whether the group is an end node in DPHyp
*/
public boolean isInnerJoinGroup() {
public boolean isValidJoinGroup() {
Plan plan = getLogicalExpression().getPlan();
if (plan instanceof LogicalJoin
&& ((LogicalJoin) plan).getJoinType() == JoinType.INNER_JOIN) {
// Right now, we only support inner join
Preconditions.checkArgument(!((LogicalJoin) plan).getExpressions().isEmpty(),
"inner join must have join conjuncts");
return true;
}
return false;
return plan instanceof LogicalJoin
&& !((LogicalJoin) plan).isMarkJoin()
&& ((LogicalJoin) plan).getExpressions().size() > 0;
}
public boolean isProjectGroup() {

View File

@ -163,6 +163,41 @@ public class Memo {
return plan;
}
public int countMaxContinuousJoin() {
return countGroupJoin(root).second;
}
/**
* return the max continuous join operator
*/
public Pair<Integer, Integer> countGroupJoin(Group group) {
GroupExpression logicalExpr = group.getLogicalExpression();
List<Pair<Integer, Integer>> children = new ArrayList<>();
for (Group child : logicalExpr.children()) {
children.add(countGroupJoin(child));
}
if (group.isProjectGroup()) {
return children.get(0);
}
int maxJoinCount = 0;
int continuousJoinCount = 0;
for (Pair<Integer, Integer> child : children) {
maxJoinCount = Math.max(maxJoinCount, child.second);
}
if (group.isValidJoinGroup()) {
for (Pair<Integer, Integer> child : children) {
continuousJoinCount += child.first;
}
continuousJoinCount += 1;
} else if (group.isProjectGroup()) {
return children.get(0);
}
return Pair.of(continuousJoinCount, Math.max(continuousJoinCount, maxJoinCount));
}
/**
* Add plan to Memo.
*

View File

@ -195,7 +195,6 @@ public class RuleSet {
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.add(JoinCommute.BUSHY.build())
.addAll(OTHER_REORDER_RULES)
.build();
public List<Rule> getDPHypReorderRules() {

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.JoinOperator;
import org.apache.doris.common.AnalysisException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Map;
@ -53,6 +54,37 @@ public enum JoinType {
.put(RIGHT_ANTI_JOIN, LEFT_ANTI_JOIN)
.build();
// TODO: the right-semi/right-anti/right-outer join is not derived in paper. We need to derive them
/*ASSOC:
* topJoin bottomJoin
* / \ / \
* bottomJoin C -> A topJoin
* / \ / \
* A B B C
* ====================================
* topJoin bottomJoin
* topJoin - -
* bottomJoin + -
*/
private static final Map<JoinType, ImmutableSet<JoinType>> assocJoinMatrix
= ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
.put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.build();
private static final Map<JoinType, ImmutableSet<JoinType>> lAssocJoinMatrix
= ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
.put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.build();
private static final Map<JoinType, ImmutableSet<JoinType>> rAssocJoinMatrix
= ImmutableMap.<JoinType, ImmutableSet<JoinType>>builder()
.put(CROSS_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.put(INNER_JOIN, ImmutableSet.of(CROSS_JOIN, INNER_JOIN))
.build();
/**
* Convert join type in Nereids to legacy join type in Doris.
*
@ -157,6 +189,18 @@ public enum JoinType {
return joinSwapMap.containsKey(this);
}
public static boolean isAssoc(JoinType join1, JoinType join2) {
return assocJoinMatrix.containsKey(join1) && assocJoinMatrix.get(join1).contains(join2);
}
public static boolean isLAssoc(JoinType join1, JoinType join2) {
return lAssocJoinMatrix.containsKey(join1) && lAssocJoinMatrix.get(join1).contains(join2);
}
public static boolean isRAssoc(JoinType join1, JoinType join2) {
return rAssocJoinMatrix.containsKey(join1) && rAssocJoinMatrix.get(join1).contains(join2);
}
public JoinType swap() {
return joinSwapMap.get(this);
}

View File

@ -17,8 +17,10 @@
package org.apache.doris.nereids.sqltest;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class JoinOrderJobTest extends SqlTestBase {
@ -84,4 +86,26 @@ public class JoinOrderJobTest extends SqlTestBase {
.rewrite()
.dpHypOptimize();
}
@Test
protected void testCountJoin() {
String sql = "select count(*) \n"
+ "from \n"
+ "T1, \n"
+ "(\n"
+ "select sum(T2.score + T3.score) as score from T2 join T3 on T2.id = T3.id"
+ ") subTable, \n"
+ "( \n"
+ "select sum(T4.id*2) as id from T4"
+ ") doubleT4 \n"
+ "where \n"
+ "T1.id = doubleT4.id and \n"
+ "T1.score = subTable.score;\n";
Memo memo = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.getCascadesContext()
.getMemo();
Assertions.assertEquals(memo.countMaxContinuousJoin(), 2);
}
}

View File

@ -296,7 +296,7 @@ public class HyperGraphBuilder {
}
private void injectRowcount(Group group) {
if (!group.isInnerJoinGroup()) {
if (!group.isValidJoinGroup()) {
LogicalOlapScan scanPlan = (LogicalOlapScan) group.getLogicalExpression().getPlan();
Statistics stats = injectRowcount(scanPlan);
group.setStatistics(stats);

View File

@ -245,7 +245,7 @@ public class PlanChecker {
double now = System.currentTimeMillis();
Group root = cascadesContext.getMemo().getRoot();
boolean changeRoot = false;
if (root.isInnerJoinGroup()) {
if (root.isValidJoinGroup()) {
// If the root group is join group, DPHyp can change the root group.
// To keep the root group is not changed, we add a dummy project operator above join
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
@ -434,7 +434,7 @@ public class PlanChecker {
public PlanChecker orderJoin() {
Group root = cascadesContext.getMemo().getRoot();
boolean changeRoot = false;
if (root.isInnerJoinGroup()) {
if (root.isValidJoinGroup()) {
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
// FIXME: can't match type, convert List<Slot> to List<NamedExpression>
GroupExpression newExpr = new GroupExpression(