[enhancement](Nereids): when the DPhyper failed, roll back to cascades without join reorder (#26390)
when the DPhyper failed, roll back to cascades without join reorder
This commit is contained in:
@ -37,6 +37,8 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
@ -47,6 +49,7 @@ import java.util.Set;
|
||||
* Join Order job with DPHyp
|
||||
*/
|
||||
public class JoinOrderJob extends Job {
|
||||
public static final Logger LOG = LogManager.getLogger(JoinOrderJob.class);
|
||||
private final Group group;
|
||||
private final Set<NamedExpression> otherProject = new HashSet<>();
|
||||
|
||||
@ -87,20 +90,14 @@ public class JoinOrderJob extends Job {
|
||||
int limit = 1000;
|
||||
PlanReceiver planReceiver = new PlanReceiver(this.context, limit, hyperGraph,
|
||||
group.getLogicalProperties().getOutputSet());
|
||||
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
|
||||
if (!subgraphEnumerator.enumerate()) {
|
||||
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
|
||||
graphSimplifier.simplifyGraph(limit);
|
||||
if (!subgraphEnumerator.enumerate()) {
|
||||
throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit);
|
||||
}
|
||||
if (!tryEnumerateJoin(hyperGraph, planReceiver, limit)) {
|
||||
return group;
|
||||
}
|
||||
Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap());
|
||||
|
||||
// For other projects, such as project constant or project nullable, we construct a new project above root
|
||||
if (otherProject.size() != 0) {
|
||||
if (!otherProject.isEmpty()) {
|
||||
otherProject.addAll(optimized.getLogicalExpression().getPlan().getOutput());
|
||||
LogicalProject logicalProject = new LogicalProject<>(new ArrayList<>(otherProject),
|
||||
LogicalProject<Plan> logicalProject = new LogicalProject<>(new ArrayList<>(otherProject),
|
||||
optimized.getLogicalExpression().getPlan());
|
||||
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
|
||||
optimized = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression);
|
||||
@ -108,6 +105,15 @@ public class JoinOrderJob extends Job {
|
||||
return optimized;
|
||||
}
|
||||
|
||||
private boolean tryEnumerateJoin(HyperGraph hyperGraph, PlanReceiver planReceiver, int limit) {
|
||||
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
|
||||
if (!subgraphEnumerator.enumerate()) {
|
||||
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
|
||||
return graphSimplifier.simplifyGraph(limit) && subgraphEnumerator.enumerate();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* build a hyperGraph for the root group
|
||||
*
|
||||
|
||||
@ -392,26 +392,27 @@ public class GraphSimplifier {
|
||||
}
|
||||
|
||||
private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) {
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join;
|
||||
if (graph.getEdges().size() > 64 * 63 / 8) {
|
||||
// If there are too many edges, it is advisable to return the "edge" directly
|
||||
// to avoid lengthy enumeration time.
|
||||
return edge;
|
||||
join = edge.getJoin();
|
||||
} else {
|
||||
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
|
||||
List<Expression> hashConditions = validEdgesMap.stream()
|
||||
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> otherConditions = validEdgesMap.stream()
|
||||
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
join = edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
|
||||
}
|
||||
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
|
||||
List<Expression> hashConditions = validEdgesMap.stream()
|
||||
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> otherConditions = validEdgesMap.stream()
|
||||
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join =
|
||||
edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
|
||||
|
||||
Edge newEdge = new Edge(
|
||||
join,
|
||||
-1, edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
|
||||
edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
|
||||
newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes());
|
||||
newEdge.setRightRequiredNodes(edge.getRightRequiredNodes());
|
||||
newEdge.addLeftNode(leftNodes);
|
||||
@ -462,7 +463,6 @@ public class GraphSimplifier {
|
||||
// if the left and right is overlapping, just return null.
|
||||
Preconditions.checkArgument(
|
||||
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
|
||||
|
||||
// construct new Edge
|
||||
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
|
||||
if (LongBitmap.isOverlap(newLeft, bitmap3)) {
|
||||
|
||||
@ -316,10 +316,16 @@ public class HyperGraph {
|
||||
// For these nodes that are only in the old edge, we need remove the edge from them
|
||||
// For these nodes that are only in the new edge, we need to add the edge to them
|
||||
Edge edge = edges.get(edgeIndex);
|
||||
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
|
||||
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, false);
|
||||
}
|
||||
updateEdges(edge, edge.getLeftExtendedNodes(), newLeft);
|
||||
updateEdges(edge, edge.getRightExtendedNodes(), newRight);
|
||||
edges.get(edgeIndex).setLeftExtendedNodes(newLeft);
|
||||
edges.get(edgeIndex).setRightExtendedNodes(newRight);
|
||||
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
|
||||
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, true);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateEdges(Edge edge, long oldNodes, long newNodes) {
|
||||
|
||||
@ -95,4 +95,8 @@ public class Counter implements AbstractReceiver {
|
||||
public int getLimit() {
|
||||
return limit;
|
||||
}
|
||||
|
||||
public int getEmitCount() {
|
||||
return emitCount;
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -76,6 +77,8 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
|
||||
HyperGraph hyperGraph;
|
||||
final Set<Slot> finalOutputs;
|
||||
long startTime = System.currentTimeMillis();
|
||||
long timeLimit = ConnectContext.get().getSessionVariable().joinReorderTimeLimit;
|
||||
|
||||
public PlanReceiver(JobContext jobContext, int limit, HyperGraph hyperGraph, Set<Slot> outputs) {
|
||||
this.jobContext = jobContext;
|
||||
@ -104,7 +107,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
processMissedEdges(left, right, edges);
|
||||
|
||||
emitCount += 1;
|
||||
if (emitCount > limit) {
|
||||
if (emitCount > limit || System.currentTimeMillis() - startTime > timeLimit) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -272,6 +275,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
usdEdges.clear();
|
||||
complexProjectMap.clear();
|
||||
complexProjectMap.putAll(hyperGraph.getComplexProject());
|
||||
startTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -346,6 +346,7 @@ public class SessionVariable implements Serializable, Writable {
|
||||
public static final String MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER = "max_table_count_use_cascades_join_reorder";
|
||||
public static final int MIN_JOIN_REORDER_TABLE_COUNT = 2;
|
||||
|
||||
public static final String JOIN_REORDER_TIME_LIMIT = "join_order_time_limit";
|
||||
public static final String SHOW_USER_DEFAULT_ROLE = "show_user_default_role";
|
||||
|
||||
public static final String ENABLE_MINIDUMP = "enable_minidump";
|
||||
@ -1106,6 +1107,9 @@ public class SessionVariable implements Serializable, Writable {
|
||||
@VariableMgr.VarAttr(name = MAX_TABLE_COUNT_USE_CASCADES_JOIN_REORDER, needForward = true)
|
||||
public int maxTableCountUseCascadesJoinReorder = 10;
|
||||
|
||||
@VariableMgr.VarAttr(name = JOIN_REORDER_TIME_LIMIT, needForward = true)
|
||||
public long joinReorderTimeLimit = 1000;
|
||||
|
||||
// If this is true, the result of `show roles` will return all user default role
|
||||
@VariableMgr.VarAttr(name = SHOW_USER_DEFAULT_ROLE, needForward = true)
|
||||
public boolean showUserDefaultRole = false;
|
||||
|
||||
@ -250,6 +250,23 @@ class GraphSimplifierTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Disabled
|
||||
@Test
|
||||
void test64Clique() {
|
||||
HyperGraph hyperGraph = new HyperGraphBuilder(Sets.newHashSet(JoinType.INNER_JOIN))
|
||||
.randomBuildWith(64, 67);
|
||||
Counter counter = new Counter();
|
||||
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
|
||||
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
|
||||
graphSimplifier.simplifyGraph(1);
|
||||
|
||||
for (Edge edge : hyperGraph.getEdges()) {
|
||||
System.out.println(edge);
|
||||
}
|
||||
Assertions.assertTrue(subgraphEnumerator.enumerate());
|
||||
System.out.println(counter.getEmitCount());
|
||||
}
|
||||
|
||||
@Disabled
|
||||
@Test
|
||||
void benchGraphSimplifier() {
|
||||
|
||||
@ -19,13 +19,16 @@ package org.apache.doris.nereids.sqltest;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.HyperGraphBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class JoinOrderJobTest extends SqlTestBase {
|
||||
@ -141,4 +144,19 @@ public class JoinOrderJobTest extends SqlTestBase {
|
||||
.optimize()
|
||||
.getBestPlanTree();
|
||||
}
|
||||
|
||||
@Disabled
|
||||
@Test
|
||||
void test64CliqueJoin() {
|
||||
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(Sets.newHashSet(JoinType.INNER_JOIN));
|
||||
Plan plan = hyperGraphBuilder
|
||||
.randomBuildPlanWith(64, 64 * 63 / 2);
|
||||
plan = new LogicalProject(plan.getOutput(), plan);
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
|
||||
hyperGraphBuilder.initStats(cascadesContext);
|
||||
PlanChecker.from(cascadesContext)
|
||||
.rewrite()
|
||||
.dpHypOptimize()
|
||||
.getBestPlanTree();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user