[improvement](nereids)dphyper GraphSimplifier should consider missed edges when estimating join cost (#21747)
This commit is contained in:
@ -31,6 +31,7 @@ import java.util.BitSet;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Edge in HyperGraph
|
||||
@ -214,10 +215,6 @@ public class Edge {
|
||||
return join.getExpressions().get(0);
|
||||
}
|
||||
|
||||
public List<? extends Expression> getExpressions() {
|
||||
return join.getExpressions();
|
||||
}
|
||||
|
||||
public List<Expression> getHashJoinConjuncts() {
|
||||
return join.getHashJoinConjuncts();
|
||||
}
|
||||
@ -237,5 +234,27 @@ public class Edge {
|
||||
return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString(
|
||||
rightExtendedNodes));
|
||||
}
|
||||
|
||||
/**
|
||||
* extract join type and conjuncts from edges
|
||||
*/
|
||||
public static @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges,
|
||||
List<Expression> hashConjuncts, List<Expression> otherConjuncts) {
|
||||
JoinType joinType = null;
|
||||
for (Edge edge : edges) {
|
||||
if (edge.getJoinType() != joinType && joinType != null) {
|
||||
return null;
|
||||
}
|
||||
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
|
||||
joinType = edge.getJoinType();
|
||||
hashConjuncts.addAll(edge.getHashJoinConjuncts());
|
||||
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
|
||||
}
|
||||
return joinType;
|
||||
}
|
||||
|
||||
public static Edge createTempEdge(LogicalJoin join) {
|
||||
return new Edge(join, -1, null, null, 0L);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,8 @@ import org.apache.doris.nereids.cost.CostCalculator;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
|
||||
import org.apache.doris.nereids.stats.JoinEstimation;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
|
||||
@ -32,6 +34,7 @@ import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@ -39,6 +42,7 @@ import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Stack;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
@ -388,12 +392,38 @@ public class GraphSimplifier {
|
||||
return Pair.of(joinStats, edge);
|
||||
}
|
||||
|
||||
private Edge processMissedEdges(int edgeIndex1, int edgeIndex2, Edge edge) {
|
||||
List<Edge> edges = Lists.newArrayList(edge);
|
||||
edges.addAll(graph.getEdges().stream()
|
||||
.filter(e -> e.getIndex() != edgeIndex1 && e.getIndex() != edgeIndex2
|
||||
&& LongBitmap.isSubset(e.getReferenceNodes(), edge.getReferenceNodes())
|
||||
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getLeftExtendedNodes())
|
||||
&& !LongBitmap.isSubset(e.getReferenceNodes(), edge.getRightExtendedNodes()))
|
||||
.collect(Collectors.toList()));
|
||||
if (edges.size() > 1) {
|
||||
List<Expression> hashConjuncts = new ArrayList<>();
|
||||
List<Expression> otherConjuncts = new ArrayList<>();
|
||||
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
|
||||
LogicalJoin oldJoin = edge.getJoin();
|
||||
LogicalJoin newJoin = new LogicalJoin<>(joinType, hashConjuncts,
|
||||
otherConjuncts, oldJoin.getHint(), oldJoin.left(), oldJoin.right());
|
||||
Edge newEdge = Edge.createTempEdge(newJoin);
|
||||
newEdge.setLeftExtendedNodes(edge.getLeftExtendedNodes());
|
||||
newEdge.setRightExtendedNodes(edge.getRightExtendedNodes());
|
||||
return newEdge;
|
||||
} else {
|
||||
return edge;
|
||||
}
|
||||
}
|
||||
|
||||
private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2,
|
||||
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
|
||||
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
|
||||
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
|
||||
Edge edge = processMissedEdges(edgeIndex1, edgeIndex2, edge1Before2.second);
|
||||
Cost cost1Before2 = calCost(edge, edge1Before2.first,
|
||||
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
|
||||
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
|
||||
Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
|
||||
edge = processMissedEdges(edgeIndex1, edgeIndex2, edge2Before1.second);
|
||||
Cost cost2Before1 = calCost(edge, edge1Before2.first,
|
||||
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
|
||||
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
|
||||
double benefit = Double.MAX_VALUE;
|
||||
|
||||
@ -60,7 +60,6 @@ import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* The Receiver is used for cached the plan that has been emitted and build the new plan
|
||||
@ -118,7 +117,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
List<Expression> hashConjuncts = new ArrayList<>();
|
||||
List<Expression> otherConjuncts = new ArrayList<>();
|
||||
|
||||
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
|
||||
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
|
||||
if (joinType == null) {
|
||||
return true;
|
||||
}
|
||||
@ -237,21 +236,6 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
return plans;
|
||||
}
|
||||
|
||||
private @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
|
||||
List<Expression> otherConjuncts) {
|
||||
JoinType joinType = null;
|
||||
for (Edge edge : edges) {
|
||||
if (edge.getJoinType() != joinType && joinType != null) {
|
||||
return null;
|
||||
}
|
||||
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
|
||||
joinType = edge.getJoinType();
|
||||
hashConjuncts.addAll(edge.getHashJoinConjuncts());
|
||||
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
|
||||
}
|
||||
return joinType;
|
||||
}
|
||||
|
||||
private boolean extractIsMarkJoin(List<Edge> edges) {
|
||||
boolean isMarkJoin = false;
|
||||
JoinType joinType = null;
|
||||
|
||||
Reference in New Issue
Block a user