[enhancement](Nereids): Use long bitmap in DPHyp (#14725)

This commit is contained in:
谢健
2022-12-01 20:47:45 +08:00
committed by GitHub
parent ba9a777554
commit 302da03b18
21 changed files with 556 additions and 554 deletions

View File

@ -77,7 +77,25 @@ public class JoinOrderJob extends Job {
throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit);
}
}
return planReceiver.getBestPlan(hyperGraph.getNodesMap());
Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap());
return copyToMemo(optimized);
}
private Group copyToMemo(Group root) {
if (!root.isJoinGroup()) {
return root;
}
GroupExpression groupExpression = root.getLogicalExpression();
int arity = groupExpression.arity();
for (int i = 0; i < arity; i++) {
Group childGroup = groupExpression.child(i);
Group newChildGroup = copyToMemo(childGroup);
groupExpression.setChild(i, newChildGroup);
}
Group newRoot = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression);
newRoot.setStatistics(root.getStatistics());
return newRoot;
}
/**

View File

@ -17,12 +17,11 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
/**
@ -38,16 +37,16 @@ public class CircleDetector {
// record the node in certain order, named i2n in paper
List<Integer> nodes = new ArrayList<>();
// stored the dependency of each node
List<BitSet> directedEdges = new ArrayList<>();
List<Long> directedEdges = new ArrayList<>();
// the nodes are after than this node
List<BitSet> subNodes = new ArrayList<>();
List<Long> subNodes = new ArrayList<>();
CircleDetector(int size) {
for (int i = 0; i < size; i++) {
orders.add(i);
nodes.add(i);
directedEdges.add(Bitmap.newBitmap());
subNodes.add(Bitmap.newBitmap(i));
directedEdges.add(LongBitmap.newBitmap());
subNodes.add(LongBitmap.newBitmap(i));
}
}
@ -63,19 +62,21 @@ public class CircleDetector {
if (checkCircleWithEdge(node1, node2)) {
return false;
}
Bitmap.set(directedEdges.get(node1), node2);
directedEdges.set(node1, LongBitmap.set(directedEdges.get(node1), node2));
int order1 = orders.get(node1);
int order2 = orders.get(node2);
if (order1 >= order2) {
shift(order2, order1 + 1, subNodes.get(node2));
}
for (BitSet nodes : subNodes) {
// add all subNodes which contains node1 into subNodes of node2.
if (Bitmap.get(nodes, node1)) {
Bitmap.or(nodes, subNodes.get(node2));
int size = subNodes.size();
for (int i = 0; i < size; i++) {
// add the subNodes which contains node1 with subNodes of node2.
long nodes = subNodes.get(i);
if (LongBitmap.get(nodes, node1)) {
subNodes.set(i, LongBitmap.or(nodes, subNodes.get(node2)));
}
}
Bitmap.or(subNodes.get(node1), subNodes.get(node2));
subNodes.set(node1, LongBitmap.or(subNodes.get(node1), subNodes.get(node2)));
return true;
}
@ -86,13 +87,14 @@ public class CircleDetector {
* @param node2 the end node of the edge
*/
public void deleteDirectedEdge(int node1, int node2) {
Preconditions.checkArgument(Bitmap.get(directedEdges.get(node1), node2),
Preconditions.checkArgument(LongBitmap.get(directedEdges.get(node1), node2),
String.format("The edge %d -> %d is not existed", node1, node2));
for (BitSet nodes : subNodes) {
Bitmap.clear(nodes);
int size = subNodes.size();
for (int i = 0; i < size; i++) {
subNodes.set(i, LongBitmap.newBitmap());
}
int size = orders.size();
size = orders.size();
for (int i = 0; i < size; i++) {
getSubNodes(i);
}
@ -111,30 +113,30 @@ public class CircleDetector {
*/
public boolean checkCircleWithEdge(int node1, int node2) {
// return true when there is a circle
return Bitmap.get(subNodes.get(node2), node1);
return LongBitmap.get(subNodes.get(node2), node1);
}
private BitSet getSubNodes(int node) {
if (Bitmap.getCardinality(subNodes.get(node)) != 0) {
private long getSubNodes(int node) {
if (LongBitmap.getCardinality(subNodes.get(node)) != 0) {
return subNodes.get(node);
}
for (int nextNode : Bitmap.getIterator(directedEdges.get(node))) {
for (int nextNode : LongBitmap.getIterator(directedEdges.get(node))) {
Preconditions.checkArgument(orders.get(nextNode) > orders.get(node),
String.format("node %d must come after node %d", nextNode, node));
Bitmap.or(subNodes.get(node), getSubNodes(nextNode));
subNodes.set(node, LongBitmap.or(subNodes.get(node), getSubNodes(nextNode)));
}
return subNodes.get(node);
}
private void shift(int startOrder, int endOrder, BitSet visited) {
private void shift(int startOrder, int endOrder, long visited) {
// Reorder the nodes between order1 and order2. We always keep the nodes visited comes
// before the other nodes and their relative order is not changed. Because those two parts
// is not connected, we can do it safely.
List<Integer> shiftNodes = new ArrayList<>();
for (int o = startOrder; o < endOrder; o++) {
int node = nodes.get(o);
if (Bitmap.get(visited, node)) {
if (LongBitmap.get(visited, node)) {
shiftNodes.add(node);
} else {
// the relative orders of visited nodes are not changed
@ -158,7 +160,7 @@ public class CircleDetector {
StringBuilder builder = new StringBuilder();
int size = directedEdges.size();
for (int i = 0; i < size; i++) {
if (Bitmap.getCardinality(directedEdges.get(i)) != 0) {
if (LongBitmap.getCardinality(directedEdges.get(i)) != 0) {
builder.append(String.format("%d -> %s; ", i, directedEdges.get(i)));
}
}

View File

@ -17,13 +17,11 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.BitSet;
/**
* Edge in HyperGraph
*/
@ -34,9 +32,9 @@ public class Edge {
// The endpoints (hyperNodes) of this hyperEdge.
// left and right may not overlap, and both must have at least one bit set.
private BitSet left = Bitmap.newBitmap();
private BitSet right = Bitmap.newBitmap();
private BitSet referenceNodes = Bitmap.newBitmap();
private long left = LongBitmap.newBitmap();
private long right = LongBitmap.newBitmap();
private long referenceNodes = LongBitmap.newBitmap();
/**
* Create simple edge.
@ -52,71 +50,64 @@ public class Edge {
}
public boolean isSimple() {
return Bitmap.getCardinality(left) == 1 && Bitmap.getCardinality(right) == 1;
return LongBitmap.getCardinality(left) == 1 && LongBitmap.getCardinality(right) == 1;
}
public void addLeftNode(BitSet left) {
Bitmap.or(this.left, left);
Bitmap.or(referenceNodes, left);
public void addLeftNode(long left) {
this.left = LongBitmap.or(this.left, left);
referenceNodes = LongBitmap.or(referenceNodes, left);
}
public void addLeftNodes(BitSet... bitSets) {
for (BitSet bitSet : bitSets) {
Bitmap.or(this.left, bitSet);
Bitmap.or(referenceNodes, bitSet);
public void addLeftNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
this.left = LongBitmap.or(this.left, bitmap);
referenceNodes = LongBitmap.or(referenceNodes, bitmap);
}
}
public void addRightNode(BitSet right) {
Bitmap.or(this.right, right);
Bitmap.or(referenceNodes, right);
public void addRightNode(long right) {
this.right = LongBitmap.or(this.right, right);
referenceNodes = LongBitmap.or(referenceNodes, right);
}
public void addRightNodes(BitSet... bitSets) {
for (BitSet bitSet : bitSets) {
Bitmap.or(this.right, bitSet);
Bitmap.or(referenceNodes, bitSet);
public void addRightNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
LongBitmap.or(this.right, bitmap);
LongBitmap.or(referenceNodes, bitmap);
}
}
public BitSet getLeft() {
public long getLeft() {
return left;
}
public void setLeft(BitSet left) {
referenceNodes.clear();
public void setLeft(long left) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.left = left;
}
public BitSet getRight() {
public long getRight() {
return right;
}
public void setRight(BitSet right) {
referenceNodes.clear();
public void setRight(long right) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.right = right;
}
public boolean isSub(Edge edge) {
// When this join reference nodes is a subset of other join, then this join must appear before that join
BitSet otherBitset = edge.getReferenceNodes();
return Bitmap.isSubset(getReferenceNodes(), otherBitset);
long otherBitmap = edge.getReferenceNodes();
return LongBitmap.isSubset(getReferenceNodes(), otherBitmap);
}
public BitSet getReferenceNodes() {
if (referenceNodes.cardinality() == 0) {
referenceNodes = Bitmap.newBitmapUnion(left, right);
public long getReferenceNodes() {
if (LongBitmap.getCardinality(referenceNodes) == 0) {
referenceNodes = LongBitmap.newBitmapUnion(left, right);
}
return referenceNodes;
}
public Edge reverse(int index) {
Edge newEdge = new Edge(join, index);
newEdge.addLeftNode(right);
newEdge.addRightNode(left);
return newEdge;
}
public int getIndex() {
return index;
}
@ -134,7 +125,7 @@ public class Edge {
@Override
public String toString() {
return String.format("<%s - %s>", left, right);
return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right));
}
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
@ -56,7 +55,7 @@ public class GraphSimplifier {
// It cached the plan in simplification. we don't store it in hyper graph,
// because it's just used for simulating join. In fact, the graph simplifier
// just generate the partial order of join operator.
private HashMap<BitSet, Plan> cachePlan = new HashMap<>();
private HashMap<Long, Plan> cachePlan = new HashMap<>();
private Stack<SimplificationStep> appliedSteps = new Stack<SimplificationStep>();
private Stack<SimplificationStep> unAppliedSteps = new Stack<SimplificationStep>();
@ -99,7 +98,7 @@ public class GraphSimplifier {
for (int j = i + 1; j < edgeSize; j++) {
Edge edge1 = graph.getEdge(i);
Edge edge2 = graph.getEdge(j);
List<BitSet> superset = new ArrayList<>();
List<Long> superset = new ArrayList<>();
tryGetSuperset(edge1.getLeft(), edge2.getLeft(), superset);
tryGetSuperset(edge1.getLeft(), edge2.getRight(), superset);
tryGetSuperset(edge1.getRight(), edge2.getLeft(), superset);
@ -300,13 +299,13 @@ public class GraphSimplifier {
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)) {
return Optional.empty();
}
BitSet left1 = edge1.getLeft();
BitSet right1 = edge1.getRight();
BitSet left2 = edge2.getLeft();
BitSet right2 = edge2.getRight();
long left1 = edge1.getLeft();
long right1 = edge1.getRight();
long left2 = edge2.getLeft();
long right2 = edge2.getRight();
Edge edge1Before2;
Edge edge2Before1;
List<BitSet> superBitset = new ArrayList<>();
List<Long> superBitset = new ArrayList<>();
if (tryGetSuperset(left1, left2, superBitset)) {
// (common Join1 right1) Join2 right2
edge1Before2 = threeLeftJoin(superBitset.get(0), edge1, right1, edge2, right2);
@ -335,35 +334,35 @@ public class GraphSimplifier {
return Optional.of(simplificationStep);
}
Edge threeLeftJoin(BitSet bitSet1, Edge edge1, BitSet bitSet2, Edge edge2, BitSet bitSet3) {
Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
// (plan1 edge1 plan2) edge2 plan3
// The join may have redundant table, e.g., t1,t2 join t3 join t2,t4
// Therefore, the cost is not accurate
Preconditions.checkArgument(
cachePlan.containsKey(bitSet1) && cachePlan.containsKey(bitSet2) && cachePlan.containsKey(bitSet3));
LogicalJoin leftPlan = simulateJoin(cachePlan.get(bitSet1), edge1.getJoin(), cachePlan.get(bitSet2));
LogicalJoin join = simulateJoin(leftPlan, edge2.getJoin(), cachePlan.get(bitSet3));
cachePlan.containsKey(bitmap1) && cachePlan.containsKey(bitmap2) && cachePlan.containsKey(bitmap3));
LogicalJoin leftPlan = simulateJoin(cachePlan.get(bitmap1), edge1.getJoin(), cachePlan.get(bitmap2));
LogicalJoin join = simulateJoin(leftPlan, edge2.getJoin(), cachePlan.get(bitmap3));
Edge edge = new Edge(join, -1);
BitSet newLeft = Bitmap.newBitmapUnion(bitSet1, bitSet2);
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
// To avoid overlapping the left and the right, the newLeft is calculated, Note the
// newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency
Bitmap.andNot(newLeft, bitSet3);
newLeft = LongBitmap.andNot(newLeft, bitmap3);
edge.addLeftNodes(newLeft);
edge.addRightNode(edge2.getRight());
cachePlan.put(newLeft, leftPlan);
return edge;
}
Edge threeRightJoin(BitSet bitSet1, Edge edge1, BitSet bitSet2, Edge edge2, BitSet bitSet3) {
Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
Preconditions.checkArgument(
cachePlan.containsKey(bitSet1) && cachePlan.containsKey(bitSet2) && cachePlan.containsKey(bitSet3));
cachePlan.containsKey(bitmap1) && cachePlan.containsKey(bitmap2) && cachePlan.containsKey(bitmap3));
// plan1 edge1 (plan2 edge2 plan3)
LogicalJoin rightPlan = simulateJoin(cachePlan.get(bitSet2), edge2.getJoin(), cachePlan.get(bitSet3));
LogicalJoin join = simulateJoin(cachePlan.get(bitSet1), edge1.getJoin(), rightPlan);
LogicalJoin rightPlan = simulateJoin(cachePlan.get(bitmap2), edge2.getJoin(), cachePlan.get(bitmap3));
LogicalJoin join = simulateJoin(cachePlan.get(bitmap1), edge1.getJoin(), rightPlan);
Edge edge = new Edge(join, -1);
BitSet newRight = Bitmap.newBitmapUnion(bitSet2, bitSet3);
Bitmap.andNot(newRight, bitSet1);
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
newRight = LongBitmap.andNot(newRight, bitmap1);
edge.addLeftNode(edge1.getLeft());
edge.addRightNode(newRight);
cachePlan.put(newRight, rightPlan);
@ -403,12 +402,12 @@ public class GraphSimplifier {
return step;
}
private boolean tryGetSuperset(BitSet bitSet1, BitSet bitSet2, List<BitSet> superset) {
if (Bitmap.isSubset(bitSet1, bitSet2)) {
superset.add(bitSet2);
private boolean tryGetSuperset(long bitmap1, long bitmap2, List<Long> superset) {
if (LongBitmap.isSubset(bitmap1, bitmap2)) {
superset.add(bitmap2);
return true;
} else if (Bitmap.isSubset(bitSet2, bitSet1)) {
superset.add(bitSet1);
} else if (LongBitmap.isSubset(bitmap2, bitmap1)) {
superset.add(bitmap1);
return true;
}
return false;
@ -465,13 +464,13 @@ public class GraphSimplifier {
double benefit;
int beforeIndex;
int afterIndex;
BitSet newLeft;
BitSet newRight;
BitSet oldLeft;
BitSet oldRight;
long newLeft;
long newRight;
long oldLeft;
long oldRight;
SimplificationStep(double benefit, int beforeIndex, int afterIndex, BitSet newLeft, BitSet newRight,
BitSet oldLeft, BitSet oldRight) {
SimplificationStep(double benefit, int beforeIndex, int afterIndex, long newLeft, long newRight,
long oldLeft, long oldRight) {
this.afterIndex = afterIndex;
this.beforeIndex = beforeIndex;
this.benefit = benefit;

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -29,7 +29,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Set;
@ -49,8 +48,8 @@ public class HyperGraph {
return nodes;
}
public BitSet getNodesMap() {
return Bitmap.newBitmapBetween(0, nodes.size());
public long getNodesMap() {
return LongBitmap.newBitmapBetween(0, nodes.size());
}
public Edge getEdge(int index) {
@ -79,33 +78,35 @@ public class HyperGraph {
LogicalJoin singleJoin = new LogicalJoin(join.getJoinType(), ImmutableList.of(expression), join.left(),
join.right());
Edge edge = new Edge(singleJoin, edges.size());
BitSet bitSet = findNodes(expression.getInputSlots());
Preconditions.checkArgument(bitSet.cardinality() == 2,
long bitmap = findNodes(expression.getInputSlots());
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 2,
String.format("HyperGraph has not supported polynomial %s yet", expression));
int leftIndex = Bitmap.nextSetBit(bitSet, 0);
BitSet left = Bitmap.newBitmap(leftIndex);
int leftIndex = LongBitmap.nextSetBit(bitmap, 0);
long left = LongBitmap.newBitmap(leftIndex);
edge.addLeftNode(left);
int rightIndex = Bitmap.nextSetBit(bitSet, leftIndex + 1);
BitSet right = Bitmap.newBitmap(rightIndex);
int rightIndex = LongBitmap.nextSetBit(bitmap, leftIndex + 1);
long right = LongBitmap.newBitmap(rightIndex);
edge.addRightNode(right);
edge.getReferenceNodes().stream().forEach(index -> nodes.get(index).attachEdge(edge));
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
edges.add(edge);
}
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
// We don't implement this trick now.
}
private BitSet findNodes(Set<Slot> slots) {
BitSet bitSet = Bitmap.newBitmap();
private long findNodes(Set<Slot> slots) {
long bitmap = LongBitmap.newBitmap();
for (Node node : nodes) {
for (Slot slot : node.getPlan().getOutput()) {
if (slots.contains(slot)) {
Bitmap.set(bitSet, node.getIndex());
bitmap = LongBitmap.set(bitmap, node.getIndex());
break;
}
}
}
return bitSet;
return bitmap;
}
/**
@ -115,7 +116,7 @@ public class HyperGraph {
* @param newLeft The new left of updated edge
* @param newRight The new right of update edge
*/
public void modifyEdge(int edgeIndex, BitSet newLeft, BitSet newRight) {
public void modifyEdge(int edgeIndex, long newLeft, long newRight) {
// When modify an edge in hyper graph, we need to update the left and right nodes
// For these nodes that are only in the old edge, we need remove the edge from them
// For these nodes that are only in the new edge, we need to add the edge to them
@ -126,12 +127,12 @@ public class HyperGraph {
edges.get(edgeIndex).setRight(newRight);
}
private void updateEdges(Edge edge, BitSet oldNodes, BitSet newNodes) {
BitSet removeNodes = Bitmap.newBitmapDiff(oldNodes, newNodes);
Bitmap.getIterator(removeNodes).forEach(index -> nodes.get(index).removeEdge(edge));
private void updateEdges(Edge edge, long oldNodes, long newNodes) {
long removeNodes = LongBitmap.newBitmapDiff(oldNodes, newNodes);
LongBitmap.getIterator(removeNodes).forEach(index -> nodes.get(index).removeEdge(edge));
BitSet addedNodes = Bitmap.newBitmapDiff(newNodes, oldNodes);
Bitmap.getIterator(addedNodes).forEach(index -> nodes.get(index).attachEdge(edge));
long addedNodes = LongBitmap.newBitmapDiff(newNodes, oldNodes);
LongBitmap.getIterator(addedNodes).forEach(index -> nodes.get(index).attachEdge(edge));
}
/**
@ -168,8 +169,8 @@ public class HyperGraph {
arrowHead = ",arrowhead=none";
}
int leftIndex = edge.getLeft().nextSetBit(0);
int rightIndex = edge.getRight().nextSetBit(0);
int leftIndex = LongBitmap.lowestOneIndex(edge.getLeft());
int rightIndex = LongBitmap.lowestOneIndex(edge.getRight());
builder.append(String.format("%s -> %s [label=\"%s\"%s]\n", graphvisNodes.get(leftIndex),
graphvisNodes.get(rightIndex), label, arrowHead));
} else {
@ -178,7 +179,7 @@ public class HyperGraph {
String leftLabel = "";
String rightLabel = "";
if (edge.getLeft().cardinality() == 1) {
if (LongBitmap.getCardinality(edge.getLeft()) == 1) {
rightLabel = label;
} else {
leftLabel = label;
@ -186,16 +187,16 @@ public class HyperGraph {
int finalI = i;
String finalLeftLabel = leftLabel;
edge.getLeft().stream().forEach(nodeIndex -> {
for (int nodeIndex : LongBitmap.getIterator(edge.getLeft())) {
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
graphvisNodes.get(nodeIndex), finalI, finalLeftLabel));
});
}
String finalRightLabel = rightLabel;
edge.getRight().stream().forEach(nodeIndex -> {
for (int nodeIndex : LongBitmap.getIterator(edge.getRight())) {
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
graphvisNodes.get(nodeIndex), finalI, finalRightLabel));
});
}
}
}
builder.append("}\n");

View File

@ -17,14 +17,12 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import javax.annotation.Nullable;
/**
* HyperGraph Node.
@ -33,54 +31,18 @@ public class Node {
private final int index;
private Group group;
private List<Edge> edges = new ArrayList<>();
// We split these into simple edges (only one node on each side) and complex edges (others)
// because we can often quickly discard all simple edges by testing the set of interesting nodes
// against the “simple_neighborhood” bitmap. These data will be calculated before enumerate.
private List<Edge> complexEdges = new ArrayList<>();
private BitSet simpleNeighborhood = new BitSet();
private List<Edge> simpleEdges = new ArrayList<>();
private BitSet complexNeighborhood = new BitSet();
public Node(int index, Group group) {
this.group = group;
this.index = index;
}
/**
* Try to find the edge between this node and nodes
*
* @param nodes the other side of the edge
* @return The edge between this node and parameters
*/
@Nullable
public Edge tryGetEdgeWith(BitSet nodes) {
if (Bitmap.isOverlap(simpleNeighborhood, nodes)) {
for (Edge edge : simpleEdges) {
if (Bitmap.isSubset(edge.getLeft(), nodes) || Bitmap.isSubset(edge.getRight(), nodes)) {
return edge;
}
}
throw new RuntimeException(String.format("There is no simple Edge <%d - %s>", index, nodes));
} else if (Bitmap.isOverlap(complexNeighborhood, nodes)) {
for (Edge edge : complexEdges) {
// TODO: Right now we check all edges. But due to the simple cmp, we can only check that the edge with
// one side that equal to this node
if ((Bitmap.isSubset(edge.getLeft(), nodes) && Bitmap.isSubset(edge.getRight(),
Bitmap.newBitmap(index))) || (Bitmap.isSubset(edge.getRight(), nodes) && Bitmap.isSubset(
edge.getLeft(), Bitmap.newBitmap(index)))) {
return edge;
}
}
}
return null;
}
public int getIndex() {
return index;
}
public BitSet getNodeMap() {
return Bitmap.newBitmap(index);
public long getNodeMap() {
return LongBitmap.newBitmap(index);
}
public Plan getPlan() {

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.SubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver;
import com.google.common.base.Preconditions;
@ -71,11 +71,11 @@ public class SubgraphEnumerator {
neighborhoodCalculator = new NeighborhoodCalculator();
// We skip the last element because it can't generate valid csg-cmp pair
BitSet forbiddenNodes = Bitmap.newBitmapBetween(0, size - 1);
long forbiddenNodes = LongBitmap.newBitmapBetween(0, size - 1);
for (int i = size - 2; i >= 0; i--) {
BitSet csg = Bitmap.newBitmap(i);
Bitmap.unset(forbiddenNodes, i);
if (!emitCsg(csg) || !enumerateCsgRec(csg, Bitmap.newBitmap(forbiddenNodes))) {
long csg = LongBitmap.newBitmap(i);
forbiddenNodes = LongBitmap.unset(forbiddenNodes, i);
if (!emitCsg(csg) || !enumerateCsgRec(csg, LongBitmap.clone(forbiddenNodes))) {
return false;
}
}
@ -84,11 +84,11 @@ public class SubgraphEnumerator {
// The general purpose of EnumerateCsgRec is to extend a given set csg, which
// induces a connected subgraph of G to a larger set with the same property.
private boolean enumerateCsgRec(BitSet csg, BitSet forbiddenNodes) {
BitSet neighborhood = neighborhoodCalculator.calcNeighborhood(csg, forbiddenNodes, edgeCalculator);
SubsetIterator subsetIterator = Bitmap.getSubsetIterator(neighborhood);
for (BitSet subset : subsetIterator) {
BitSet newCsg = Bitmap.newBitmapUnion(csg, subset);
private boolean enumerateCsgRec(long csg, long forbiddenNodes) {
long neighborhood = neighborhoodCalculator.calcNeighborhood(csg, forbiddenNodes, edgeCalculator);
LongBitmapSubsetIterator subsetIterator = LongBitmap.getSubsetIterator(neighborhood);
for (long subset : subsetIterator) {
long newCsg = LongBitmap.newBitmapUnion(csg, subset);
if (receiver.contain(newCsg)) {
edgeCalculator.unionEdges(csg, subset);
if (!emitCsg(newCsg)) {
@ -96,22 +96,22 @@ public class SubgraphEnumerator {
}
}
}
Bitmap.or(forbiddenNodes, neighborhood);
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhood);
subsetIterator.reset();
for (BitSet subset : subsetIterator) {
BitSet newCsg = Bitmap.newBitmapUnion(csg, subset);
if (!enumerateCsgRec(newCsg, Bitmap.newBitmap(forbiddenNodes))) {
for (long subset : subsetIterator) {
long newCsg = LongBitmap.newBitmapUnion(csg, subset);
if (!enumerateCsgRec(newCsg, LongBitmap.clone(forbiddenNodes))) {
return false;
}
}
return true;
}
private boolean enumerateCmpRec(BitSet csg, BitSet cmp, BitSet forbiddenNodes) {
BitSet neighborhood = neighborhoodCalculator.calcNeighborhood(cmp, forbiddenNodes, edgeCalculator);
SubsetIterator subsetIterator = new SubsetIterator(neighborhood);
for (BitSet subset : subsetIterator) {
BitSet newCmp = Bitmap.newBitmapUnion(cmp, subset);
private boolean enumerateCmpRec(long csg, long cmp, long forbiddenNodes) {
long neighborhood = neighborhoodCalculator.calcNeighborhood(cmp, forbiddenNodes, edgeCalculator);
LongBitmapSubsetIterator subsetIterator = new LongBitmapSubsetIterator(neighborhood);
for (long subset : subsetIterator) {
long newCmp = LongBitmap.newBitmapUnion(cmp, subset);
// We need to check whether Cmp is connected and then try to find hyper edge
if (receiver.contain(newCmp)) {
edgeCalculator.unionEdges(cmp, subset);
@ -125,11 +125,11 @@ public class SubgraphEnumerator {
}
}
}
Bitmap.or(forbiddenNodes, neighborhood);
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhood);
subsetIterator.reset();
for (BitSet subset : subsetIterator) {
BitSet newCmp = Bitmap.newBitmapUnion(cmp, subset);
if (!enumerateCmpRec(csg, newCmp, Bitmap.newBitmap(forbiddenNodes))) {
for (long subset : subsetIterator) {
long newCmp = LongBitmap.newBitmapUnion(cmp, subset);
if (!enumerateCmpRec(csg, newCmp, LongBitmap.clone(forbiddenNodes))) {
return false;
}
}
@ -139,14 +139,14 @@ public class SubgraphEnumerator {
// EmitCsg takes as an argument a non-empty, proper subset csg of HyperGraph , which
// induces a connected subgraph. It is then responsible to generate the seeds for
// all cmp such that (csg, cmp) becomes a csg-cmp-pair.
private boolean emitCsg(BitSet csg) {
BitSet forbiddenNodes = Bitmap.newBitmapBetween(0, Bitmap.nextSetBit(csg, 0));
Bitmap.or(forbiddenNodes, csg);
BitSet neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, Bitmap.newBitmap(forbiddenNodes),
private boolean emitCsg(long csg) {
long forbiddenNodes = LongBitmap.newBitmapBetween(0, LongBitmap.nextSetBit(csg, 0));
forbiddenNodes = LongBitmap.or(forbiddenNodes, csg);
long neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, LongBitmap.clone(forbiddenNodes),
edgeCalculator);
for (int nodeIndex : Bitmap.getReverseIterator(neighborhoods)) {
BitSet cmp = Bitmap.newBitmap(nodeIndex);
for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) {
long cmp = LongBitmap.newBitmap(nodeIndex);
// whether there is an edge between csg and cmp
List<Edge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
if (edges.isEmpty()) {
@ -165,9 +165,9 @@ public class SubgraphEnumerator {
// 2. The cmp is {t2} and expanded from {t2} to {t2, t3}
// We don't want get {t2, t3} twice. So In first enumeration, we
// can exclude {t2}
BitSet newForbiddenNodes = Bitmap.newBitmapBetween(0, nodeIndex + 1);
Bitmap.and(newForbiddenNodes, neighborhoods);
Bitmap.or(newForbiddenNodes, forbiddenNodes);
long newForbiddenNodes = LongBitmap.newBitmapBetween(0, nodeIndex + 1);
newForbiddenNodes = LongBitmap.and(newForbiddenNodes, neighborhoods);
newForbiddenNodes = LongBitmap.or(newForbiddenNodes, forbiddenNodes);
if (!enumerateCmpRec(csg, cmp, newForbiddenNodes)) {
return false;
}
@ -183,21 +183,22 @@ public class SubgraphEnumerator {
// expand csg and cmp. In fact, we just need a seed node that can be expanded
// to all subgraph. That is any one node of hyper nodes. In fact, the neighborhoods
// is the minimum set that we choose one node from above v.
public BitSet calcNeighborhood(BitSet subgraph, BitSet forbiddenNodes, EdgeCalculator edgeCalculator) {
BitSet neighborhoods = Bitmap.newBitmap();
edgeCalculator.foundSimpleEdgesContain(subgraph)
.forEach(edge -> neighborhoods.or(edge.getReferenceNodes()));
Bitmap.or(forbiddenNodes, subgraph);
Bitmap.andNot(neighborhoods, forbiddenNodes);
Bitmap.or(forbiddenNodes, neighborhoods);
public long calcNeighborhood(long subgraph, long forbiddenNodes, EdgeCalculator edgeCalculator) {
long neighborhoods = LongBitmap.newBitmap();
for (Edge edge : edgeCalculator.foundSimpleEdgesContain(subgraph)) {
neighborhoods = LongBitmap.or(neighborhoods, edge.getReferenceNodes());
}
forbiddenNodes = LongBitmap.or(forbiddenNodes, subgraph);
neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
BitSet left = edge.getLeft();
BitSet right = edge.getRight();
if (Bitmap.isSubset(left, subgraph) && !Bitmap.isOverlap(right, forbiddenNodes)) {
Bitmap.set(neighborhoods, right.nextSetBit(0));
} else if (Bitmap.isSubset(right, subgraph) && !Bitmap.isOverlap(left, forbiddenNodes)) {
Bitmap.set(neighborhoods, left.nextSetBit(0));
long left = edge.getLeft();
long right = edge.getRight();
if (LongBitmap.isSubset(left, subgraph) && !LongBitmap.isOverlap(right, forbiddenNodes)) {
neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(right));
} else if (LongBitmap.isSubset(right, subgraph) && !LongBitmap.isOverlap(left, forbiddenNodes)) {
neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(left));
}
}
return neighborhoods;
@ -212,17 +213,17 @@ public class SubgraphEnumerator {
// because we can often quickly discard all simple edges by testing the set of interesting nodes
// against the “simple_neighborhood” bitmap. These data will be calculated before enumerate.
HashMap<BitSet, BitSet> containSimpleEdges = new HashMap<>();
HashMap<BitSet, BitSet> containComplexEdges = new HashMap<>();
HashMap<Long, BitSet> containSimpleEdges = new HashMap<>();
HashMap<Long, BitSet> containComplexEdges = new HashMap<>();
// It cached all edges that overlap by this subgraph. All this edges must be
// complex edges
HashMap<BitSet, BitSet> overlapEdges = new HashMap<>();
HashMap<Long, BitSet> overlapEdges = new HashMap<>();
EdgeCalculator(List<Edge> edges) {
this.edges = edges;
}
public void initSubgraph(BitSet subgraph) {
public void initSubgraph(long subgraph) {
BitSet simpleContains = new BitSet();
BitSet complexContains = new BitSet();
BitSet overlaps = new BitSet();
@ -249,26 +250,29 @@ public class SubgraphEnumerator {
containComplexEdges.put(subgraph, complexContains);
}
public void unionEdges(BitSet bitSet1, BitSet bitSet2) {
public void unionEdges(long bitmap1, long bitmap2) {
// When union two sub graphs, we only need to check overlap edges.
// However, if all reference nodes are contained by the subgraph,
// we should remove it.
if (!containSimpleEdges.containsKey(bitSet1)) {
initSubgraph(bitSet1);
if (!containSimpleEdges.containsKey(bitmap1)) {
initSubgraph(bitmap1);
}
if (!containSimpleEdges.containsKey(bitSet2)) {
initSubgraph(bitSet2);
if (!containSimpleEdges.containsKey(bitmap2)) {
initSubgraph(bitmap2);
}
BitSet subgraph = Bitmap.newBitmapUnion(bitSet1, bitSet2);
long subgraph = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
if (containSimpleEdges.containsKey(subgraph)) {
return;
}
BitSet simpleContains = Bitmap.newBitmapUnion(containSimpleEdges.get(bitSet1),
containSimpleEdges.get(bitSet2));
BitSet complexContains = Bitmap.newBitmapUnion(containComplexEdges.get(bitSet1),
containComplexEdges.get(bitSet2));
BitSet overlaps = Bitmap.newBitmapUnion(overlapEdges.get(bitSet1),
overlapEdges.get(bitSet2));
BitSet simpleContains = new BitSet();
simpleContains.or(containSimpleEdges.get(bitmap1));
simpleContains.or(containSimpleEdges.get(bitmap2));
BitSet complexContains = new BitSet();
simpleContains.or(containComplexEdges.get(bitmap1));
simpleContains.or(containComplexEdges.get(bitmap2));
BitSet overlaps = new BitSet();
simpleContains.or(overlapEdges.get(bitmap1));
simpleContains.or(overlapEdges.get(bitmap2));
for (int index : overlaps.stream().toArray()) {
Edge edge = edges.get(index);
if (isContainEdge(subgraph, edge)) {
@ -287,19 +291,22 @@ public class SubgraphEnumerator {
overlapEdges.put(subgraph, overlaps);
}
public List<Edge> connectCsgCmp(BitSet bitSet1, BitSet bitSet2) {
public List<Edge> connectCsgCmp(long csg, long cmp) {
Preconditions.checkArgument(
containSimpleEdges.containsKey(bitSet1) && containSimpleEdges.containsKey(bitSet2));
containSimpleEdges.containsKey(csg) && containSimpleEdges.containsKey(cmp));
List<Edge> foundEdges = new ArrayList<>();
BitSet edgeMap = Bitmap.newBitmapIntersect(containSimpleEdges.get(bitSet1),
containSimpleEdges.get(bitSet2));
Bitmap.or(edgeMap, Bitmap.newBitmapIntersect(containComplexEdges.get(bitSet1),
containComplexEdges.get(bitSet2)));
BitSet edgeMap = new BitSet();
edgeMap.or(containSimpleEdges.get(csg));
edgeMap.and(containSimpleEdges.get(cmp));
BitSet complexes = new BitSet();
complexes.or(containComplexEdges.get(csg));
complexes.and(containComplexEdges.get(cmp));
edgeMap.or(complexes);
edgeMap.stream().forEach(index -> foundEdges.add(edges.get(index)));
return foundEdges;
}
public List<Edge> foundEdgesContain(BitSet subgraph) {
public List<Edge> foundEdgesContain(long subgraph) {
Preconditions.checkArgument(containSimpleEdges.containsKey(subgraph));
BitSet edgeMap = containSimpleEdges.get(subgraph);
edgeMap.or(containComplexEdges.get(subgraph));
@ -308,7 +315,7 @@ public class SubgraphEnumerator {
return foundEdges;
}
public List<Edge> foundSimpleEdgesContain(BitSet subgraph) {
public List<Edge> foundSimpleEdgesContain(long subgraph) {
List<Edge> foundEdges = new ArrayList<>();
if (!containSimpleEdges.containsKey(subgraph)) {
return foundEdges;
@ -318,7 +325,7 @@ public class SubgraphEnumerator {
return foundEdges;
}
public List<Edge> foundComplexEdgesContain(BitSet subgraph) {
public List<Edge> foundComplexEdgesContain(long subgraph) {
List<Edge> foundEdges = new ArrayList<>();
if (!containComplexEdges.containsKey(subgraph)) {
return foundEdges;
@ -328,24 +335,24 @@ public class SubgraphEnumerator {
return foundEdges;
}
public int getEdgeSizeContain(BitSet subgraph) {
public int getEdgeSizeContain(long subgraph) {
Preconditions.checkArgument(containSimpleEdges.containsKey(subgraph));
return containSimpleEdges.get(subgraph).cardinality() + containSimpleEdges.get(subgraph).cardinality();
}
private boolean isContainEdge(BitSet subgraph, Edge edge) {
int containLeft = Bitmap.isSubset(edge.getLeft(), subgraph) ? 0 : 1;
int containRight = Bitmap.isSubset(edge.getRight(), subgraph) ? 0 : 1;
private boolean isContainEdge(long subgraph, Edge edge) {
int containLeft = LongBitmap.isSubset(edge.getLeft(), subgraph) ? 0 : 1;
int containRight = LongBitmap.isSubset(edge.getRight(), subgraph) ? 0 : 1;
return containLeft + containRight == 1;
}
private boolean isOverlapEdge(BitSet subgraph, Edge edge) {
int overlapLeft = Bitmap.isOverlap(edge.getLeft(), subgraph) ? 0 : 1;
int overlapRight = Bitmap.isOverlap(edge.getRight(), subgraph) ? 0 : 1;
private boolean isOverlapEdge(long subgraph, Edge edge) {
int overlapLeft = LongBitmap.isOverlap(edge.getLeft(), subgraph) ? 0 : 1;
int overlapRight = LongBitmap.isOverlap(edge.getRight(), subgraph) ? 0 : 1;
return overlapLeft + overlapRight == 1;
}
private BitSet removeInvalidEdges(BitSet subgraph, BitSet edgeMap) {
private BitSet removeInvalidEdges(long subgraph, BitSet edgeMap) {
for (int index : edgeMap.stream().toArray()) {
Edge edge = edges.get(index);
if (!isOverlapEdge(subgraph, edge)) {

View File

@ -1,134 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import java.util.BitSet;
/**
* This is helper class for some bitmap operation
*/
public class Bitmap {
public static boolean isSubset(BitSet bitSet1, BitSet bitSet2) {
BitSet bitSet = new BitSet();
bitSet.or(bitSet1);
bitSet.andNot(bitSet2);
return bitSet.cardinality() == 0;
}
public static BitSet newBitmap(int... values) {
BitSet bitSet = new BitSet();
for (int v : values) {
bitSet.set(v);
}
return bitSet;
}
public static BitSet newBitmap(BitSet bitSet) {
BitSet n = new BitSet();
n.or(bitSet);
return n;
}
public static BitSet newBitmap() {
return new BitSet();
}
public static BitSet newBitmapUnion(BitSet... bitSets) {
BitSet u = new BitSet();
for (BitSet bitSet : bitSets) {
u.or(bitSet);
}
return u;
}
// return bitSet1 - bitSet2
public static BitSet newBitmapDiff(BitSet bitSet1, BitSet bitSet2) {
BitSet u = new BitSet();
u.or(bitSet1);
u.andNot(bitSet2);
return u;
}
//return bitset1 ∩ bitset2
public static BitSet newBitmapIntersect(BitSet bitSet1, BitSet bitSet2) {
BitSet intersect = newBitmap();
intersect.or(bitSet1);
intersect.and(bitSet2);
return intersect;
}
public static BitSet newBitmapBetween(int start, int end) {
BitSet bitSet = new BitSet();
bitSet.set(start, end);
return bitSet;
}
public static int nextSetBit(BitSet bitSet, int fromIndex) {
return bitSet.nextSetBit(fromIndex);
}
public static boolean get(BitSet bitSet, int index) {
return bitSet.get(index);
}
public static void set(BitSet bitSet, int index) {
bitSet.set(index);
}
public static void unset(BitSet bitSet, int index) {
bitSet.set(index, false);
}
public static void clear(BitSet bitSet) {
bitSet.clear();
}
public static int getCardinality(BitSet bitSet) {
return bitSet.cardinality();
}
public static BitSetIterator getIterator(BitSet bitSet) {
return new BitSetIterator(bitSet);
}
public static ReverseBitSetIterator getReverseIterator(BitSet bitSet) {
return new ReverseBitSetIterator(bitSet);
}
public static void or(BitSet bitSet1, BitSet bitSet2) {
bitSet1.or(bitSet2);
}
public static boolean isOverlap(BitSet bitSet1, BitSet bitSet2) {
return bitSet1.intersects(bitSet2);
}
public static void andNot(BitSet bitSet1, BitSet bitSet2) {
bitSet1.andNot(bitSet2);
}
public static void and(BitSet bitSet1, BitSet bitSet2) {
bitSet1.and(bitSet2);
}
public static SubsetIterator getSubsetIterator(BitSet bitSet) {
return new SubsetIterator(bitSet);
}
}

View File

@ -0,0 +1,156 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import java.util.BitSet;
/**
* This is helper class for some bitmap operation
*/
public class LongBitmap {
private static final long MASK = 0xffffffffffffffffL;
private static final int SIZE = Long.SIZE;
public static boolean isSubset(long bitmap1, long bitmap2) {
return (bitmap1 | bitmap2) == bitmap2;
}
public static long newBitmap(int... values) {
long bitmap = 0;
for (int v : values) {
bitmap |= (1L << v);
}
return bitmap;
}
public static long newBitmap(int value) {
return 1L << value;
}
public static long newBitmap() {
return 0;
}
public static long clone(long bitmap) {
return bitmap;
}
public static long newBitmapUnion(long... bitmaps) {
long u = 0;
for (long bitmap : bitmaps) {
u |= bitmap;
}
return u;
}
public static long newBitmapUnion(long b1, long b2) {
return b1 | b2;
}
// return bitSet1 - bitSet2
public static long newBitmapDiff(long bitmap1, long bitmap2) {
return bitmap1 & (~bitmap2);
}
//return bitset1 ∩ bitset2
public static long newBitmapIntersect(long bitmap1, long bitmap2) {
return bitmap1 & bitmap2;
}
public static long newBitmapBetween(int start, int end) {
long bitmap = 0;
for (int i = start; i < end; i++) {
bitmap |= (1L << i);
}
return bitmap;
}
public static int nextSetBit(long bitmap, int fromIndex) {
bitmap &= (MASK << fromIndex);
bitmap = bitmap & (-bitmap);
return Long.numberOfTrailingZeros(bitmap);
}
public static boolean get(long bitmap, int index) {
return (bitmap & (1L << index)) != 0;
}
public static long set(long bitmap, int index) {
return bitmap | (1L << index);
}
public static long unset(long bitmap, int index) {
return bitmap & (~(1L << index));
}
public static long clear(long bitSet) {
return 0;
}
public static int getCardinality(long bimap) {
return Long.bitCount(bimap);
}
public static LongBitmapIterator getIterator(long bitmap) {
return new LongBitmapIterator(bitmap);
}
public static LongBitmapReverseIterator getReverseIterator(long bitmap) {
return new LongBitmapReverseIterator(bitmap);
}
public static long or(long bitmap1, long bitmap2) {
return bitmap1 | bitmap2;
}
public static boolean isOverlap(long bitmap1, long bitmap2) {
return (bitmap1 & bitmap2) != 0;
}
public static long andNot(long bitmap1, long bitmap2) {
return (bitmap1 & ~(bitmap2));
}
public static long and(long bitmap1, long bitmap2) {
return bitmap1 & bitmap2;
}
public static LongBitmapSubsetIterator getSubsetIterator(long bitmap) {
return new LongBitmapSubsetIterator(bitmap);
}
public static long clearLowestBit(long bitmap) {
return bitmap & (bitmap - 1);
}
public static int previousSetBit(long bitmap, int fromIndex) {
long newBitmap = bitmap & (MASK >>> -(fromIndex + 1));
return SIZE - Long.numberOfLeadingZeros(newBitmap) - 1;
}
public static int lowestOneIndex(long bitmap) {
return Long.numberOfTrailingZeros(bitmap);
}
public static String toString(long bitmap) {
long[] longs = {bitmap};
BitSet bitSet = BitSet.valueOf(longs);
return bitSet.toString();
}
}

View File

@ -19,19 +19,17 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import org.jetbrains.annotations.NotNull;
import java.util.BitSet;
import java.util.Iterator;
/**
* This is an Iterator for iterating bitmap
*/
public class BitSetIterator implements Iterable<Integer> {
int lastIndex = -1;
int readNum = 0;
BitSet bitSet;
public class LongBitmapIterator implements Iterable<Integer> {
private long bitmap;
private int lastIndex = 0;
BitSetIterator(BitSet bitSet) {
this.bitSet = bitSet;
LongBitmapIterator(long bitmap) {
this.bitmap = bitmap;
}
@NotNull
@ -40,13 +38,13 @@ public class BitSetIterator implements Iterable<Integer> {
class Iter implements Iterator<Integer> {
@Override
public boolean hasNext() {
return (readNum < bitSet.cardinality());
return bitmap != 0;
}
@Override
public Integer next() {
lastIndex = bitSet.nextSetBit(lastIndex + 1);
readNum += 1;
lastIndex = LongBitmap.nextSetBit(bitmap, lastIndex);
bitmap = LongBitmap.clearLowestBit(bitmap);
return lastIndex;
}

View File

@ -19,20 +19,18 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import org.jetbrains.annotations.NotNull;
import java.util.BitSet;
import java.util.Iterator;
/**
* This is an Iterator for iterating bitmap descending
*/
public class ReverseBitSetIterator implements Iterable<Integer> {
int lastIndex = 0;
int readNum = 0;
BitSet bitSet;
public class LongBitmapReverseIterator implements Iterable<Integer> {
private int lastIndex = 0;
private long bitmap;
ReverseBitSetIterator(BitSet bitSet) {
this.bitSet = bitSet;
lastIndex = bitSet.size();
LongBitmapReverseIterator(long bitmap) {
this.bitmap = bitmap;
lastIndex = 63;
}
@NotNull
@ -41,13 +39,13 @@ public class ReverseBitSetIterator implements Iterable<Integer> {
class Iter implements Iterator<Integer> {
@Override
public boolean hasNext() {
return (readNum < bitSet.cardinality());
return bitmap != 0;
}
@Override
public Integer next() {
lastIndex = bitSet.previousSetBit(lastIndex - 1);
readNum += 1;
lastIndex = LongBitmap.previousSetBit(bitmap, lastIndex);
bitmap = LongBitmap.unset(bitmap, lastIndex);
return lastIndex;
}

View File

@ -19,63 +19,44 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
/**
* This is a class for iterating all true subset of a bitset, referenced in
* https://groups.google.com/forum/#!msg/rec.games.chess/KnJvBnhgDKU/yCi5yBx18PQJ
*/
public class SubsetIterator implements Iterable<BitSet> {
List<BitSet> subsets = new ArrayList<>();
int cursor = 0;
public class LongBitmapSubsetIterator implements Iterable<Long> {
private long bitmap;
private long state;
/**
* Generate all subset for this bitSet
*
* @param bitSet The bitset that need to be generated
* @param bitmap The bitset that need to be generated
*/
public SubsetIterator(BitSet bitSet) {
long[] setVal = bitSet.toLongArray();
int len = setVal.length;
long[] baseVal = new long[len];
subsets.add(new BitSet());
for (int i = 0; i < len; i++) {
long subVal = (-setVal[i]) & setVal[i];
int size = subsets.size();
while (subVal != 0) {
baseVal[i] = subVal;
for (int j = 0; j < size; j++) {
BitSet newSubset = BitSet.valueOf(baseVal);
newSubset.or(subsets.get(j));
subsets.add(newSubset);
}
subVal = (subVal - setVal[i]) & setVal[i];
}
baseVal[i] = 0;
}
// remove empty subset
subsets.remove(0);
public LongBitmapSubsetIterator(long bitmap) {
this.bitmap = bitmap;
this.state = (-bitmap) & bitmap;
}
public void reset() {
cursor = 0;
state = (-bitmap) & bitmap;
}
@NotNull
@Override
public Iterator<BitSet> iterator() {
class Iter implements Iterator<BitSet> {
public Iterator<Long> iterator() {
class Iter implements Iterator<Long> {
@Override
public boolean hasNext() {
return (cursor < subsets.size());
return state != 0;
}
@Override
public BitSet next() {
return subsets.get(cursor++);
public Long next() {
Long subset = state;
state = (state - bitmap) & bitmap;
return subset;
}
@Override

View File

@ -20,20 +20,19 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.memo.Group;
import java.util.BitSet;
import java.util.List;
/**
* A interface of receiver
*/
public interface AbstractReceiver {
public boolean emitCsgCmp(BitSet csg, BitSet cmp, List<Edge> edges);
public boolean emitCsgCmp(long csg, long cmp, List<Edge> edges);
public void addGroup(BitSet bitSet, Group group);
public void addGroup(long bitSet, Group group);
public boolean contain(BitSet bitSet);
public boolean contain(long bitSet);
public void reset();
public Group getBestPlan(BitSet bitSet);
public Group getBestPlan(long bitSet);
}

View File

@ -18,11 +18,11 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
@ -33,7 +33,7 @@ public class Counter implements AbstractReceiver {
// limit define the max number of csg-cmp pair in this Receiver
private int limit;
private int emitCount = 0;
private HashMap<BitSet, Integer> counter = new HashMap<>();
private HashMap<Long, Integer> counter = new HashMap<>();
public Counter() {
this.limit = Integer.MAX_VALUE;
@ -51,30 +51,28 @@ public class Counter implements AbstractReceiver {
* @param edges the join operator
* @return the left and the right can be connected by the edge
*/
public boolean emitCsgCmp(BitSet left, BitSet right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
Preconditions.checkArgument(counter.containsKey(left));
Preconditions.checkArgument(counter.containsKey(right));
emitCount += 1;
if (emitCount > limit) {
return false;
}
BitSet bitSet = new BitSet();
bitSet.or(left);
bitSet.or(right);
if (!counter.containsKey(bitSet)) {
counter.put(bitSet, counter.get(left) * counter.get(right));
long bitmap = LongBitmap.newBitmapUnion(left, right);
if (!counter.containsKey(bitmap)) {
counter.put(bitmap, counter.get(left) * counter.get(right));
} else {
counter.put(bitSet, counter.get(bitSet) + counter.get(left) * counter.get(right));
counter.put(bitmap, counter.get(bitmap) + counter.get(left) * counter.get(right));
}
return true;
}
public void addGroup(BitSet bitSet, Group group) {
counter.put(bitSet, 1);
public void addGroup(long bitmap, Group group) {
counter.put(bitmap, 1);
}
public boolean contain(BitSet bitSet) {
return counter.containsKey(bitSet);
public boolean contain(long bitmap) {
return counter.containsKey(bitmap);
}
public void reset() {
@ -82,15 +80,15 @@ public class Counter implements AbstractReceiver {
emitCount = 0;
}
public Group getBestPlan(BitSet bitSet) {
public Group getBestPlan(long bitmap) {
throw new RuntimeException("Counter does not support getBestPlan()");
}
public int getCount(BitSet bitSet) {
return counter.get(bitSet);
public int getCount(long bitmap) {
return counter.get(bitmap);
}
public HashMap<BitSet, Integer> getAllCount() {
public HashMap<Long, Integer> getAllCount() {
return counter;
}

View File

@ -18,7 +18,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.PhysicalProperties;
@ -31,7 +31,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
@ -40,7 +39,7 @@ import java.util.List;
*/
public class PlanReceiver implements AbstractReceiver {
// limit define the max number of csg-cmp pair in this Receiver
HashMap<BitSet, Group> planTable = new HashMap<>();
HashMap<Long, Group> planTable = new HashMap<>();
int limit;
int emitCount = 0;
@ -61,14 +60,14 @@ public class PlanReceiver implements AbstractReceiver {
* @return the left and the right can be connected by the edge
*/
@Override
public boolean emitCsgCmp(BitSet left, BitSet right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
emitCount += 1;
if (emitCount > limit) {
return false;
}
BitSet fullKey = Bitmap.newBitmapUnion(left, right);
long fullKey = LongBitmap.newBitmapUnion(left, right);
Group group1 = constructGroup(left, right, edges);
Group group2 = constructGroup(right, left, edges);
Group winnerGroup;
@ -88,13 +87,13 @@ public class PlanReceiver implements AbstractReceiver {
}
@Override
public void addGroup(BitSet bitSet, Group group) {
planTable.put(bitSet, group);
public void addGroup(long bitmap, Group group) {
planTable.put(bitmap, group);
}
@Override
public boolean contain(BitSet bitSet) {
return planTable.containsKey(bitSet);
public boolean contain(long bitmap) {
return planTable.containsKey(bitmap);
}
@Override
@ -104,9 +103,9 @@ public class PlanReceiver implements AbstractReceiver {
}
@Override
public Group getBestPlan(BitSet bitSet) {
Preconditions.checkArgument(planTable.containsKey(bitSet));
return planTable.get(bitSet);
public Group getBestPlan(long bitmap) {
Preconditions.checkArgument(planTable.containsKey(bitmap));
return planTable.get(bitmap);
}
private double getSimpleCost(Plan plan) {
@ -116,7 +115,7 @@ public class PlanReceiver implements AbstractReceiver {
return plan.getGroupExpression().get().getCostByProperties(PhysicalProperties.ANY);
}
private Group constructGroup(BitSet left, BitSet right, List<Edge> edges) {
private Group constructGroup(long left, long right, List<Edge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
Group leftGroup = planTable.get(left);

View File

@ -460,6 +460,16 @@ public class Memo {
return CopyInResult.of(false, existedLogicalExpression);
}
// This function is used to copy new group expression
// It's used in DPHyp after construct new group expression
public Group copyInGroupExpression(GroupExpression newGroupExpression) {
Group newGroup = new Group(groupIdGenerator.getNextId(), newGroupExpression,
newGroupExpression.getPlan().getLogicalProperties());
groups.put(newGroup.getGroupId(), newGroup);
groupExpressions.put(newGroupExpression, newGroupExpression);
return newGroup;
}
private CopyInResult rewriteByNewGroupExpression(Group targetGroup, Plan newPlan,
GroupExpression newGroupExpression) {
if (targetGroup == null) {

View File

@ -1,43 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.SubsetIterator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.BitSet;
import java.util.HashSet;
public class BitSetTest {
@Test
void subsetIteratorTest() {
BitSet bitSet = new BitSet();
bitSet.set(0, 3);
bitSet.set(64, 67);
SubsetIterator subsetIterator = new SubsetIterator(bitSet);
HashSet<BitSet> subsets = new HashSet<>();
for (BitSet subset : subsetIterator) {
Assertions.assertTrue(Bitmap.isSubset(subset, bitSet));
subsets.add(subset);
}
Assertions.assertEquals(subsets.size(), Math.pow(2, 6) - 1);
}
}

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.HashSet;
public class BitmapTest {
@Test
void subsetIteratorTest() {
int[] ints = {1, 3, 5, 12, 21, 32, 43, 54, 60, 63};
long bitmap = LongBitmap.newBitmap(ints);
LongBitmapSubsetIterator subsetIterator = new LongBitmapSubsetIterator(bitmap);
HashSet<Long> subsets = new HashSet<>();
for (long subset : subsetIterator) {
Assertions.assertTrue(LongBitmap.isSubset(subset, bitmap));
subsets.add(subset);
}
Assertions.assertEquals(subsets.size(), Math.pow(2, ints.length) - 1);
subsetIterator.reset();
subsets.clear();
for (long subset : subsetIterator) {
Assertions.assertTrue(LongBitmap.isSubset(subset, bitmap));
subsets.add(subset);
}
Assertions.assertEquals(subsets.size(), Math.pow(2, ints.length) - 1);
}
@Test
void iteratorTest() {
int[] ints = {1, 3, 6, 12, 43};
long bitmap = LongBitmap.newBitmap(ints);
int index = 0;
for (int v : LongBitmap.getIterator(bitmap)) {
Assertions.assertEquals(ints[index], v);
index += 1;
}
index = 0;
for (int v : LongBitmap.getReverseIterator(bitmap)) {
Assertions.assertEquals(ints[ints.length - 1 - index], v);
index += 1;
}
}
}

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.HyperGraphBuilder;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
public class GraphSimplifierTest {
@ -163,7 +162,6 @@ public class GraphSimplifierTest {
totalTime / times));
}
@Disabled
@Test
void testComplexQuery() {
HyperGraph hyperGraph = new HyperGraphBuilder()

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.SubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.HyperGraphBuilder;
@ -26,7 +26,6 @@ import org.apache.doris.nereids.util.HyperGraphBuilder;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
@ -48,9 +47,8 @@ public class SubgraphEnumeratorTest {
Counter counter = new Counter();
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
subgraphEnumerator.enumerate();
BitSet fullSet = new BitSet();
fullSet.set(0, 5);
HashMap<BitSet, Integer> cache = new HashMap<>();
long fullSet = LongBitmap.newBitmapBetween(0, 5);
HashMap<Long, Integer> cache = new HashMap<>();
countAndCheck(fullSet, hyperGraph, counter.getAllCount(), cache);
}
@ -69,12 +67,11 @@ public class SubgraphEnumeratorTest {
.addEdge(JoinType.INNER_JOIN, 1, 2)
.addEdge(JoinType.INNER_JOIN, 2, 3)
.build();
BitSet fullSet = new BitSet();
fullSet.set(0, 4);
long fullSet = LongBitmap.newBitmapBetween(0, 4);
Counter counter = new Counter();
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
subgraphEnumerator.enumerate();
HashMap<BitSet, Integer> cache = new HashMap<>();
HashMap<Long, Integer> cache = new HashMap<>();
countAndCheck(fullSet, hyperGraph, counter.getAllCount(), cache);
}
@ -82,14 +79,13 @@ public class SubgraphEnumeratorTest {
void testRandomQuery() {
int tableNum = 10;
int edgeNum = 40;
BitSet fullSet = new BitSet();
fullSet.set(0, tableNum);
long fullSet = LongBitmap.newBitmapBetween(0, tableNum);
for (int i = 0; i < 10; i++) {
HyperGraph hyperGraph = new HyperGraphBuilder().randomBuildWith(tableNum, edgeNum);
Counter counter = new Counter();
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
subgraphEnumerator.enumerate();
HashMap<BitSet, Integer> cache = new HashMap<>();
HashMap<Long, Integer> cache = new HashMap<>();
countAndCheck(fullSet, hyperGraph, counter.getAllCount(), cache);
}
}
@ -109,34 +105,33 @@ public class SubgraphEnumeratorTest {
String.format("enumerate %d tables %d edges cost %f ms", tableNum, edgeNum, endTime - startTime));
}
private int countAndCheck(BitSet bitSet, HyperGraph hyperGraph, HashMap<BitSet, Integer> counter,
HashMap<BitSet, Integer> cache) {
if (cache.containsKey(bitSet)) {
return cache.get(bitSet);
private int countAndCheck(long bitmap, HyperGraph hyperGraph, HashMap<Long, Integer> counter,
HashMap<Long, Integer> cache) {
if (cache.containsKey(bitmap)) {
return cache.get(bitmap);
}
if (bitSet.cardinality() == 1) {
Assertions.assertEquals(counter.get(bitSet), 1,
String.format("The csg-cmp pairs of %s should be %d rather than %s", bitSet, 1,
counter.get(bitSet)));
cache.put(bitSet, 1);
if (LongBitmap.getCardinality(bitmap) == 1) {
Assertions.assertEquals(counter.get(bitmap), 1,
String.format("The csg-cmp pairs of %s should be %d rather than %s", bitmap, 1,
counter.get(bitmap)));
cache.put(bitmap, 1);
return 1;
}
SubsetIterator subsetIterator = new SubsetIterator(bitSet);
LongBitmapSubsetIterator subsetIterator = new LongBitmapSubsetIterator(bitmap);
int count = 0;
HashSet<BitSet> visited = new HashSet<>();
for (BitSet subset : subsetIterator) {
BitSet left = subset;
BitSet right = new BitSet();
right.or(bitSet);
right.andNot(left);
HashSet<Long> visited = new HashSet<>();
for (long subset : subsetIterator) {
long left = subset;
long right = LongBitmap.clone(bitmap);
right = LongBitmap.andNot(right, left);
if (visited.contains(left) || visited.contains(right)) {
continue;
}
visited.add(left);
visited.add(right);
for (Edge edge : hyperGraph.getEdges()) {
if ((Bitmap.isSubset(edge.getLeft(), left) && Bitmap.isSubset(edge.getRight(), right)) || (
Bitmap.isSubset(edge.getLeft(), right) && Bitmap.isSubset(edge.getRight(), left))) {
if ((LongBitmap.isSubset(edge.getLeft(), left) && LongBitmap.isSubset(edge.getRight(), right)) || (
LongBitmap.isSubset(edge.getLeft(), right) && LongBitmap.isSubset(edge.getRight(), left))) {
count += countAndCheck(left, hyperGraph, counter, cache) * countAndCheck(right, hyperGraph,
counter, cache);
break;
@ -144,14 +139,14 @@ public class SubgraphEnumeratorTest {
}
}
if (count == 0) {
Assertions.assertEquals(counter.get(bitSet), null,
String.format("The plan %s should be invalid", bitSet));
Assertions.assertEquals(counter.get(bitmap), null,
String.format("The plan %s should be invalid", bitmap));
} else {
Assertions.assertEquals(counter.get(bitSet), count,
String.format("The csg-cmp pairs of %s should be %d rather than %d", bitSet, count,
counter.get(bitSet)));
Assertions.assertEquals(counter.get(bitmap), count,
String.format("The csg-cmp pairs of %s should be %d rather than %d", bitmap, count,
counter.get(bitmap)));
}
cache.put(bitSet, count);
cache.put(bitmap, count);
return count;
}
}

View File

@ -23,7 +23,6 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.Bitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -118,9 +117,13 @@ public class HyperGraphBuilder {
Preconditions.checkArgument(node2 >= 0 && node1 < rowCounts.size(),
String.format("%d must in [%d, %d)", node1, 0, rowCounts.size()));
BitSet leftBitmap = Bitmap.newBitmap(node1);
BitSet rightBitmap = Bitmap.newBitmap(node2);
BitSet fullBitmap = Bitmap.newBitmapUnion(leftBitmap, rightBitmap);
BitSet leftBitmap = new BitSet();
leftBitmap.set(node1);
BitSet rightBitmap = new BitSet();
rightBitmap.set(node2);
BitSet fullBitmap = new BitSet();
fullBitmap.or(leftBitmap);
fullBitmap.or(rightBitmap);
Optional<BitSet> fullKey = findPlan(fullBitmap);
if (!fullKey.isPresent()) {
Optional<BitSet> leftKey = findPlan(leftBitmap);