[fix](nereids) dphyper join reorder may lost join condition in some case (#16995)

when emitCsgCmp, we should check if there is some missed edges should be used as connection edge. If there is missed edge but can't be used as connection edge, the emitCsgCmp should return and seek for another plan.
This commit is contained in:
starocean999
2023-02-27 10:36:11 +08:00
committed by GitHub
parent f228cfdd00
commit 2d5f32caf1
5 changed files with 269 additions and 13 deletions

View File

@ -19,10 +19,17 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* Edge in HyperGraph
*/
@ -35,6 +42,10 @@ public class Edge {
// left and right may not overlap, and both must have at least one bit set.
private long left = LongBitmap.newBitmap();
private long right = LongBitmap.newBitmap();
private long originalLeft = LongBitmap.newBitmap();
private long originalRight = LongBitmap.newBitmap();
private long referenceNodes = LongBitmap.newBitmap();
/**
@ -100,6 +111,22 @@ public class Edge {
this.right = right;
}
public long getOriginalLeft() {
return originalLeft;
}
public void setOriginalLeft(long left) {
this.originalLeft = left;
}
public long getOriginalRight() {
return originalRight;
}
public void setOriginalRight(long right) {
this.originalRight = right;
}
public boolean isSub(Edge edge) {
// When this join reference nodes is a subset of other join, then this join must appear before that join
long otherBitmap = edge.getReferenceNodes();
@ -122,9 +149,20 @@ public class Edge {
}
public Expression getExpression() {
Preconditions.checkArgument(join.getExpressions().size() == 1);
return join.getExpressions().get(0);
}
public List<? extends Expression> getExpressions() {
return join.getExpressions();
}
public final Set<Slot> getInputSlots() {
Set<Slot> slots = new HashSet<>();
join.getExpressions().stream().forEach(expression -> slots.addAll(expression.getInputSlots()));
return slots;
}
@Override
public String toString() {
return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right));

View File

@ -24,17 +24,18 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
@ -131,13 +132,31 @@ public class HyperGraph {
public void addEdge(Group group) {
Preconditions.checkArgument(group.isJoinGroup());
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
for (Expression expression : join.getExpressions()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), ImmutableList.of(expression), join.left(),
join.right());
Edge edge = new Edge(singleJoin, edges.size());
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
for (Expression expression : join.getHashJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
if (!conjuncts.containsKey(ends)) {
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
}
conjuncts.get(ends).first.add(expression);
}
for (Expression expression : join.getOtherJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
if (!conjuncts.containsKey(ends)) {
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
}
conjuncts.get(ends).second.add(expression);
}
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
entry.getValue().second, JoinHint.NONE, join.left(), join.right());
Edge edge = new Edge(singleJoin, edges.size());
Pair<Long, Long> ends = entry.getKey();
edge.setLeft(ends.first);
edge.setOriginalLeft(ends.first);
edge.setRight(ends.second);
edge.setOriginalRight(ends.second);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}

View File

@ -105,6 +105,12 @@ public class PlanReceiver implements AbstractReceiver {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
// check if the missed edges can be correctly connected by add it to edges
// if not, the plan is invalid because of the missed edges, just return and seek for another valid plan
if (!processMissedEdges(left, right, edges)) {
return true;
}
Memo memo = jobContext.getCascadesContext().getMemo();
emitCount += 1;
if (emitCount > limit) {
@ -151,7 +157,7 @@ public class PlanReceiver implements AbstractReceiver {
usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap);
for (Edge edge : hyperGraph.getEdges()) {
if (!usedEdgesBitmap.get(edge.getIndex())) {
outputSlots.addAll(edge.getExpression().getInputSlots());
outputSlots.addAll(edge.getInputSlots());
}
}
hyperGraph.getComplexProject()
@ -162,6 +168,47 @@ public class PlanReceiver implements AbstractReceiver {
return outputSlots;
}
// check if the missed edges can be used to connect left and right together with edges
// return true if no missed edge or the missed edge can be used to connect left and right
// the returned edges includes missed edges if there is any.
private boolean processMissedEdges(long left, long right, List<Edge> edges) {
boolean canAddMisssedEdges = true;
// find all reference nodes assume left and right sub graph is connected
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
usedEdgesBitmap.or(usdEdges.get(right));
edges.stream().forEach(edge -> usedEdgesBitmap.set(edge.getIndex()));
long allReferenceNodes = getAllReferenceNodes(usedEdgesBitmap);
// check all edges
// the edge is a missed edge if the edge is not used and its reference nodes is a subset of allReferenceNodes
for (Edge edge : hyperGraph.getEdges()) {
if (LongBitmap.isSubset(edge.getReferenceNodes(), allReferenceNodes) && !usedEdgesBitmap.get(
edge.getIndex())) {
// check the missed edge can be used to connect left and right together with edges
// if the missed edge meet the 2 conditions, it is a valid edge
// 1. the edge's left child's referenced nodes is subset of the left
// 2. the edge's original right node is subset of right
canAddMisssedEdges = canAddMisssedEdges && LongBitmap.isSubset(edge.getLeft(),
left) && LongBitmap.isSubset(edge.getOriginalRight(), right);
// always add the missed edge to edges
// because the caller will return immediately if canAddMisssedEdges is false
edges.add(edge);
}
}
return canAddMisssedEdges;
}
private long getAllReferenceNodes(BitSet edgesBitmap) {
long nodes = LongBitmap.newBitmap();
for (int i = edgesBitmap.nextSetBit(0); i >= 0; i = edgesBitmap.nextSetBit(i + 1)) {
nodes = LongBitmap.or(nodes, hyperGraph.getEdge(i).getReferenceNodes());
}
return nodes;
}
private void proposeAllDistributedPlans(GroupExpression groupExpression) {
jobContext.getCascadesContext().pushJob(new CostAndEnforcerJob(groupExpression,
new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE)));
@ -200,11 +247,12 @@ public class PlanReceiver implements AbstractReceiver {
for (Edge edge : edges) {
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
Expression expression = edge.getExpression();
if (expression instanceof EqualTo) {
hashConjuncts.add(edge.getExpression());
} else {
otherConjuncts.add(expression);
for (Expression expression : edge.getExpressions()) {
if (expression instanceof EqualTo) {
hashConjuncts.add(expression);
} else {
otherConjuncts.add(expression);
}
}
}
return joinType;
@ -231,6 +279,8 @@ public class PlanReceiver implements AbstractReceiver {
@Override
public void reset() {
planTable.clear();
projectsOnSubgraph.clear();
usdEdges.clear();
emitCount = 0;
}

View File

@ -85,10 +85,10 @@ public class HyperGraphTest {
+ " LOGICAL_OLAP_SCAN3 [label=\"LOGICAL_OLAP_SCAN3 \n"
+ " rowCount=40.00\"];\n"
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN1 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
+ "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n"
+ "}\n";
Assertions.assertEquals(dottyGraph, target);
}