[fix](nereids) dphyper join reorder may lost join condition in some case (#16995)
when emitCsgCmp, we should check if there is some missed edges should be used as connection edge. If there is missed edge but can't be used as connection edge, the emitCsgCmp should return and seek for another plan.
This commit is contained in:
@ -19,10 +19,17 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
|
||||
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Edge in HyperGraph
|
||||
*/
|
||||
@ -35,6 +42,10 @@ public class Edge {
|
||||
// left and right may not overlap, and both must have at least one bit set.
|
||||
private long left = LongBitmap.newBitmap();
|
||||
private long right = LongBitmap.newBitmap();
|
||||
|
||||
private long originalLeft = LongBitmap.newBitmap();
|
||||
private long originalRight = LongBitmap.newBitmap();
|
||||
|
||||
private long referenceNodes = LongBitmap.newBitmap();
|
||||
|
||||
/**
|
||||
@ -100,6 +111,22 @@ public class Edge {
|
||||
this.right = right;
|
||||
}
|
||||
|
||||
public long getOriginalLeft() {
|
||||
return originalLeft;
|
||||
}
|
||||
|
||||
public void setOriginalLeft(long left) {
|
||||
this.originalLeft = left;
|
||||
}
|
||||
|
||||
public long getOriginalRight() {
|
||||
return originalRight;
|
||||
}
|
||||
|
||||
public void setOriginalRight(long right) {
|
||||
this.originalRight = right;
|
||||
}
|
||||
|
||||
public boolean isSub(Edge edge) {
|
||||
// When this join reference nodes is a subset of other join, then this join must appear before that join
|
||||
long otherBitmap = edge.getReferenceNodes();
|
||||
@ -122,9 +149,20 @@ public class Edge {
|
||||
}
|
||||
|
||||
public Expression getExpression() {
|
||||
Preconditions.checkArgument(join.getExpressions().size() == 1);
|
||||
return join.getExpressions().get(0);
|
||||
}
|
||||
|
||||
public List<? extends Expression> getExpressions() {
|
||||
return join.getExpressions();
|
||||
}
|
||||
|
||||
public final Set<Slot> getInputSlots() {
|
||||
Set<Slot> slots = new HashSet<>();
|
||||
join.getExpressions().stream().forEach(expression -> slots.addAll(expression.getInputSlots()));
|
||||
return slots;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right));
|
||||
|
||||
@ -24,17 +24,18 @@ import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinHint;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@ -131,13 +132,31 @@ public class HyperGraph {
|
||||
public void addEdge(Group group) {
|
||||
Preconditions.checkArgument(group.isJoinGroup());
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
|
||||
for (Expression expression : join.getExpressions()) {
|
||||
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), ImmutableList.of(expression), join.left(),
|
||||
join.right());
|
||||
Edge edge = new Edge(singleJoin, edges.size());
|
||||
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
|
||||
for (Expression expression : join.getHashJoinConjuncts()) {
|
||||
Pair<Long, Long> ends = findEnds(expression);
|
||||
if (!conjuncts.containsKey(ends)) {
|
||||
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
|
||||
}
|
||||
conjuncts.get(ends).first.add(expression);
|
||||
}
|
||||
for (Expression expression : join.getOtherJoinConjuncts()) {
|
||||
Pair<Long, Long> ends = findEnds(expression);
|
||||
if (!conjuncts.containsKey(ends)) {
|
||||
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
|
||||
}
|
||||
conjuncts.get(ends).second.add(expression);
|
||||
}
|
||||
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
|
||||
.entrySet()) {
|
||||
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
|
||||
entry.getValue().second, JoinHint.NONE, join.left(), join.right());
|
||||
Edge edge = new Edge(singleJoin, edges.size());
|
||||
Pair<Long, Long> ends = entry.getKey();
|
||||
edge.setLeft(ends.first);
|
||||
edge.setOriginalLeft(ends.first);
|
||||
edge.setRight(ends.second);
|
||||
edge.setOriginalRight(ends.second);
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
|
||||
nodes.get(nodeIndex).attachEdge(edge);
|
||||
}
|
||||
|
||||
@ -105,6 +105,12 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
Preconditions.checkArgument(planTable.containsKey(left));
|
||||
Preconditions.checkArgument(planTable.containsKey(right));
|
||||
|
||||
// check if the missed edges can be correctly connected by add it to edges
|
||||
// if not, the plan is invalid because of the missed edges, just return and seek for another valid plan
|
||||
if (!processMissedEdges(left, right, edges)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Memo memo = jobContext.getCascadesContext().getMemo();
|
||||
emitCount += 1;
|
||||
if (emitCount > limit) {
|
||||
@ -151,7 +157,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap);
|
||||
for (Edge edge : hyperGraph.getEdges()) {
|
||||
if (!usedEdgesBitmap.get(edge.getIndex())) {
|
||||
outputSlots.addAll(edge.getExpression().getInputSlots());
|
||||
outputSlots.addAll(edge.getInputSlots());
|
||||
}
|
||||
}
|
||||
hyperGraph.getComplexProject()
|
||||
@ -162,6 +168,47 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
return outputSlots;
|
||||
}
|
||||
|
||||
// check if the missed edges can be used to connect left and right together with edges
|
||||
// return true if no missed edge or the missed edge can be used to connect left and right
|
||||
// the returned edges includes missed edges if there is any.
|
||||
private boolean processMissedEdges(long left, long right, List<Edge> edges) {
|
||||
boolean canAddMisssedEdges = true;
|
||||
|
||||
// find all reference nodes assume left and right sub graph is connected
|
||||
BitSet usedEdgesBitmap = new BitSet();
|
||||
usedEdgesBitmap.or(usdEdges.get(left));
|
||||
usedEdgesBitmap.or(usdEdges.get(right));
|
||||
edges.stream().forEach(edge -> usedEdgesBitmap.set(edge.getIndex()));
|
||||
long allReferenceNodes = getAllReferenceNodes(usedEdgesBitmap);
|
||||
|
||||
// check all edges
|
||||
// the edge is a missed edge if the edge is not used and its reference nodes is a subset of allReferenceNodes
|
||||
for (Edge edge : hyperGraph.getEdges()) {
|
||||
if (LongBitmap.isSubset(edge.getReferenceNodes(), allReferenceNodes) && !usedEdgesBitmap.get(
|
||||
edge.getIndex())) {
|
||||
// check the missed edge can be used to connect left and right together with edges
|
||||
// if the missed edge meet the 2 conditions, it is a valid edge
|
||||
// 1. the edge's left child's referenced nodes is subset of the left
|
||||
// 2. the edge's original right node is subset of right
|
||||
canAddMisssedEdges = canAddMisssedEdges && LongBitmap.isSubset(edge.getLeft(),
|
||||
left) && LongBitmap.isSubset(edge.getOriginalRight(), right);
|
||||
|
||||
// always add the missed edge to edges
|
||||
// because the caller will return immediately if canAddMisssedEdges is false
|
||||
edges.add(edge);
|
||||
}
|
||||
}
|
||||
return canAddMisssedEdges;
|
||||
}
|
||||
|
||||
private long getAllReferenceNodes(BitSet edgesBitmap) {
|
||||
long nodes = LongBitmap.newBitmap();
|
||||
for (int i = edgesBitmap.nextSetBit(0); i >= 0; i = edgesBitmap.nextSetBit(i + 1)) {
|
||||
nodes = LongBitmap.or(nodes, hyperGraph.getEdge(i).getReferenceNodes());
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
private void proposeAllDistributedPlans(GroupExpression groupExpression) {
|
||||
jobContext.getCascadesContext().pushJob(new CostAndEnforcerJob(groupExpression,
|
||||
new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE)));
|
||||
@ -200,11 +247,12 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
for (Edge edge : edges) {
|
||||
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
|
||||
joinType = edge.getJoinType();
|
||||
Expression expression = edge.getExpression();
|
||||
if (expression instanceof EqualTo) {
|
||||
hashConjuncts.add(edge.getExpression());
|
||||
} else {
|
||||
otherConjuncts.add(expression);
|
||||
for (Expression expression : edge.getExpressions()) {
|
||||
if (expression instanceof EqualTo) {
|
||||
hashConjuncts.add(expression);
|
||||
} else {
|
||||
otherConjuncts.add(expression);
|
||||
}
|
||||
}
|
||||
}
|
||||
return joinType;
|
||||
@ -231,6 +279,8 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
@Override
|
||||
public void reset() {
|
||||
planTable.clear();
|
||||
projectsOnSubgraph.clear();
|
||||
usdEdges.clear();
|
||||
emitCount = 0;
|
||||
}
|
||||
|
||||
|
||||
@ -85,10 +85,10 @@ public class HyperGraphTest {
|
||||
+ " LOGICAL_OLAP_SCAN3 [label=\"LOGICAL_OLAP_SCAN3 \n"
|
||||
+ " rowCount=40.00\"];\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN1 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
|
||||
+ "}\n";
|
||||
Assertions.assertEquals(dottyGraph, target);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user