[enhancement](Nereids): speedup graphsimplifier (#26066)

1. fix some bugs in graphsimplifier
2. remove some time costed code
3. add shape check for graphsimplifier
This commit is contained in:
谢健
2023-10-30 20:13:13 +08:00
committed by GitHub
parent 619f2bbbda
commit b3f31f9e1a
5 changed files with 187 additions and 175 deletions

View File

@ -27,16 +27,6 @@ import org.apache.doris.qe.ConnectContext;
public interface Cost {
double getValue();
/**
* This is for calculating the cost in simplifier
*/
static Cost withRowCount(double rowCount) {
if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) {
return new CostV2(0, rowCount, 0);
}
return new CostV1(rowCount);
}
/**
* return zero cost
*/

View File

@ -45,13 +45,6 @@ class CostV1 implements Cost {
+ costWeight.networkWeight * networkCost;
}
public CostV1(double cost) {
this.cost = cost;
this.cpuCost = 0;
this.networkCost = 0;
this.memoryCost = 0;
}
public static CostV1 infinite() {
return INFINITE;
}

View File

@ -18,32 +18,28 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.stats.JoinEstimation;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@ -70,10 +66,9 @@ public class GraphSimplifier {
// because it's just used for simulating join. In fact, the graph simplifier
// just generate the partial order of join operator.
private final HashMap<Long, Statistics> cacheStats = new HashMap<>();
private final HashMap<Long, Cost> cacheCost = new HashMap<>();
private final Stack<SimplificationStep> appliedSteps = new Stack<>();
private final Stack<SimplificationStep> unAppliedSteps = new Stack<>();
private final HashMap<Long, Double> cacheCost = new HashMap<>();
private final Deque<SimplificationStep> appliedSteps = new ArrayDeque<>();
private final Deque<SimplificationStep> unAppliedSteps = new ArrayDeque<>();
private final Set<Edge> validEdges;
@ -91,7 +86,7 @@ public class GraphSimplifier {
}
for (Node node : graph.getNodes()) {
cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics());
cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount()));
cacheCost.put(node.getNodeMap(), node.getRowCount());
}
validEdges = graph.getEdges().stream()
.filter(e -> {
@ -116,6 +111,13 @@ public class GraphSimplifier {
initFirstStep();
}
private boolean isOverlap(Edge edge1, Edge edge2) {
return (LongBitmap.isOverlap(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes())
&& LongBitmap.isOverlap(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes()))
|| (LongBitmap.isOverlap(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes())
&& LongBitmap.isOverlap(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes()));
}
private void initFirstStep() {
extractJoinDependencies();
for (int i = 0; i < edgeSize; i += 1) {
@ -138,8 +140,10 @@ public class GraphSimplifier {
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset);
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes(), superset);
if (!circleDetector.checkCircleWithEdge(i, j) && !circleDetector.checkCircleWithEdge(j, i)
&& !edge2.isSub(edge1) && !edge1.isSub(edge2) && !superset.isEmpty()) {
if (edge2.isSub(edge1) || edge1.isSub(edge2) || superset.isEmpty() || isOverlap(edge1, edge2)) {
continue;
}
if (!(circleDetector.checkCircleWithEdge(i, j) || circleDetector.checkCircleWithEdge(j, i))) {
return false;
}
}
@ -211,7 +215,7 @@ public class GraphSimplifier {
appliedSteps.push(bestStep);
Preconditions.checkArgument(
cacheStats.containsKey(bestStep.newLeft) && cacheStats.containsKey(bestStep.newRight),
String.format("%s - %s", bestStep.newLeft, bestStep.newRight));
"<%s - %s> must has been stats derived", bestStep.newLeft, bestStep.newRight);
graph.modifyEdge(bestStep.afterIndex, bestStep.newLeft, bestStep.newRight);
if (needProcessNeighbor) {
processNeighbors(bestStep.afterIndex, 0, edgeSize);
@ -220,7 +224,8 @@ public class GraphSimplifier {
}
private boolean unApplySimplificationStep() {
Preconditions.checkArgument(appliedSteps.size() > 0);
Preconditions.checkArgument(!appliedSteps.isEmpty(),
"try to unapply a simplification step but there is no step applied");
SimplificationStep bestStep = appliedSteps.pop();
unAppliedSteps.push(bestStep);
graph.modifyEdge(bestStep.afterIndex, bestStep.oldLeft, bestStep.oldRight);
@ -350,8 +355,8 @@ public class GraphSimplifier {
|| !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) {
return Optional.empty();
}
Pair<Statistics, Edge> edge1Before2;
Pair<Statistics, Edge> edge2Before1;
Edge edge1Before2;
Edge edge2Before1;
List<Long> superBitset = new ArrayList<>();
if (tryGetSuperset(left1, left2, superBitset)) {
// (common Join1 right1) Join2 right2
@ -377,108 +382,145 @@ public class GraphSimplifier {
return Optional.empty();
}
if (edge1Before2 == null || edge2Before1 == null) {
return Optional.empty();
}
// edge1 is not the neighborhood of edge2
SimplificationStep simplificationStep = orderJoin(edge1Before2, edge2Before1, edgeIndex1, edgeIndex2);
return Optional.of(simplificationStep);
}
Pair<Statistics, Edge> threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
// (plan1 edge1 plan2) edge2 plan3
// The join may have redundant table, e.g., t1,t2 join t3 join t2,t4
// Therefore, the cost is not accurate
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
Statistics leftStats = JoinEstimation.estimate(cacheStats.get(bitmap1), cacheStats.get(bitmap2),
edge1.getJoin());
Statistics joinStats = JoinEstimation.estimate(leftStats, cacheStats.get(bitmap3), edge2.getJoin());
Edge edge = new Edge(
edge2.getJoin(), -1, edge2.getLeftChildEdges(), edge2.getRightChildEdges(), edge2.getSubTreeNodes());
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
// To avoid overlapping the left and the right, the newLeft is calculated, Note the
// newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency
newLeft = LongBitmap.andNot(newLeft, bitmap3);
edge.addLeftNodes(newLeft);
edge.addRightNode(edge2.getRightExtendedNodes());
cacheStats.put(newLeft, leftStats);
cacheCost.put(newLeft, calCost(edge2, leftStats, cacheStats.get(bitmap1), cacheStats.get(bitmap2)));
return Pair.of(joinStats, edge);
private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) {
if (graph.getEdges().size() > 64 * 63 / 8) {
// If there are too many edges, it is advisable to return the "edge" directly
// to avoid lengthy enumeration time.
return edge;
}
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
List<Expression> hashConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
List<Expression> otherConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
LogicalJoin<? extends Plan, ? extends Plan> join =
edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
Edge newEdge = new Edge(
join,
-1, edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes());
newEdge.setRightRequiredNodes(edge.getRightRequiredNodes());
newEdge.addLeftNode(leftNodes);
newEdge.addRightNode(rightNodes);
return newEdge;
}
Pair<Statistics, Edge> threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) {
// The bitmap may differ from the edge's reference slots.
// Taking into account the order: edge1<{1} - {2}> edge2<{1,3} - {4}>.
// Actually, we are considering the sequence {1,3} - {2} - {4}
long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap);
if (cacheStats.containsKey(bitmap)) {
return;
}
// Note the edge in graphSimplifier contains all tree nodes
Statistics joinStats = JoinEstimation
.estimate(cacheStats.get(leftBitmap),
cacheStats.get(rightBitmap), edge.getJoin());
cacheStats.put(bitmap, joinStats);
}
private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap);
Preconditions.checkArgument(cacheStats.containsKey(leftBitmap) && cacheStats.containsKey(rightBitmap)
&& cacheStats.containsKey(bitmap),
"graph simplifier meet an edge %s that have not been derived stats", edge);
LogicalJoin<? extends Plan, ? extends Plan> join = edge.getJoin();
Statistics leftStats = cacheStats.get(leftBitmap);
Statistics rightStats = cacheStats.get(rightBitmap);
double cost;
if (JoinUtils.shouldNestedLoopJoin(join)) {
cost = cacheCost.get(leftBitmap) + cacheCost.get(rightBitmap)
+ rightStats.getRowCount() + 1 / leftStats.getRowCount();
} else {
cost = cacheCost.get(leftBitmap) + cacheCost.get(rightBitmap)
+ (rightStats.getRowCount() + 1 / leftStats.getRowCount()) * 1.2;
}
if (!cacheCost.containsKey(bitmap) || cacheCost.get(bitmap) > cost) {
cacheCost.put(bitmap, cost);
}
return cost;
}
private @Nullable Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
// (plan1 edge1 plan2) edge2 plan3
// if the left and right is overlapping, just return null.
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
// construct new Edge
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
if (LongBitmap.isOverlap(newLeft, bitmap3)) {
return null;
}
Edge newEdge = constructEdge(newLeft, edge2, bitmap3);
deriveStats(edge1, bitmap1, bitmap2);
deriveStats(newEdge, newLeft, bitmap3);
calCost(edge1, bitmap1, bitmap2);
return newEdge;
}
private @Nullable Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
// plan1 edge1 (plan2 edge2 plan3)
Statistics rightStats = JoinEstimation.estimate(cacheStats.get(bitmap2), cacheStats.get(bitmap3),
edge2.getJoin());
Statistics joinStats = JoinEstimation.estimate(cacheStats.get(bitmap1), rightStats, edge1.getJoin());
Edge edge = new Edge(
edge1.getJoin(), -1, edge1.getLeftChildEdges(), edge1.getRightChildEdges(), edge1.getSubTreeNodes());
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
newRight = LongBitmap.andNot(newRight, bitmap1);
edge.addLeftNode(edge1.getLeftExtendedNodes());
edge.addRightNode(newRight);
cacheStats.put(newRight, rightStats);
cacheCost.put(newRight, calCost(edge2, rightStats, cacheStats.get(bitmap2), cacheStats.get(bitmap3)));
return Pair.of(joinStats, edge);
}
private Edge processMissedEdges(int edgeIndex1, int edgeIndex2, Edge edge) {
List<Edge> edges = Lists.newArrayList(edge);
edges.addAll(graph.getEdges().stream()
.filter(e -> e.getIndex() != edgeIndex1 && e.getIndex() != edgeIndex2
&& LongBitmap.isSubset(e.getReferenceNodes(), edge.getReferenceNodes())
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getLeftExtendedNodes())
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getRightExtendedNodes()))
.collect(Collectors.toList()));
if (edges.size() > 1) {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
LogicalJoin oldJoin = edge.getJoin();
LogicalJoin newJoin = new LogicalJoin<>(joinType, hashConjuncts,
otherConjuncts, oldJoin.getHint(), oldJoin.left(), oldJoin.right());
Edge newEdge = Edge.createTempEdge(newJoin);
newEdge.setLeftExtendedNodes(edge.getLeftExtendedNodes());
newEdge.setRightExtendedNodes(edge.getRightExtendedNodes());
return newEdge;
} else {
return edge;
if (LongBitmap.isOverlap(bitmap1, newRight)) {
return null;
}
Edge newEdge = constructEdge(bitmap1, edge1, newRight);
deriveStats(edge2, bitmap2, bitmap3);
deriveStats(newEdge, bitmap1, newRight);
calCost(edge1, bitmap2, bitmap3);
return newEdge;
}
private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2,
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
// TODO: Consider miss edges when construct join.
// considering
// a
// / \
// b - c
// when constructing edge_ab before edge_bc. edge_ac should be added on top join
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
Cost cost2Before1 = calCost(edge2Before1.second, edge2Before1.first,
cacheStats.get(edge2Before1.second.getLeftExtendedNodes()),
cacheStats.get(edge2Before1.second.getRightExtendedNodes()));
private SimplificationStep orderJoin(Edge edge1Before2,
Edge edge2Before1, int edgeIndex1, int edgeIndex2) {
double cost1Before2 = calCost(edge1Before2,
edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes());
double cost2Before1 = calCost(edge2Before1,
edge2Before1.getLeftExtendedNodes(), edge2Before1.getRightExtendedNodes());
double benefit = Double.MAX_VALUE;
SimplificationStep step;
// Choose the plan with smaller cost and make the simplification step to replace the old edge by it.
if (cost1Before2.getValue() < cost2Before1.getValue()) {
if (cost1Before2.getValue() != 0) {
benefit = cost2Before1.getValue() / cost1Before2.getValue();
if (cost1Before2 < cost2Before1) {
if (cost1Before2 != 0) {
benefit = cost2Before1 / cost1Before2;
}
// choose edge1Before2
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeftExtendedNodes(),
edge1Before2.second.getRightExtendedNodes(), graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2,
edge1Before2.getLeftExtendedNodes(),
edge1Before2.getRightExtendedNodes(),
graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getEdge(edgeIndex2).getRightExtendedNodes());
} else {
if (cost2Before1.getValue() != 0) {
benefit = cost1Before2.getValue() / cost2Before1.getValue();
if (cost2Before1 != 0) {
benefit = cost1Before2 / cost2Before1;
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeftExtendedNodes(),
edge2Before1.second.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.getLeftExtendedNodes(),
edge2Before1.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getEdge(edgeIndex1).getRightExtendedNodes());
}
return step;
@ -495,41 +537,6 @@ public class GraphSimplifier {
return false;
}
private Cost calCost(Edge edge, Statistics stats,
Statistics leftStats, Statistics rightStats) {
LogicalJoin join = edge.getJoin();
PlanContext planContext = new PlanContext(stats, ImmutableList.of(leftStats, rightStats));
Cost cost;
if (JoinUtils.shouldNestedLoopJoin(join)) {
PhysicalNestedLoopJoin nestedLoopJoin = new PhysicalNestedLoopJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());
cost = CostCalculator.calculateCost(nestedLoopJoin, planContext);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
} else {
PhysicalHashJoin hashJoin = new PhysicalHashJoin<>(
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getHint(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());
cost = CostCalculator.calculateCost(hashJoin, planContext);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
}
return cost;
}
/**
* Put join dependencies into circle detector.
*/
@ -594,7 +601,7 @@ public class GraphSimplifier {
@Override
public int compareTo(GraphSimplifier.BestSimplification o) {
Preconditions.checkArgument(step.isPresent());
return Double.compare(getBenefit(), o.getBenefit());
return Double.compare(o.getBenefit(), getBenefit());
}
public double getBenefit() {

View File

@ -50,6 +50,8 @@ public class HyperGraph {
private final List<Node> nodes = new ArrayList<>();
private final HashSet<Group> nodeSet = new HashSet<>();
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
// record all edges that can be placed on the subgraph
private final Map<Long, BitSet> treeEdgesCache = new HashMap<>();
// Record the complex project expression for some subgraph
// e.g. project (a + b)
@ -268,6 +270,30 @@ public class HyperGraph {
return Pair.of(left, right);
}
public BitSet getEdgesInOperator(long left, long right) {
BitSet operatorEdgesMap = new BitSet();
operatorEdgesMap.or(getEdgesInTree(LongBitmap.or(left, right)));
operatorEdgesMap.andNot(getEdgesInTree(left));
operatorEdgesMap.andNot(getEdgesInTree(right));
return operatorEdgesMap;
}
/**
* Returns all edges in the tree
*/
public BitSet getEdgesInTree(long treeNodesMap) {
if (!treeEdgesCache.containsKey(treeNodesMap)) {
BitSet edgesMap = new BitSet();
for (Edge edge : edges) {
if (LongBitmap.isSubset(edge.getReferenceNodes(), treeNodesMap)) {
edgesMap.set(edge.getIndex());
}
}
treeEdgesCache.put(treeNodesMap, edgesMap);
}
return treeEdgesCache.get(treeNodesMap);
}
private long calNodeMap(Set<Slot> slots) {
Preconditions.checkArgument(slots.size() != 0);
long bitmap = LongBitmap.newBitmap();