[improvement](nereids)dphyper GraphSimplifier should consider missed edges when estimating join cost (#21747)

This commit is contained in:
starocean999
2023-09-28 09:30:57 +08:00
committed by GitHub
parent 430634367a
commit 584646c054
3 changed files with 57 additions and 24 deletions

View File

@ -31,6 +31,7 @@ import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
/**
* Edge in HyperGraph
@ -214,10 +215,6 @@ public class Edge {
return join.getExpressions().get(0);
}
public List<? extends Expression> getExpressions() {
return join.getExpressions();
}
public List<Expression> getHashJoinConjuncts() {
return join.getHashJoinConjuncts();
}
@ -237,5 +234,27 @@ public class Edge {
return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString(
rightExtendedNodes));
}
/**
* extract join type and conjuncts from edges
*/
public static @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges,
List<Expression> hashConjuncts, List<Expression> otherConjuncts) {
JoinType joinType = null;
for (Edge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
hashConjuncts.addAll(edge.getHashJoinConjuncts());
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
}
return joinType;
}
public static Edge createTempEdge(LogicalJoin join) {
return new Edge(join, -1, null, null, 0L);
}
}

View File

@ -24,6 +24,8 @@ import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.stats.JoinEstimation;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
@ -32,6 +34,7 @@ import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
@ -39,6 +42,7 @@ import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Stack;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
@ -388,12 +392,38 @@ public class GraphSimplifier {
return Pair.of(joinStats, edge);
}
private Edge processMissedEdges(int edgeIndex1, int edgeIndex2, Edge edge) {
List<Edge> edges = Lists.newArrayList(edge);
edges.addAll(graph.getEdges().stream()
.filter(e -> e.getIndex() != edgeIndex1 && e.getIndex() != edgeIndex2
&& LongBitmap.isSubset(e.getReferenceNodes(), edge.getReferenceNodes())
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getLeftExtendedNodes())
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getRightExtendedNodes()))
.collect(Collectors.toList()));
if (edges.size() > 1) {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
LogicalJoin oldJoin = edge.getJoin();
LogicalJoin newJoin = new LogicalJoin<>(joinType, hashConjuncts,
otherConjuncts, oldJoin.getHint(), oldJoin.left(), oldJoin.right());
Edge newEdge = Edge.createTempEdge(newJoin);
newEdge.setLeftExtendedNodes(edge.getLeftExtendedNodes());
newEdge.setRightExtendedNodes(edge.getRightExtendedNodes());
return newEdge;
} else {
return edge;
}
}
private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2,
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
Edge edge = processMissedEdges(edgeIndex1, edgeIndex2, edge1Before2.second);
Cost cost1Before2 = calCost(edge, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
edge = processMissedEdges(edgeIndex1, edgeIndex2, edge2Before1.second);
Cost cost2Before1 = calCost(edge, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
double benefit = Double.MAX_VALUE;

View File

@ -60,7 +60,6 @@ import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
/**
* The Receiver is used for cached the plan that has been emitted and build the new plan
@ -118,7 +117,7 @@ public class PlanReceiver implements AbstractReceiver {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
}
@ -237,21 +236,6 @@ public class PlanReceiver implements AbstractReceiver {
return plans;
}
private @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
JoinType joinType = null;
for (Edge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
hashConjuncts.addAll(edge.getHashJoinConjuncts());
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
}
return joinType;
}
private boolean extractIsMarkJoin(List<Edge> edges) {
boolean isMarkJoin = false;
JoinType joinType = null;