[enhancement](Nereids): when the DPhyper failed, roll back to cascades without join reorder (#26390)

when the DPhyper failed, roll back to cascades without join reorder
This commit is contained in:
谢健
2023-11-07 20:05:40 +08:00
committed by GitHub
parent 5e9a23e643
commit 2be6c9ff7d
8 changed files with 84 additions and 25 deletions

View File

@ -37,6 +37,8 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.BitSet;
@ -47,6 +49,7 @@ import java.util.Set;
* Join Order job with DPHyp
*/
public class JoinOrderJob extends Job {
public static final Logger LOG = LogManager.getLogger(JoinOrderJob.class);
private final Group group;
private final Set<NamedExpression> otherProject = new HashSet<>();
@ -87,20 +90,14 @@ public class JoinOrderJob extends Job {
int limit = 1000;
PlanReceiver planReceiver = new PlanReceiver(this.context, limit, hyperGraph,
group.getLogicalProperties().getOutputSet());
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
if (!subgraphEnumerator.enumerate()) {
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
graphSimplifier.simplifyGraph(limit);
if (!subgraphEnumerator.enumerate()) {
throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit);
}
if (!tryEnumerateJoin(hyperGraph, planReceiver, limit)) {
return group;
}
Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap());
// For other projects, such as project constant or project nullable, we construct a new project above root
if (otherProject.size() != 0) {
if (!otherProject.isEmpty()) {
otherProject.addAll(optimized.getLogicalExpression().getPlan().getOutput());
LogicalProject logicalProject = new LogicalProject<>(new ArrayList<>(otherProject),
LogicalProject<Plan> logicalProject = new LogicalProject<>(new ArrayList<>(otherProject),
optimized.getLogicalExpression().getPlan());
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
optimized = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression);
@ -108,6 +105,15 @@ public class JoinOrderJob extends Job {
return optimized;
}
private boolean tryEnumerateJoin(HyperGraph hyperGraph, PlanReceiver planReceiver, int limit) {
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
if (!subgraphEnumerator.enumerate()) {
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
return graphSimplifier.simplifyGraph(limit) && subgraphEnumerator.enumerate();
}
return true;
}
/**
* build a hyperGraph for the root group
*

View File

@ -392,26 +392,27 @@ public class GraphSimplifier {
}
private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) {
LogicalJoin<? extends Plan, ? extends Plan> join;
if (graph.getEdges().size() > 64 * 63 / 8) {
// If there are too many edges, it is advisable to return the "edge" directly
// to avoid lengthy enumeration time.
return edge;
join = edge.getJoin();
} else {
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
List<Expression> hashConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
List<Expression> otherConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
join = edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
}
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
List<Expression> hashConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
List<Expression> otherConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
LogicalJoin<? extends Plan, ? extends Plan> join =
edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
Edge newEdge = new Edge(
join,
-1, edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes());
newEdge.setRightRequiredNodes(edge.getRightRequiredNodes());
newEdge.addLeftNode(leftNodes);
@ -462,7 +463,6 @@ public class GraphSimplifier {
// if the left and right is overlapping, just return null.
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
// construct new Edge
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
if (LongBitmap.isOverlap(newLeft, bitmap3)) {

View File

@ -316,10 +316,16 @@ public class HyperGraph {
// For these nodes that are only in the old edge, we need remove the edge from them
// For these nodes that are only in the new edge, we need to add the edge to them
Edge edge = edges.get(edgeIndex);
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, false);
}
updateEdges(edge, edge.getLeftExtendedNodes(), newLeft);
updateEdges(edge, edge.getRightExtendedNodes(), newRight);
edges.get(edgeIndex).setLeftExtendedNodes(newLeft);
edges.get(edgeIndex).setRightExtendedNodes(newRight);
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, true);
}
}
private void updateEdges(Edge edge, long oldNodes, long newNodes) {

View File

@ -95,4 +95,8 @@ public class Counter implements AbstractReceiver {
public int getLimit() {
return limit;
}
public int getEmitCount() {
return emitCount;
}
}

View File

@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -76,6 +77,8 @@ public class PlanReceiver implements AbstractReceiver {
HyperGraph hyperGraph;
final Set<Slot> finalOutputs;
long startTime = System.currentTimeMillis();
long timeLimit = ConnectContext.get().getSessionVariable().joinReorderTimeLimit;
public PlanReceiver(JobContext jobContext, int limit, HyperGraph hyperGraph, Set<Slot> outputs) {
this.jobContext = jobContext;
@ -104,7 +107,7 @@ public class PlanReceiver implements AbstractReceiver {
processMissedEdges(left, right, edges);
emitCount += 1;
if (emitCount > limit) {
if (emitCount > limit || System.currentTimeMillis() - startTime > timeLimit) {
return false;
}
@ -272,6 +275,7 @@ public class PlanReceiver implements AbstractReceiver {
usdEdges.clear();
complexProjectMap.clear();
complexProjectMap.putAll(hyperGraph.getComplexProject());
startTime = System.currentTimeMillis();
}
@Override

View File

@ -346,6 +346,7 @@ public class SessionVariable implements Serializable, Writable {
public static final String MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER = "max_table_count_use_cascades_join_reorder";
public static final int MIN_JOIN_REORDER_TABLE_COUNT = 2;
public static final String JOIN_REORDER_TIME_LIMIT = "join_order_time_limit";
public static final String SHOW_USER_DEFAULT_ROLE = "show_user_default_role";
public static final String ENABLE_MINIDUMP = "enable_minidump";
@ -1106,6 +1107,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER, needForward = true)
public int maxTableCountUseCascadesJoinReorder = 10;
@VariableMgr.VarAttr(name = JOIN_REORDER_TIME_LIMIT, needForward = true)
public long joinReorderTimeLimit = 1000;
// If this is true, the result of `show roles` will return all user default role
@VariableMgr.VarAttr(name = SHOW_USER_DEFAULT_ROLE, needForward = true)
public boolean showUserDefaultRole = false;

View File

@ -250,6 +250,23 @@ class GraphSimplifierTest {
}
}
@Disabled
@Test
void test64Clique() {
HyperGraph hyperGraph = new HyperGraphBuilder(Sets.newHashSet(JoinType.INNER_JOIN))
.randomBuildWith(64, 67);
Counter counter = new Counter();
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
graphSimplifier.simplifyGraph(1);
for (Edge edge : hyperGraph.getEdges()) {
System.out.println(edge);
}
Assertions.assertTrue(subgraphEnumerator.enumerate());
System.out.println(counter.getEmitCount());
}
@Disabled
@Test
void benchGraphSimplifier() {

View File

@ -19,13 +19,16 @@ package org.apache.doris.nereids.sqltest;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.HyperGraphBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
public class JoinOrderJobTest extends SqlTestBase {
@ -141,4 +144,19 @@ public class JoinOrderJobTest extends SqlTestBase {
.optimize()
.getBestPlanTree();
}
@Disabled
@Test
void test64CliqueJoin() {
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(Sets.newHashSet(JoinType.INNER_JOIN));
Plan plan = hyperGraphBuilder
.randomBuildPlanWith(64, 64 * 63 / 2);
plan = new LogicalProject(plan.getOutput(), plan);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
hyperGraphBuilder.initStats(cascadesContext);
PlanChecker.from(cascadesContext)
.rewrite()
.dpHypOptimize()
.getBestPlanTree();
}
}