diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java index 4172e54559..601aa001ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java @@ -31,6 +31,7 @@ import java.util.BitSet; import java.util.HashSet; import java.util.List; import java.util.Set; +import javax.annotation.Nullable; /** * Edge in HyperGraph @@ -214,10 +215,6 @@ public class Edge { return join.getExpressions().get(0); } - public List getExpressions() { - return join.getExpressions(); - } - public List getHashJoinConjuncts() { return join.getHashJoinConjuncts(); } @@ -237,5 +234,27 @@ public class Edge { return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString( rightExtendedNodes)); } + + /** + * extract join type and conjuncts from edges + */ + public static @Nullable JoinType extractJoinTypeAndConjuncts(List edges, + List hashConjuncts, List otherConjuncts) { + JoinType joinType = null; + for (Edge edge : edges) { + if (edge.getJoinType() != joinType && joinType != null) { + return null; + } + Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); + joinType = edge.getJoinType(); + hashConjuncts.addAll(edge.getHashJoinConjuncts()); + otherConjuncts.addAll(edge.getOtherJoinConjuncts()); + } + return joinType; + } + + public static Edge createTempEdge(LogicalJoin join) { + return new Edge(join, -1, null, null, 0L); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index 380dd698a4..1beb4a4f06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -24,6 +24,8 @@ import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; import org.apache.doris.nereids.stats.JoinEstimation; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; @@ -32,6 +34,7 @@ import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; @@ -39,6 +42,7 @@ import java.util.List; import java.util.Optional; import java.util.PriorityQueue; import java.util.Stack; +import java.util.stream.Collectors; import javax.annotation.Nullable; /** @@ -388,12 +392,38 @@ public class GraphSimplifier { return Pair.of(joinStats, edge); } + private Edge processMissedEdges(int edgeIndex1, int edgeIndex2, Edge edge) { + List edges = Lists.newArrayList(edge); + edges.addAll(graph.getEdges().stream() + .filter(e -> e.getIndex() != edgeIndex1 && e.getIndex() != edgeIndex2 + && LongBitmap.isSubset(e.getReferenceNodes(), edge.getReferenceNodes()) + && !LongBitmap.isSubset(e.getReferenceNodes(), edge.getLeftExtendedNodes()) + && !LongBitmap.isSubset(e.getReferenceNodes(), edge.getRightExtendedNodes())) + .collect(Collectors.toList())); + if (edges.size() > 1) { + List hashConjuncts = new ArrayList<>(); + List otherConjuncts = new ArrayList<>(); + JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); + LogicalJoin oldJoin = edge.getJoin(); + LogicalJoin newJoin = new LogicalJoin<>(joinType, hashConjuncts, + otherConjuncts, oldJoin.getHint(), oldJoin.left(), oldJoin.right()); + Edge newEdge = Edge.createTempEdge(newJoin); + newEdge.setLeftExtendedNodes(edge.getLeftExtendedNodes()); + newEdge.setRightExtendedNodes(edge.getRightExtendedNodes()); + return newEdge; + } else { + return edge; + } + } + private SimplificationStep orderJoin(Pair edge1Before2, - Pair edge2Before1, int edgeIndex1, int edgeIndex2) { - Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first, + Pair edge2Before1, int edgeIndex1, int edgeIndex2) { + Edge edge = processMissedEdges(edgeIndex1, edgeIndex2, edge1Before2.second); + Cost cost1Before2 = calCost(edge, edge1Before2.first, cacheStats.get(edge1Before2.second.getLeftExtendedNodes()), cacheStats.get(edge1Before2.second.getRightExtendedNodes())); - Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first, + edge = processMissedEdges(edgeIndex1, edgeIndex2, edge2Before1.second); + Cost cost2Before1 = calCost(edge, edge1Before2.first, cacheStats.get(edge1Before2.second.getLeftExtendedNodes()), cacheStats.get(edge1Before2.second.getRightExtendedNodes())); double benefit = Double.MAX_VALUE; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index b33c8134c4..099118c65e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -60,7 +60,6 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -import javax.annotation.Nullable; /** * The Receiver is used for cached the plan that has been emitted and build the new plan @@ -118,7 +117,7 @@ public class PlanReceiver implements AbstractReceiver { List hashConjuncts = new ArrayList<>(); List otherConjuncts = new ArrayList<>(); - JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); + JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); if (joinType == null) { return true; } @@ -237,21 +236,6 @@ public class PlanReceiver implements AbstractReceiver { return plans; } - private @Nullable JoinType extractJoinTypeAndConjuncts(List edges, List hashConjuncts, - List otherConjuncts) { - JoinType joinType = null; - for (Edge edge : edges) { - if (edge.getJoinType() != joinType && joinType != null) { - return null; - } - Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); - joinType = edge.getJoinType(); - hashConjuncts.addAll(edge.getHashJoinConjuncts()); - otherConjuncts.addAll(edge.getOtherJoinConjuncts()); - } - return joinType; - } - private boolean extractIsMarkJoin(List edges) { boolean isMarkJoin = false; JoinType joinType = null;