[feature](Nereids): add filter edge in hyperGraph (#28006)

This commit is contained in:
谢健
2023-12-11 14:36:43 +08:00
committed by GitHub
parent 593cc92501
commit c2d6fbbc85
19 changed files with 402 additions and 245 deletions

View File

@ -19,6 +19,8 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
@ -81,7 +83,7 @@ public class GraphSimplifier {
*/
public GraphSimplifier(HyperGraph graph) {
this.graph = graph;
edgeSize = graph.getEdges().size();
edgeSize = graph.getJoinEdges().size();
for (int i = 0; i < edgeSize; i++) {
BestSimplification bestSimplification = new BestSimplification();
simplifications.add(bestSimplification);
@ -91,7 +93,7 @@ public class GraphSimplifier {
cacheStats.put(node.getNodeMap(), dPhyperNode.getGroup().getStatistics());
cacheCost.put(node.getNodeMap(), dPhyperNode.getRowCount());
}
validEdges = graph.getEdges().stream()
validEdges = graph.getJoinEdges().stream()
.filter(e -> {
for (Slot slot : e.getJoin().getConditionSlot()) {
boolean contains = false;
@ -136,8 +138,8 @@ public class GraphSimplifier {
public boolean isTotalOrder() {
for (int i = 0; i < edgeSize; i++) {
for (int j = i + 1; j < edgeSize; j++) {
Edge edge1 = graph.getEdge(i);
Edge edge2 = graph.getEdge(j);
Edge edge1 = graph.getJoinEdge(i);
Edge edge2 = graph.getJoinEdge(j);
List<Long> superset = new ArrayList<>();
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset);
@ -342,8 +344,8 @@ public class GraphSimplifier {
}
private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int edgeIndex2) {
Edge edge1 = graph.getEdge(edgeIndex1);
Edge edge2 = graph.getEdge(edgeIndex2);
JoinEdge edge1 = graph.getJoinEdge(edgeIndex1);
JoinEdge edge2 = graph.getJoinEdge(edgeIndex2);
if (edge1.isSub(edge2) || edge2.isSub(edge1)
|| circleDetector.checkCircleWithEdge(edgeIndex1, edgeIndex2)
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)
@ -358,8 +360,8 @@ public class GraphSimplifier {
|| !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) {
return Optional.empty();
}
Edge edge1Before2;
Edge edge2Before1;
JoinEdge edge1Before2;
JoinEdge edge2Before1;
List<Long> superBitset = new ArrayList<>();
if (tryGetSuperset(left1, left2, superBitset)) {
// (common Join1 right1) Join2 right2
@ -394,36 +396,34 @@ public class GraphSimplifier {
return Optional.of(simplificationStep);
}
private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) {
private JoinEdge constructEdge(long leftNodes, JoinEdge edge, long rightNodes) {
LogicalJoin<? extends Plan, ? extends Plan> join;
if (graph.getEdges().size() > 64 * 63 / 8) {
if (graph.getJoinEdges().size() > 64 * 63 / 8) {
// If there are too many edges, it is advisable to return the "edge" directly
// to avoid lengthy enumeration time.
join = edge.getJoin();
} else {
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
List<Expression> hashConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
List<Expression> otherConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
join = edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
}
Edge newEdge = new Edge(
join,
edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes());
newEdge.setRightRequiredNodes(edge.getRightRequiredNodes());
newEdge.addLeftNode(leftNodes);
newEdge.addRightNode(rightNodes);
JoinEdge newEdge = new JoinEdge(join, edge.getIndex(),
edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes(),
edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
newEdge.addLeftExtendNode(leftNodes);
newEdge.addRightExtendNode(rightNodes);
return newEdge;
}
private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) {
private void deriveStats(JoinEdge edge, long leftBitmap, long rightBitmap) {
// The bitmap may differ from the edge's reference slots.
// Taking into account the order: edge1<{1} - {2}> edge2<{1,3} - {4}>.
// Actually, we are considering the sequence {1,3} - {2} - {4}
@ -438,7 +438,7 @@ public class GraphSimplifier {
cacheStats.put(bitmap, joinStats);
}
private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
private double calCost(JoinEdge edge, long leftBitmap, long rightBitmap) {
long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap);
Preconditions.checkArgument(cacheStats.containsKey(leftBitmap) && cacheStats.containsKey(rightBitmap)
&& cacheStats.containsKey(bitmap),
@ -461,7 +461,7 @@ public class GraphSimplifier {
return cost;
}
private @Nullable Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
private @Nullable JoinEdge threeLeftJoin(long bitmap1, JoinEdge edge1, long bitmap2, JoinEdge edge2, long bitmap3) {
// (plan1 edge1 plan2) edge2 plan3
// if the left and right is overlapping, just return null.
Preconditions.checkArgument(
@ -471,7 +471,7 @@ public class GraphSimplifier {
if (LongBitmap.isOverlap(newLeft, bitmap3)) {
return null;
}
Edge newEdge = constructEdge(newLeft, edge2, bitmap3);
JoinEdge newEdge = constructEdge(newLeft, edge2, bitmap3);
deriveStats(edge1, bitmap1, bitmap2);
deriveStats(newEdge, newLeft, bitmap3);
@ -481,15 +481,16 @@ public class GraphSimplifier {
return newEdge;
}
private @Nullable Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
private @Nullable JoinEdge threeRightJoin(long bitmap1, JoinEdge edge1, long bitmap2,
JoinEdge edge2, long bitmap3) {
Preconditions.checkArgument(cacheStats.containsKey(bitmap1)
&& cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
// plan1 edge1 (plan2 edge2 plan3)
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
if (LongBitmap.isOverlap(bitmap1, newRight)) {
return null;
}
Edge newEdge = constructEdge(bitmap1, edge1, newRight);
JoinEdge newEdge = constructEdge(bitmap1, edge1, newRight);
deriveStats(edge2, bitmap2, bitmap3);
deriveStats(newEdge, bitmap1, newRight);
@ -498,8 +499,8 @@ public class GraphSimplifier {
return newEdge;
}
private SimplificationStep orderJoin(Edge edge1Before2,
Edge edge2Before1, int edgeIndex1, int edgeIndex2) {
private SimplificationStep orderJoin(JoinEdge edge1Before2,
JoinEdge edge2Before1, int edgeIndex1, int edgeIndex2) {
double cost1Before2 = calCost(edge1Before2,
edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes());
double cost2Before1 = calCost(edge2Before1,
@ -515,16 +516,16 @@ public class GraphSimplifier {
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2,
edge1Before2.getLeftExtendedNodes(),
edge1Before2.getRightExtendedNodes(),
graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getEdge(edgeIndex2).getRightExtendedNodes());
graph.getJoinEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getJoinEdge(edgeIndex2).getRightExtendedNodes());
} else {
if (cost2Before1 != 0) {
benefit = cost1Before2 / cost2Before1;
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.getLeftExtendedNodes(),
edge2Before1.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getEdge(edgeIndex1).getRightExtendedNodes());
edge2Before1.getRightExtendedNodes(), graph.getJoinEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getJoinEdge(edgeIndex1).getRightExtendedNodes());
}
return step;
}
@ -545,9 +546,9 @@ public class GraphSimplifier {
*/
private void extractJoinDependencies() {
for (int i = 0; i < edgeSize; i++) {
Edge edge1 = graph.getEdge(i);
Edge edge1 = graph.getJoinEdge(i);
for (int j = i + 1; j < edgeSize; j++) {
Edge edge2 = graph.getEdge(j);
Edge edge2 = graph.getJoinEdge(j);
if (edge1.isSub(edge2)) {
Preconditions.checkArgument(circleDetector.tryAddDirectedEdge(i, j),
"Edge %s violates Edge %s", edge1, edge2);

View File

@ -19,11 +19,15 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -32,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.PlanUtils;
@ -53,7 +58,8 @@ import java.util.stream.Collectors;
* It's used for join ordering
*/
public class HyperGraph {
private final List<Edge> edges = new ArrayList<>();
private final List<JoinEdge> joinEdges = new ArrayList<>();
private final List<FilterEdge> filterEdges = new ArrayList<>();
private final List<AbstractNode> nodes = new ArrayList<>();
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
// record all edges that can be placed on the subgraph
@ -69,8 +75,8 @@ public class HyperGraph {
this.finalOutputs = ImmutableSet.copyOf(finalOutputs);
}
public List<Edge> getEdges() {
return edges;
public List<JoinEdge> getJoinEdges() {
return joinEdges;
}
public List<AbstractNode> getNodes() {
@ -81,8 +87,12 @@ public class HyperGraph {
return LongBitmap.newBitmapBetween(0, nodes.size());
}
public Edge getEdge(int index) {
return edges.get(index);
public JoinEdge getJoinEdge(int index) {
return joinEdges.get(index);
}
public FilterEdge getFilterEdge(int index) {
return filterEdges.get(index);
}
public AbstractNode getNode(int index) {
@ -176,17 +186,14 @@ public class HyperGraph {
return complexProject;
}
private void addEdgeOfInfo(Edge edge) {
private void addEdgeOfInfo(JoinEdge edge) {
long nodeMap = calNodeMap(edge.getInputSlots());
Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1,
"edge must have more than one ends");
this.edges.add(new Edge(edge.getJoin(), edges.size(), null, null, null));
long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0));
long right = LongBitmap.newBitmapDiff(nodeMap, left);
edge.setLeftRequiredNodes(left);
edge.setLeftExtendedNodes(left);
edge.setRightRequiredNodes(right);
edge.setRightExtendedNodes(right);
this.joinEdges.add(new JoinEdge(edge.getJoin(), joinEdges.size(),
null, null, 0, left, right));
}
/**
@ -194,7 +201,8 @@ public class HyperGraph {
*
* @param join The join plan
*/
public BitSet addEdge(LogicalJoin<?, ?> join, Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
private BitSet addJoin(LogicalJoin<?, ?> join,
Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
for (Expression expression : join.getHashJoinConjuncts()) {
// TODO: avoid calling calculateEnds if calNodeMap's results are same
@ -217,59 +225,77 @@ public class HyperGraph {
BitSet curJoinEdges = new BitSet();
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
LogicalJoin<?, ?> singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
Lists.newArrayList(join.left(), join.right()));
Edge edge = new Edge(singleJoin, edges.size(), leftEdgeNodes.first, rightEdgeNodes.first,
LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second));
Pair<Long, Long> ends = entry.getKey();
edge.setLeftRequiredNodes(ends.first);
edge.setLeftExtendedNodes(ends.first);
edge.setRightRequiredNodes(ends.second);
edge.setRightExtendedNodes(ends.second);
JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first,
LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second), ends.first, ends.second);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
curJoinEdges.set(edge.getIndex());
edges.add(edge);
joinEdges.add(edge);
}
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
curJoinEdges.stream().forEach(i -> makeConflictRules(edges.get(i)));
curJoinEdges.stream().forEach(i -> joinEdges.get(i).addCurJoinEdges(curJoinEdges));
curJoinEdges.stream().forEach(i -> makeJoinConflictRules(joinEdges.get(i)));
curJoinEdges.stream().forEach(i -> makeFilterConflictRules(joinEdges.get(i)));
return curJoinEdges;
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
// We don't implement this trick now.
}
private BitSet addFilter(LogicalFilter<?> filter, Pair<BitSet, Long> childEdgeNodes) {
FilterEdge edge = new FilterEdge(filter, filterEdges.size(), childEdgeNodes.first, childEdgeNodes.second,
childEdgeNodes.second);
filterEdges.add(edge);
BitSet bitSet = new BitSet();
bitSet.set(edge.getIndex());
return bitSet;
}
private void makeFilterConflictRules(JoinEdge joinEdge) {
long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges);
long rightSubNodes = joinEdge.getRightSubNodes(joinEdges);
filterEdges.forEach(e -> {
if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
}
if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
}
});
}
// Make edge with CD-C algorithm in
// On the correct and complete enumeration of the core search
private void makeConflictRules(Edge edgeB) {
private void makeJoinConflictRules(JoinEdge edgeB) {
BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges());
BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges());
long leftRequired = edgeB.getLeftRequiredNodes();
long rightRequired = edgeB.getRightRequiredNodes();
for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) {
Edge childA = edges.get(i);
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(edges));
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges));
}
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(edges));
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges));
}
}
for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
Edge childA = edges.get(i);
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(edges));
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges));
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(edges));
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges));
}
}
edgeB.setLeftRequiredNodes(leftRequired);
edgeB.setRightRequiredNodes(rightRequired);
edgeB.setLeftExtendedNodes(leftRequired);
edgeB.setRightExtendedNodes(rightRequired);
}
@ -277,7 +303,7 @@ public class HyperGraph {
private BitSet subTreeEdge(Edge edge) {
long subTreeNodes = edge.getSubTreeNodes();
BitSet subEdges = new BitSet();
edges.stream()
joinEdges.stream()
.filter(e -> LongBitmap.isSubset(subTreeNodes, e.getReferenceNodes()))
.forEach(e -> subEdges.set(e.getIndex()));
return subEdges;
@ -286,7 +312,7 @@ public class HyperGraph {
private BitSet subTreeEdges(BitSet edgeSet) {
BitSet bitSet = new BitSet();
edgeSet.stream()
.mapToObj(i -> subTreeEdge(edges.get(i)))
.mapToObj(i -> subTreeEdge(joinEdges.get(i)))
.forEach(bitSet::or);
return bitSet;
}
@ -301,15 +327,19 @@ public class HyperGraph {
if (left == 0) {
Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0,
"the number of the table which expression reference is less 2");
Pair<BitSet, Long> llEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
Pair<BitSet, Long> lrEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
Pair<BitSet, Long> llEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(
joinEdges);
Pair<BitSet, Long> lrEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(
joinEdges);
return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes);
}
if (right == 0) {
Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0,
"the number of the table which expression reference is less 2");
Pair<BitSet, Long> rlEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
Pair<BitSet, Long> rrEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
Pair<BitSet, Long> rlEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(
joinEdges);
Pair<BitSet, Long> rrEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(
joinEdges);
return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes);
}
return Pair.of(left, right);
@ -329,7 +359,7 @@ public class HyperGraph {
public BitSet getEdgesInTree(long treeNodesMap) {
if (!treeEdgesCache.containsKey(treeNodesMap)) {
BitSet edgesMap = new BitSet();
for (Edge edge : edges) {
for (Edge edge : joinEdges) {
if (LongBitmap.isSubset(edge.getReferenceNodes(), treeNodesMap)) {
edgesMap.set(edge.getIndex());
}
@ -364,7 +394,7 @@ public class HyperGraph {
for (AbstractNode node : nodes) {
res = flatChild((StructInfoNode) node, res);
}
for (Edge edge : edges) {
for (JoinEdge edge : joinEdges) {
res.forEach(g -> g.addEdgeOfInfo(edge));
}
return res;
@ -376,12 +406,12 @@ public class HyperGraph {
return hyperGraphs;
}
return hyperGraphs.stream().flatMap(g ->
infoNode.getGraphs().stream().map(subGraph -> {
HyperGraph hyperGraph = new HyperGraph(g.finalOutputs);
hyperGraph.addStructInfo(g);
hyperGraph.addStructInfo(subGraph);
return hyperGraph;
})
infoNode.getGraphs().stream().map(subGraph -> {
HyperGraph hyperGraph = new HyperGraph(g.finalOutputs);
hyperGraph.addStructInfo(g);
hyperGraph.addStructInfo(subGraph);
return hyperGraph;
})
).collect(Collectors.toList());
}
@ -410,7 +440,7 @@ public class HyperGraph {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) groupExpression.getPlan();
Pair<BitSet, Long> left = this.buildDPhyperGraph(groupExpression.child(0).getLogicalExpressions().get(0));
Pair<BitSet, Long> right = this.buildDPhyperGraph(groupExpression.child(1).getLogicalExpressions().get(0));
return Pair.of(this.addEdge(join, left, right),
return Pair.of(this.addJoin(join, left, right),
LongBitmap.or(left.second, right.second));
}
@ -424,10 +454,10 @@ public class HyperGraph {
other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan()));
other.getComplexProject().forEach((t, projectList) ->
projectList.forEach(e -> this.addAlias((Alias) e, t << offset)));
other.getEdges().forEach(this::addEdgeOfInfo);
other.getJoinEdges().forEach(this::addEdgeOfInfo);
}
// Build Graph for matching mv
// Build Graph for matching mv, return join edge set and nodes in this plan
private Pair<BitSet, Long> buildStructInfo(Plan plan) {
if (plan instanceof GroupPlan) {
Group group = ((GroupPlan) plan).getGroup();
@ -454,14 +484,21 @@ public class HyperGraph {
}
// process Join
if (isValidJoin(plan)) {
if (isValidJoinForStructInfo(plan)) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
Pair<BitSet, Long> left = this.buildStructInfo(plan.child(0));
Pair<BitSet, Long> right = this.buildStructInfo(plan.child(1));
return Pair.of(this.addEdge(join, left, right),
return Pair.of(this.addJoin(join, left, right),
LongBitmap.or(left.second, right.second));
}
if (isValidFilter(plan)) {
LogicalFilter<?> filter = (LogicalFilter<?>) plan;
Pair<BitSet, Long> child = this.buildStructInfo(filter.child());
this.addFilter(filter, child);
return Pair.of(new BitSet(), child.second);
}
// process Other Node
int idx = this.addStructInfoNode(plan);
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
@ -480,6 +517,23 @@ public class HyperGraph {
&& !join.getExpressions().isEmpty();
}
/**
* inner join group without mark slot
*/
public static boolean isValidJoinForStructInfo(Plan plan) {
if (!(plan instanceof LogicalJoin)) {
return false;
}
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
return !join.isMarkJoin()
&& !join.getExpressions().isEmpty();
}
public static boolean isValidFilter(Plan plan) {
return plan instanceof LogicalFilter;
}
/**
* the project with alias and slot
*/
@ -502,14 +556,14 @@ public class HyperGraph {
// When modify an edge in hyper graph, we need to update the left and right nodes
// For these nodes that are only in the old edge, we need remove the edge from them
// For these nodes that are only in the new edge, we need to add the edge to them
Edge edge = edges.get(edgeIndex);
Edge edge = joinEdges.get(edgeIndex);
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, false);
}
updateEdges(edge, edge.getLeftExtendedNodes(), newLeft);
updateEdges(edge, edge.getRightExtendedNodes(), newRight);
edges.get(edgeIndex).setLeftExtendedNodes(newLeft);
edges.get(edgeIndex).setRightExtendedNodes(newRight);
joinEdges.get(edgeIndex).setLeftExtendedNodes(newLeft);
joinEdges.get(edgeIndex).setRightExtendedNodes(newRight);
if (treeEdgesCache.containsKey(edge.getReferenceNodes())) {
treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, true);
}
@ -534,7 +588,7 @@ public class HyperGraph {
*/
public String toDottyHyperGraph() {
StringBuilder builder = new StringBuilder();
builder.append(String.format("digraph G { # %d edges\n", edges.size()));
builder.append(String.format("digraph G { # %d edges\n", joinEdges.size()));
List<String> graphvisNodes = new ArrayList<>();
for (AbstractNode node : nodes) {
String nodeName = node.getName();
@ -550,11 +604,11 @@ public class HyperGraph {
nodeID, nodeName, rowCount));
graphvisNodes.add(nodeName);
}
for (int i = 0; i < edges.size(); i += 1) {
Edge edge = edges.get(i);
for (int i = 0; i < joinEdges.size(); i += 1) {
JoinEdge edge = joinEdges.get(i);
// TODO: add cardinality to label
String label = String.format("%.2f", edge.getSelectivity());
if (edges.get(i).isSimple()) {
if (joinEdges.get(i).isSimple()) {
String arrowHead = "";
if (edge.getJoin().getJoinType() == JoinType.INNER_JOIN) {
arrowHead = ",arrowhead=none";

View File

@ -19,6 +19,8 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver;
@ -79,7 +81,7 @@ public class SubgraphEnumerator {
int size = nodes.size();
// Init edgeCalculator
edgeCalculator = new EdgeCalculator(hyperGraph.getEdges());
edgeCalculator = new EdgeCalculator(hyperGraph.getJoinEdges());
for (AbstractNode node : nodes) {
edgeCalculator.initSubgraph(node.getNodeMap());
}
@ -149,7 +151,7 @@ public class SubgraphEnumerator {
edgeCalculator.unionEdges(cmp, subset);
if (receiver.contain(newCmp)) {
// We check all edges for finding an edge.
List<Edge> edges = edgeCalculator.connectCsgCmp(csg, newCmp);
List<JoinEdge> edges = edgeCalculator.connectCsgCmp(csg, newCmp);
if (edges.isEmpty()) {
continue;
}
@ -185,7 +187,7 @@ public class SubgraphEnumerator {
for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) {
long cmp = LongBitmap.newBitmap(nodeIndex);
// whether there is an edge between csg and cmp
List<Edge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
List<JoinEdge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
if (!edges.isEmpty()) {
if (!receiver.emitCsgCmp(csg, cmp, edges)) {
return false;
@ -241,7 +243,7 @@ public class SubgraphEnumerator {
}
static class EdgeCalculator {
final List<Edge> edges;
final List<JoinEdge> edges;
// It cached all edges that contained by this subgraph, Note we always
// use bitset store edge map because the number of edges can be very large
// We split these into simple edges (only one node on each side) and complex edges (others)
@ -254,7 +256,7 @@ public class SubgraphEnumerator {
// complex edges
HashMap<Long, BitSet> overlapEdges = new HashMap<>();
EdgeCalculator(List<Edge> edges) {
EdgeCalculator(List<JoinEdge> edges) {
this.edges = edges;
}
@ -326,10 +328,10 @@ public class SubgraphEnumerator {
overlapEdges.put(subgraph, overlaps);
}
public List<Edge> connectCsgCmp(long csg, long cmp) {
public List<JoinEdge> connectCsgCmp(long csg, long cmp) {
Preconditions.checkArgument(
containSimpleEdges.containsKey(csg) && containSimpleEdges.containsKey(cmp));
List<Edge> foundEdges = new ArrayList<>();
List<JoinEdge> foundEdges = new ArrayList<>();
BitSet edgeMap = new BitSet();
edgeMap.or(containSimpleEdges.get(csg));
edgeMap.and(containSimpleEdges.get(cmp));

View File

@ -15,98 +15,68 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
/**
* Edge in HyperGraph
*/
public class Edge {
final int index;
final LogicalJoin<? extends Plan, ? extends Plan> join;
final double selectivity;
public abstract class Edge {
private final int index;
private final double selectivity;
// "RequiredNodes" refers to the nodes that can activate this edge based on
// specific requirements. These requirements are established during the building process.
// "ExtendNodes" encompasses both the "RequiredNodes" and any additional nodes
// added by the graph simplifier.
private long leftRequiredNodes = LongBitmap.newBitmap();
private long rightRequiredNodes = LongBitmap.newBitmap();
private long leftExtendedNodes = LongBitmap.newBitmap();
private long rightExtendedNodes = LongBitmap.newBitmap();
private long referenceNodes = LongBitmap.newBitmap();
private final long leftRequiredNodes;
private final long rightRequiredNodes;
private long leftExtendedNodes;
private long rightExtendedNodes;
// record the left child edges and right child edges in origin plan tree
private final BitSet leftChildEdges;
private final BitSet rightChildEdges;
// record the edges in the same operator
private final BitSet curJoinEdges = new BitSet();
private final BitSet curOperatorEdges = new BitSet();
// record all sub nodes behind in this operator. It's T function in paper
private final long subTreeNodes;
/**
* Create simple edge.
*/
public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChildEdges, Long subTreeNodes) {
Edge(int index, BitSet leftChildEdges, BitSet rightChildEdges,
long subTreeNodes, long leftRequiredNodes, Long rightRequiredNodes) {
this.index = index;
this.join = join;
this.selectivity = 1.0;
this.leftChildEdges = leftChildEdges;
this.rightChildEdges = rightChildEdges;
this.leftRequiredNodes = leftRequiredNodes;
this.rightRequiredNodes = rightRequiredNodes;
this.leftExtendedNodes = leftRequiredNodes;
this.rightExtendedNodes = rightRequiredNodes;
this.subTreeNodes = subTreeNodes;
}
public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
return join;
}
public JoinType getJoinType() {
return join.getJoinType();
}
public boolean isSimple() {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}
public void addLeftNode(long left) {
public void addLeftExtendNode(long left) {
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left);
referenceNodes = LongBitmap.or(referenceNodes, left);
}
public void addLeftNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, bitmap);
referenceNodes = LongBitmap.or(referenceNodes, bitmap);
}
}
public void addRightNode(long right) {
public void addRightExtendNode(long right) {
this.rightExtendedNodes = LongBitmap.or(this.rightExtendedNodes, right);
referenceNodes = LongBitmap.or(referenceNodes, right);
}
public void addRightNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
LongBitmap.or(this.rightExtendedNodes, bitmap);
LongBitmap.or(referenceNodes, bitmap);
}
}
public long getSubTreeNodes() {
@ -121,22 +91,22 @@ public class Edge {
return leftChildEdges;
}
public Pair<BitSet, Long> getLeftEdgeNodes(List<Edge> edges) {
public Pair<BitSet, Long> getLeftEdgeNodes(List<JoinEdge> edges) {
return Pair.of(leftChildEdges, getLeftSubNodes(edges));
}
public Pair<BitSet, Long> getRightEdgeNodes(List<Edge> edges) {
public Pair<BitSet, Long> getRightEdgeNodes(List<JoinEdge> edges) {
return Pair.of(rightChildEdges, getRightSubNodes(edges));
}
public long getLeftSubNodes(List<Edge> edges) {
public long getLeftSubNodes(List<JoinEdge> edges) {
if (leftChildEdges.isEmpty()) {
return leftRequiredNodes;
}
return edges.get(leftChildEdges.nextSetBit(0)).getSubTreeNodes();
}
public long getRightSubNodes(List<Edge> edges) {
public long getRightSubNodes(List<JoinEdge> edges) {
if (rightChildEdges.isEmpty()) {
return rightRequiredNodes;
}
@ -144,7 +114,6 @@ public class Edge {
}
public void setLeftExtendedNodes(long leftExtendedNodes) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.leftExtendedNodes = leftExtendedNodes;
}
@ -157,7 +126,6 @@ public class Edge {
}
public void setRightExtendedNodes(long rightExtendedNodes) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.rightExtendedNodes = rightExtendedNodes;
}
@ -165,24 +133,16 @@ public class Edge {
return leftRequiredNodes;
}
public void setLeftRequiredNodes(long left) {
this.leftRequiredNodes = left;
}
public long getRightRequiredNodes() {
return rightRequiredNodes;
}
public void setRightRequiredNodes(long right) {
this.rightRequiredNodes = right;
}
public void addCurJoinEdges(BitSet edges) {
curJoinEdges.or(edges);
curOperatorEdges.or(edges);
}
public BitSet getCurJoinEdges() {
return curJoinEdges;
public BitSet getCurOperatorEdges() {
return curOperatorEdges;
}
public boolean isSub(Edge edge) {
@ -192,10 +152,7 @@ public class Edge {
}
public long getReferenceNodes() {
if (LongBitmap.getCardinality(referenceNodes) == 0) {
referenceNodes = LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes);
}
return referenceNodes;
return LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes);
}
public long getRequireNodes() {
@ -210,51 +167,14 @@ public class Edge {
return selectivity;
}
public Expression getExpression() {
Preconditions.checkArgument(join.getExpressions().size() == 1);
return join.getExpressions().get(0);
}
public abstract Set<Slot> getInputSlots();
public List<Expression> getHashJoinConjuncts() {
return join.getHashJoinConjuncts();
}
public List<Expression> getOtherJoinConjuncts() {
return join.getOtherJoinConjuncts();
}
public final Set<Slot> getInputSlots() {
Set<Slot> slots = new HashSet<>();
join.getExpressions().stream().forEach(expression -> slots.addAll(expression.getInputSlots()));
return slots;
}
public abstract List<? extends Expression> getExpressions();
@Override
public String toString() {
return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString(
rightExtendedNodes));
}
/**
* extract join type and conjuncts from edges
*/
public static @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges,
List<Expression> hashConjuncts, List<Expression> otherConjuncts) {
JoinType joinType = null;
for (Edge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
hashConjuncts.addAll(edge.getHashJoinConjuncts());
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
}
return joinType;
}
public static Edge createTempEdge(LogicalJoin join) {
return new Edge(join, -1, null, null, 0L);
}
}

View File

@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Set;
/**
* Edge represents a filter
*/
public class FilterEdge extends Edge {
private final LogicalFilter<? extends Plan> filter;
private final List<Integer> rejectEdges;
public FilterEdge(LogicalFilter<? extends Plan> filter, int index,
BitSet childEdges, long subTreeNodes, long childRequireNodes) {
super(index, childEdges, new BitSet(), subTreeNodes, childRequireNodes, 0L);
this.filter = filter;
rejectEdges = new ArrayList<>();
}
public void addRejectJoin(JoinEdge joinEdge) {
rejectEdges.add(joinEdge.getIndex());
}
public List<Integer> getRejectEdges() {
return rejectEdges;
}
@Override
public Set<Slot> getInputSlots() {
return filter.getInputSlots();
}
@Override
public List<? extends Expression> getExpressions() {
return filter.getExpressions();
}
}

View File

@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
/**
* Edge represents a join
*/
public class JoinEdge extends Edge {
private final LogicalJoin<? extends Plan, ? extends Plan> join;
public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int index,
BitSet leftChildEdges, BitSet rightChildEdges, long subTreeNodes,
long leftRequireNodes, long rightRequireNodes) {
super(index, leftChildEdges, rightChildEdges, subTreeNodes, leftRequireNodes, rightRequireNodes);
this.join = join;
}
public JoinType getJoinType() {
return join.getJoinType();
}
public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
return join;
}
/**
* extract join type for edges and push them in hash conjuncts and other conjuncts
*/
public static @Nullable JoinType extractJoinTypeAndConjuncts(List<JoinEdge> edges,
List<Expression> hashConjuncts, List<Expression> otherConjuncts) {
JoinType joinType = null;
for (JoinEdge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
hashConjuncts.addAll(edge.getHashJoinConjuncts());
otherConjuncts.addAll(edge.getOtherJoinConjuncts());
}
return joinType;
}
public Expression getExpression() {
Preconditions.checkArgument(join.getExpressions().size() == 1);
return join.getExpressions().get(0);
}
@Override
public List<? extends Expression> getExpressions() {
return join.getExpressions();
}
public List<Expression> getHashJoinConjuncts() {
return join.getHashJoinConjuncts();
}
public List<Expression> getOtherJoinConjuncts() {
return join.getOtherJoinConjuncts();
}
@Override
public Set<Slot> getInputSlots() {
Set<Slot> slots = new HashSet<>();
join.getExpressions().forEach(expression -> slots.addAll(expression.getInputSlots()));
return slots;
}
}

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.memo.Group;
import com.google.common.base.Preconditions;

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.memo.Group;
import java.util.List;
@ -26,7 +26,7 @@ import java.util.List;
* A interface of receiver
*/
public interface AbstractReceiver {
boolean emitCsgCmp(long csg, long cmp, List<Edge> edges);
boolean emitCsgCmp(long csg, long cmp, List<JoinEdge> edges);
void addGroup(long bitSet, Group group);

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.memo.Group;
import com.google.common.base.Preconditions;
@ -51,7 +51,7 @@ public class Counter implements AbstractReceiver {
* @param edges the join operator
* @return the left and the right can be connected by the edge
*/
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<JoinEdge> edges) {
Preconditions.checkArgument(counter.containsKey(left));
Preconditions.checkArgument(counter.containsKey(right));
emitCount += 1;

View File

@ -20,9 +20,10 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.CostAndEnforcerJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -103,7 +104,7 @@ public class PlanReceiver implements AbstractReceiver {
* @return the left and the right can be connected by the edge
*/
@Override
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<JoinEdge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
processMissedEdges(left, right, edges);
@ -122,7 +123,7 @@ public class PlanReceiver implements AbstractReceiver {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
JoinType joinType = JoinEdge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
}
@ -149,7 +150,7 @@ public class PlanReceiver implements AbstractReceiver {
// be aware that the requiredOutputSlots is a superset of the actual output of current node
// check proposeProject method to get how to create a project node for the outputs of current node.
private Set<Slot> calculateRequiredSlots(long left, long right, List<Edge> edges) {
private Set<Slot> calculateRequiredSlots(long left, long right, List<JoinEdge> edges) {
// required output slots = final outputs + slot of unused edges + complex project exprs(if there is any)
// 1. add finalOutputs to requiredOutputSlots
Set<Slot> requiredOutputSlots = new HashSet<>(this.finalOutputs);
@ -162,7 +163,7 @@ public class PlanReceiver implements AbstractReceiver {
// 2. add unused edges' input slots to requiredOutputSlots
usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap);
for (Edge edge : hyperGraph.getEdges()) {
for (Edge edge : hyperGraph.getJoinEdges()) {
if (!usedEdgesBitmap.get(edge.getIndex())) {
requiredOutputSlots.addAll(edge.getInputSlots());
}
@ -180,7 +181,7 @@ public class PlanReceiver implements AbstractReceiver {
}
// add any missed edge into edges to connect left and right
private void processMissedEdges(long left, long right, List<Edge> edges) {
private void processMissedEdges(long left, long right, List<JoinEdge> edges) {
// find all used edges
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
@ -191,9 +192,8 @@ public class PlanReceiver implements AbstractReceiver {
long allReferenceNodes = LongBitmap.or(left, right);
// find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes
for (Edge edge : hyperGraph.getEdges()) {
long referenceNodes =
LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
long referenceNodes = LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
if (LongBitmap.isSubset(referenceNodes, allReferenceNodes)
&& !usedEdgesBitmap.get(edge.getIndex())) {
// add the missed edge to edges
@ -220,8 +220,8 @@ public class PlanReceiver implements AbstractReceiver {
List<Plan> plans = Lists.newArrayList();
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right));
Optional.empty(), joinProperties,
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
@ -241,17 +241,6 @@ public class PlanReceiver implements AbstractReceiver {
return plans;
}
private boolean extractIsMarkJoin(List<Edge> edges) {
boolean isMarkJoin = false;
JoinType joinType = null;
for (Edge edge : edges) {
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
isMarkJoin = edge.getJoin().isMarkJoin() || isMarkJoin;
joinType = edge.getJoinType();
}
return isMarkJoin;
}
@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
@ -330,7 +319,7 @@ public class PlanReceiver implements AbstractReceiver {
}
}
private List<Plan> proposeProject(List<Plan> allChild, List<Edge> edges, long left, long right) {
private List<Plan> proposeProject(List<Plan> allChild, List<JoinEdge> edges, long left, long right) {
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Slot> outputs = allChild.get(0).getOutput();
Set<Slot> outputSet = allChild.get(0).getOutputSet();

View File

@ -17,8 +17,8 @@
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
@ -91,7 +91,7 @@ public abstract class AbstractMaterializedViewJoinRule extends AbstractMateriali
return false;
}
}
for (Edge edge : hyperGraph.getEdges()) {
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER,
SUPPORTED_JOIN_TYPE_SET)) {
return false;

View File

@ -96,7 +96,7 @@ public class StructInfo {
this.predicates = Predicates.of();
// Collect predicate from join condition in hyper graph
this.hyperGraph.getEdges().forEach(edge -> {
this.hyperGraph.getJoinEdges().forEach(edge -> {
List<Expression> hashJoinConjuncts = edge.getHashJoinConjuncts();
hashJoinConjuncts.forEach(conjunctExpr -> {
predicates.addPredicate(conjunctExpr);

View File

@ -41,7 +41,7 @@ import java.util.Set;
public class PushDownFilterThroughJoin extends OneRewriteRuleFactory {
public static final PushDownFilterThroughJoin INSTANCE = new PushDownFilterThroughJoin();
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
public static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_SEMI_JOIN,
@ -50,7 +50,7 @@ public class PushDownFilterThroughJoin extends OneRewriteRuleFactory {
JoinType.CROSS_JOIN
);
private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(
public static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.RIGHT_OUTER_JOIN,
JoinType.RIGHT_SEMI_JOIN,

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -261,7 +262,7 @@ class GraphSimplifierTest {
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
graphSimplifier.simplifyGraph(1);
for (Edge edge : hyperGraph.getEdges()) {
for (Edge edge : hyperGraph.getJoinEdges()) {
System.out.println(edge);
}
Assertions.assertTrue(subgraphEnumerator.enumerate());

View File

@ -102,7 +102,7 @@ public class HyperGraphTest {
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder();
HyperGraph hyperGraph = hyperGraphBuilder.randomBuildWith(tableNum, edgeNum);
Assertions.assertEquals(hyperGraph.getNodes().size(), tableNum);
Assertions.assertEquals(hyperGraph.getEdges().size(), edgeNum);
Assertions.assertEquals(hyperGraph.getJoinEdges().size(), edgeNum);
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.HyperGraphBuilder;
@ -129,7 +130,7 @@ public class SubgraphEnumeratorTest {
}
visited.add(left);
visited.add(right);
for (Edge edge : hyperGraph.getEdges()) {
for (Edge edge : hyperGraph.getJoinEdges()) {
if ((LongBitmap.isSubset(edge.getLeftExtendedNodes(), left) && LongBitmap.isSubset(edge.getRightExtendedNodes(), right)) || (
LongBitmap.isSubset(edge.getLeftExtendedNodes(), right) && LongBitmap.isSubset(edge.getRightExtendedNodes(), left))) {
count += countAndCheck(left, hyperGraph, counter, cache) * countAndCheck(right, hyperGraph,

View File

@ -66,4 +66,34 @@ class BuildStructInfoTest extends SqlTestBase {
}));
}
@Test
void testFilter() {
String sql = "select * from T1 left outer join "
+ " (select id from T2 where id = 1) T2 "
+ "on T1.id = T2.id ";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin()
.when(j -> {
HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0);
Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin());
Assertions.assertEquals(0, (int) structInfo.getFilterEdge(0).getRejectEdges().get(0));
return true;
}));
sql = "select * from (select id from T1 where id = 0) T1 left outer join T2 "
+ "on T1.id = T2.id ";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin()
.when(j -> {
HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0);
Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin());
Assertions.assertTrue(structInfo.getFilterEdge(0).getRejectEdges().isEmpty());
return true;
}));
}
}