[fix](Nereids) fix ends calculation when there are constant project (#22265)
This commit is contained in:
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.jobs.Job;
|
||||
@ -26,6 +27,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
@ -111,20 +113,22 @@ public class JoinOrderJob extends Job {
|
||||
*
|
||||
* @param group root group, should be join type
|
||||
* @param hyperGraph build hyperGraph
|
||||
*
|
||||
* @return return edges of group's child and subTreeNodes of this group
|
||||
*/
|
||||
public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
|
||||
public Pair<BitSet, Long> buildGraph(Group group, HyperGraph hyperGraph) {
|
||||
if (group.isProjectGroup()) {
|
||||
BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
|
||||
processProjectPlan(hyperGraph, group);
|
||||
return edgeMap;
|
||||
Pair<BitSet, Long> res = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
|
||||
processProjectPlan(hyperGraph, group, res.second);
|
||||
return res;
|
||||
}
|
||||
if (!group.isValidJoinGroup()) {
|
||||
hyperGraph.addNode(optimizePlan(group));
|
||||
return new BitSet();
|
||||
int idx = hyperGraph.addNode(optimizePlan(group));
|
||||
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
|
||||
}
|
||||
BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
|
||||
BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
|
||||
return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap);
|
||||
Pair<BitSet, Long> left = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
|
||||
Pair<BitSet, Long> right = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
|
||||
return Pair.of(hyperGraph.addEdge(group, left, right), LongBitmap.or(left.second, right.second));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -133,14 +137,14 @@ public class JoinOrderJob extends Job {
|
||||
* 2. If it's an alias that may be used in the join operator, we need to add it to graph
|
||||
* 3. If it's other expression, we can ignore them and add it after optimizing
|
||||
*/
|
||||
private void processProjectPlan(HyperGraph hyperGraph, Group group) {
|
||||
private void processProjectPlan(HyperGraph hyperGraph, Group group, long subTreeNodes) {
|
||||
LogicalProject<? extends Plan> logicalProject
|
||||
= (LogicalProject<? extends Plan>) group.getLogicalExpression()
|
||||
.getPlan();
|
||||
|
||||
for (NamedExpression expr : logicalProject.getProjects()) {
|
||||
if (expr instanceof Alias) {
|
||||
hyperGraph.addAlias((Alias) expr);
|
||||
hyperGraph.addAlias((Alias) expr, subTreeNodes);
|
||||
} else if (!expr.isSlot()) {
|
||||
otherProject.add(expr);
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -26,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.BitSet;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
@ -38,23 +40,36 @@ public class Edge {
|
||||
final LogicalJoin<? extends Plan, ? extends Plan> join;
|
||||
final double selectivity;
|
||||
|
||||
// The endpoints (hyperNodes) of this hyperEdge.
|
||||
// left and right may not overlap, and both must have at least one bit set.
|
||||
private long left = LongBitmap.newBitmap();
|
||||
private long right = LongBitmap.newBitmap();
|
||||
|
||||
private long originalLeft = LongBitmap.newBitmap();
|
||||
private long originalRight = LongBitmap.newBitmap();
|
||||
// "RequiredNodes" refers to the nodes that can activate this edge based on
|
||||
// specific requirements. These requirements are established during the building process.
|
||||
// "ExtendNodes" encompasses both the "RequiredNodes" and any additional nodes
|
||||
// added by the graph simplifier.
|
||||
private long leftRequiredNodes = LongBitmap.newBitmap();
|
||||
private long rightRequiredNodes = LongBitmap.newBitmap();
|
||||
private long leftExtendedNodes = LongBitmap.newBitmap();
|
||||
private long rightExtendedNodes = LongBitmap.newBitmap();
|
||||
|
||||
private long referenceNodes = LongBitmap.newBitmap();
|
||||
|
||||
// record the left child edges and right child edges in origin plan tree
|
||||
private BitSet leftChildEdges;
|
||||
private BitSet rightChildEdges;
|
||||
|
||||
// record the edges in the same operator
|
||||
private BitSet curJoinEdges = new BitSet();
|
||||
// record all sub nodes behind in this operator. It's T function in paper
|
||||
private Long subTreeNodes;
|
||||
|
||||
/**
|
||||
* Create simple edge.
|
||||
*/
|
||||
public Edge(LogicalJoin join, int index) {
|
||||
public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChildEdges, Long subTreeNodes) {
|
||||
this.index = index;
|
||||
this.join = join;
|
||||
this.selectivity = 1.0;
|
||||
this.leftChildEdges = leftChildEdges;
|
||||
this.rightChildEdges = rightChildEdges;
|
||||
this.subTreeNodes = subTreeNodes;
|
||||
}
|
||||
|
||||
public LogicalJoin getJoin() {
|
||||
@ -66,65 +81,107 @@ public class Edge {
|
||||
}
|
||||
|
||||
public boolean isSimple() {
|
||||
return LongBitmap.getCardinality(left) == 1 && LongBitmap.getCardinality(right) == 1;
|
||||
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
|
||||
}
|
||||
|
||||
public void addLeftNode(long left) {
|
||||
this.left = LongBitmap.or(this.left, left);
|
||||
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left);
|
||||
referenceNodes = LongBitmap.or(referenceNodes, left);
|
||||
}
|
||||
|
||||
public void addLeftNodes(long... bitmaps) {
|
||||
for (long bitmap : bitmaps) {
|
||||
this.left = LongBitmap.or(this.left, bitmap);
|
||||
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, bitmap);
|
||||
referenceNodes = LongBitmap.or(referenceNodes, bitmap);
|
||||
}
|
||||
}
|
||||
|
||||
public void addRightNode(long right) {
|
||||
this.right = LongBitmap.or(this.right, right);
|
||||
this.rightExtendedNodes = LongBitmap.or(this.rightExtendedNodes, right);
|
||||
referenceNodes = LongBitmap.or(referenceNodes, right);
|
||||
}
|
||||
|
||||
public void addRightNodes(long... bitmaps) {
|
||||
for (long bitmap : bitmaps) {
|
||||
LongBitmap.or(this.right, bitmap);
|
||||
LongBitmap.or(this.rightExtendedNodes, bitmap);
|
||||
LongBitmap.or(referenceNodes, bitmap);
|
||||
}
|
||||
}
|
||||
|
||||
public long getLeft() {
|
||||
return left;
|
||||
public long getSubTreeNodes() {
|
||||
return this.subTreeNodes;
|
||||
}
|
||||
|
||||
public void setLeft(long left) {
|
||||
public long getLeftExtendedNodes() {
|
||||
return leftExtendedNodes;
|
||||
}
|
||||
|
||||
public BitSet getLeftChildEdges() {
|
||||
return leftChildEdges;
|
||||
}
|
||||
|
||||
public Pair<BitSet, Long> getLeftEdgeNodes(List<Edge> edges) {
|
||||
return Pair.of(leftChildEdges, getLeftSubNodes(edges));
|
||||
}
|
||||
|
||||
public Pair<BitSet, Long> getRightEdgeNodes(List<Edge> edges) {
|
||||
return Pair.of(rightChildEdges, getRightSubNodes(edges));
|
||||
}
|
||||
|
||||
public long getLeftSubNodes(List<Edge> edges) {
|
||||
if (leftChildEdges.isEmpty()) {
|
||||
return leftRequiredNodes;
|
||||
}
|
||||
return edges.get(leftChildEdges.nextSetBit(0)).getSubTreeNodes();
|
||||
}
|
||||
|
||||
public long getRightSubNodes(List<Edge> edges) {
|
||||
if (rightChildEdges.isEmpty()) {
|
||||
return rightRequiredNodes;
|
||||
}
|
||||
return edges.get(rightChildEdges.nextSetBit(0)).getSubTreeNodes();
|
||||
}
|
||||
|
||||
public void setLeftExtendedNodes(long leftExtendedNodes) {
|
||||
referenceNodes = LongBitmap.clear(referenceNodes);
|
||||
this.left = left;
|
||||
this.leftExtendedNodes = leftExtendedNodes;
|
||||
}
|
||||
|
||||
public long getRight() {
|
||||
return right;
|
||||
public long getRightExtendedNodes() {
|
||||
return rightExtendedNodes;
|
||||
}
|
||||
|
||||
public void setRight(long right) {
|
||||
public BitSet getRightChildEdges() {
|
||||
return rightChildEdges;
|
||||
}
|
||||
|
||||
public void setRightExtendedNodes(long rightExtendedNodes) {
|
||||
referenceNodes = LongBitmap.clear(referenceNodes);
|
||||
this.right = right;
|
||||
this.rightExtendedNodes = rightExtendedNodes;
|
||||
}
|
||||
|
||||
public long getOriginalLeft() {
|
||||
return originalLeft;
|
||||
public long getLeftRequiredNodes() {
|
||||
return leftRequiredNodes;
|
||||
}
|
||||
|
||||
public void setOriginalLeft(long left) {
|
||||
this.originalLeft = left;
|
||||
public void setLeftRequiredNodes(long left) {
|
||||
this.leftRequiredNodes = left;
|
||||
}
|
||||
|
||||
public long getOriginalRight() {
|
||||
return originalRight;
|
||||
public long getRightRequiredNodes() {
|
||||
return rightRequiredNodes;
|
||||
}
|
||||
|
||||
public void setOriginalRight(long right) {
|
||||
this.originalRight = right;
|
||||
public void setRightRequiredNodes(long right) {
|
||||
this.rightRequiredNodes = right;
|
||||
}
|
||||
|
||||
public void addCurJoinEdges(BitSet edges) {
|
||||
curJoinEdges.or(edges);
|
||||
}
|
||||
|
||||
public BitSet getCurJoinEdges() {
|
||||
return curJoinEdges;
|
||||
}
|
||||
|
||||
public boolean isSub(Edge edge) {
|
||||
@ -135,11 +192,15 @@ public class Edge {
|
||||
|
||||
public long getReferenceNodes() {
|
||||
if (LongBitmap.getCardinality(referenceNodes) == 0) {
|
||||
referenceNodes = LongBitmap.newBitmapUnion(left, right);
|
||||
referenceNodes = LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes);
|
||||
}
|
||||
return referenceNodes;
|
||||
}
|
||||
|
||||
public long getRequireNodes() {
|
||||
return LongBitmap.newBitmapUnion(leftRequiredNodes, rightRequiredNodes);
|
||||
}
|
||||
|
||||
public int getIndex() {
|
||||
return index;
|
||||
}
|
||||
@ -165,7 +226,8 @@ public class Edge {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right));
|
||||
return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString(
|
||||
rightExtendedNodes));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -108,10 +108,10 @@ public class GraphSimplifier {
|
||||
Edge edge1 = graph.getEdge(i);
|
||||
Edge edge2 = graph.getEdge(j);
|
||||
List<Long> superset = new ArrayList<>();
|
||||
tryGetSuperset(edge1.getLeft(), edge2.getLeft(), superset);
|
||||
tryGetSuperset(edge1.getLeft(), edge2.getRight(), superset);
|
||||
tryGetSuperset(edge1.getRight(), edge2.getLeft(), superset);
|
||||
tryGetSuperset(edge1.getRight(), edge2.getRight(), superset);
|
||||
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
|
||||
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset);
|
||||
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
|
||||
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes(), superset);
|
||||
if (!circleDetector.checkCircleWithEdge(i, j) && !circleDetector.checkCircleWithEdge(j, i)
|
||||
&& !edge2.isSub(edge1) && !edge1.isSub(edge2) && !superset.isEmpty()) {
|
||||
return false;
|
||||
@ -213,8 +213,8 @@ public class GraphSimplifier {
|
||||
BestSimplification bestSimplification = priorityQueue.poll();
|
||||
bestSimplification.isInQueue = false;
|
||||
SimplificationStep bestStep = bestSimplification.getStep();
|
||||
while (bestSimplification.bestNeighbor == -1 || !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex,
|
||||
bestStep.afterIndex)) {
|
||||
while (bestSimplification.bestNeighbor == -1
|
||||
|| !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex, bestStep.afterIndex)) {
|
||||
processNeighbors(bestStep.afterIndex, 0, edgeSize);
|
||||
if (priorityQueue.isEmpty()) {
|
||||
return null;
|
||||
@ -307,10 +307,14 @@ public class GraphSimplifier {
|
||||
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
long left1 = edge1.getLeft();
|
||||
long right1 = edge1.getRight();
|
||||
long left2 = edge2.getLeft();
|
||||
long right2 = edge2.getRight();
|
||||
long left1 = edge1.getLeftExtendedNodes();
|
||||
long right1 = edge1.getRightExtendedNodes();
|
||||
long left2 = edge2.getLeftExtendedNodes();
|
||||
long right2 = edge2.getRightExtendedNodes();
|
||||
if (!cacheStats.containsKey(left1) || !cacheStats.containsKey(right1)
|
||||
|| !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Pair<Statistics, Edge> edge1Before2;
|
||||
Pair<Statistics, Edge> edge2Before1;
|
||||
List<Long> superBitset = new ArrayList<>();
|
||||
@ -351,13 +355,14 @@ public class GraphSimplifier {
|
||||
Statistics leftStats = JoinEstimation.estimate(cacheStats.get(bitmap1), cacheStats.get(bitmap2),
|
||||
edge1.getJoin());
|
||||
Statistics joinStats = JoinEstimation.estimate(leftStats, cacheStats.get(bitmap3), edge2.getJoin());
|
||||
Edge edge = new Edge(edge2.getJoin(), -1);
|
||||
Edge edge = new Edge(
|
||||
edge2.getJoin(), -1, edge2.getLeftChildEdges(), edge2.getRightChildEdges(), edge2.getSubTreeNodes());
|
||||
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
|
||||
// To avoid overlapping the left and the right, the newLeft is calculated, Note the
|
||||
// newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency
|
||||
newLeft = LongBitmap.andNot(newLeft, bitmap3);
|
||||
edge.addLeftNodes(newLeft);
|
||||
edge.addRightNode(edge2.getRight());
|
||||
edge.addRightNode(edge2.getRightExtendedNodes());
|
||||
cacheStats.put(newLeft, leftStats);
|
||||
cacheCost.put(newLeft, calCost(edge2, leftStats, cacheStats.get(bitmap1), cacheStats.get(bitmap2)));
|
||||
return Pair.of(joinStats, edge);
|
||||
@ -370,11 +375,12 @@ public class GraphSimplifier {
|
||||
Statistics rightStats = JoinEstimation.estimate(cacheStats.get(bitmap2), cacheStats.get(bitmap3),
|
||||
edge2.getJoin());
|
||||
Statistics joinStats = JoinEstimation.estimate(cacheStats.get(bitmap1), rightStats, edge1.getJoin());
|
||||
Edge edge = new Edge(edge1.getJoin(), -1);
|
||||
Edge edge = new Edge(
|
||||
edge1.getJoin(), -1, edge1.getLeftChildEdges(), edge1.getRightChildEdges(), edge1.getSubTreeNodes());
|
||||
|
||||
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
|
||||
newRight = LongBitmap.andNot(newRight, bitmap1);
|
||||
edge.addLeftNode(edge1.getLeft());
|
||||
edge.addLeftNode(edge1.getLeftExtendedNodes());
|
||||
edge.addRightNode(newRight);
|
||||
cacheStats.put(newRight, rightStats);
|
||||
cacheCost.put(newRight, calCost(edge2, rightStats, cacheStats.get(bitmap2), cacheStats.get(bitmap3)));
|
||||
@ -384,11 +390,11 @@ public class GraphSimplifier {
|
||||
private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2,
|
||||
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
|
||||
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
|
||||
cacheStats.get(edge1Before2.second.getLeft()),
|
||||
cacheStats.get(edge1Before2.second.getRight()));
|
||||
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
|
||||
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
|
||||
Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
|
||||
cacheStats.get(edge1Before2.second.getLeft()),
|
||||
cacheStats.get(edge1Before2.second.getRight()));
|
||||
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
|
||||
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
|
||||
double benefit = Double.MAX_VALUE;
|
||||
SimplificationStep step;
|
||||
// Choose the plan with smaller cost and make the simplification step to replace the old edge by it.
|
||||
@ -397,17 +403,17 @@ public class GraphSimplifier {
|
||||
benefit = cost2Before1.getValue() / cost1Before2.getValue();
|
||||
}
|
||||
// choose edge1Before2
|
||||
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(),
|
||||
edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(),
|
||||
graph.getEdge(edgeIndex2).getRight());
|
||||
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeftExtendedNodes(),
|
||||
edge1Before2.second.getRightExtendedNodes(), graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
|
||||
graph.getEdge(edgeIndex2).getRightExtendedNodes());
|
||||
} else {
|
||||
if (cost2Before1.getValue() != 0) {
|
||||
benefit = cost1Before2.getValue() / cost2Before1.getValue();
|
||||
}
|
||||
// choose edge2Before1
|
||||
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(),
|
||||
edge2Before1.second.getRight(), graph.getEdge(edgeIndex1).getLeft(),
|
||||
graph.getEdge(edgeIndex1).getRight());
|
||||
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeftExtendedNodes(),
|
||||
edge2Before1.second.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
|
||||
graph.getEdge(edgeIndex1).getRightExtendedNodes());
|
||||
}
|
||||
return step;
|
||||
}
|
||||
@ -438,8 +444,8 @@ public class GraphSimplifier {
|
||||
join.left(),
|
||||
join.right());
|
||||
cost = CostCalculator.calculateCost(nestedLoopJoin, planContext);
|
||||
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0);
|
||||
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1);
|
||||
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
|
||||
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
|
||||
} else {
|
||||
PhysicalHashJoin hashJoin = new PhysicalHashJoin<>(
|
||||
join.getJoinType(),
|
||||
@ -451,8 +457,8 @@ public class GraphSimplifier {
|
||||
join.left(),
|
||||
join.right());
|
||||
cost = CostCalculator.calculateCost(hashJoin, planContext);
|
||||
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0);
|
||||
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1);
|
||||
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
|
||||
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
|
||||
}
|
||||
|
||||
return cost;
|
||||
|
||||
@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinHint;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -88,7 +87,7 @@ public class HyperGraph {
|
||||
*
|
||||
* @param alias The alias Expression in project Operator
|
||||
*/
|
||||
public boolean addAlias(Alias alias) {
|
||||
public boolean addAlias(Alias alias, long subTreeNodes) {
|
||||
Slot aliasSlot = alias.toSlot();
|
||||
if (slotToNodeMap.containsKey(aliasSlot)) {
|
||||
return true;
|
||||
@ -97,12 +96,21 @@ public class HyperGraph {
|
||||
for (Slot slot : alias.getInputSlots()) {
|
||||
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
|
||||
}
|
||||
// The case hit when there are some constant aliases such as:
|
||||
// select * from t1 join (
|
||||
// select *, 1 as b1 from t2)
|
||||
// on t1.b = b1
|
||||
// just reference them all for this slot
|
||||
if (bitmap == 0) {
|
||||
bitmap = subTreeNodes;
|
||||
}
|
||||
Preconditions.checkArgument(bitmap > 0, "slot must belong to some table");
|
||||
slotToNodeMap.put(aliasSlot, bitmap);
|
||||
if (!complexProject.containsKey(bitmap)) {
|
||||
complexProject.put(bitmap, new ArrayList<>());
|
||||
} else if (!(alias.child() instanceof SlotReference)) {
|
||||
alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0);
|
||||
}
|
||||
alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0);
|
||||
|
||||
complexProject.get(bitmap).add(alias);
|
||||
return true;
|
||||
}
|
||||
@ -111,8 +119,9 @@ public class HyperGraph {
|
||||
* add end node to HyperGraph
|
||||
*
|
||||
* @param group The group that is the end node in graph
|
||||
* @return return the node index
|
||||
*/
|
||||
public void addNode(Group group) {
|
||||
public int addNode(Group group) {
|
||||
Preconditions.checkArgument(!group.isValidJoinGroup());
|
||||
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
|
||||
@ -120,6 +129,7 @@ public class HyperGraph {
|
||||
}
|
||||
nodeSet.add(group);
|
||||
nodes.add(new Node(nodes.size(), group));
|
||||
return nodes.size() - 1;
|
||||
}
|
||||
|
||||
public boolean isNodeGroup(Group group) {
|
||||
@ -135,138 +145,126 @@ public class HyperGraph {
|
||||
*
|
||||
* @param group The join group
|
||||
*/
|
||||
public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) {
|
||||
public BitSet addEdge(Group group, Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
|
||||
Preconditions.checkArgument(group.isValidJoinGroup());
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
|
||||
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
|
||||
|
||||
for (Expression expression : join.getHashJoinConjuncts()) {
|
||||
Pair<Long, Long> ends = findEnds(expression);
|
||||
// TODO: avoid calling calculateEnds if calNodeMap's results are same
|
||||
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
|
||||
rightEdgeNodes);
|
||||
if (!conjuncts.containsKey(ends)) {
|
||||
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
|
||||
}
|
||||
conjuncts.get(ends).first.add(expression);
|
||||
}
|
||||
for (Expression expression : join.getOtherJoinConjuncts()) {
|
||||
Pair<Long, Long> ends = findEnds(expression);
|
||||
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
|
||||
rightEdgeNodes);
|
||||
if (!conjuncts.containsKey(ends)) {
|
||||
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
|
||||
}
|
||||
conjuncts.get(ends).second.add(expression);
|
||||
}
|
||||
|
||||
BitSet edgeMap = new BitSet();
|
||||
edgeMap.or(leftEdgeMap);
|
||||
edgeMap.or(rightEdgeMap);
|
||||
|
||||
BitSet curJoinEdges = new BitSet();
|
||||
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
|
||||
.entrySet()) {
|
||||
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
|
||||
entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
|
||||
Lists.newArrayList(join.left(), join.right()));
|
||||
Edge edge = new Edge(singleJoin, edges.size());
|
||||
Edge edge = new Edge(singleJoin, edges.size(), leftEdgeNodes.first, rightEdgeNodes.first,
|
||||
LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second));
|
||||
Pair<Long, Long> ends = entry.getKey();
|
||||
initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap);
|
||||
edge.setLeftRequiredNodes(ends.first);
|
||||
edge.setLeftExtendedNodes(ends.first);
|
||||
edge.setRightRequiredNodes(ends.second);
|
||||
edge.setRightExtendedNodes(ends.second);
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
|
||||
nodes.get(nodeIndex).attachEdge(edge);
|
||||
}
|
||||
edgeMap.set(edge.getIndex());
|
||||
curJoinEdges.set(edge.getIndex());
|
||||
edges.add(edge);
|
||||
}
|
||||
|
||||
return edgeMap;
|
||||
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
|
||||
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
|
||||
curJoinEdges.stream().forEach(i -> makeConflictRules(edges.get(i)));
|
||||
return curJoinEdges;
|
||||
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
|
||||
// We don't implement this trick now.
|
||||
}
|
||||
|
||||
// Make edge with CD-A algorithm in
|
||||
// Make edge with CD-C algorithm in
|
||||
// On the correct and complete enumeration of the core search
|
||||
private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) {
|
||||
long left = ends.first;
|
||||
long right = ends.second;
|
||||
for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) {
|
||||
Edge lEdge = edges.get(i);
|
||||
if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) {
|
||||
left = LongBitmap.or(left, lEdge.getLeft());
|
||||
private void makeConflictRules(Edge edgeB) {
|
||||
BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges());
|
||||
BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges());
|
||||
long leftRequired = edgeB.getLeftRequiredNodes();
|
||||
long rightRequired = edgeB.getRightRequiredNodes();
|
||||
|
||||
for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) {
|
||||
Edge childA = edges.get(i);
|
||||
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
|
||||
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(edges));
|
||||
}
|
||||
if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) {
|
||||
left = LongBitmap.or(left, lEdge.getRight());
|
||||
}
|
||||
}
|
||||
for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) {
|
||||
Edge rEdge = edges.get(i);
|
||||
if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) {
|
||||
right = LongBitmap.or(right, rEdge.getRight());
|
||||
}
|
||||
if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) {
|
||||
right = LongBitmap.or(right, rEdge.getLeft());
|
||||
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
|
||||
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(edges));
|
||||
}
|
||||
}
|
||||
|
||||
edge.setOriginalLeft(left);
|
||||
edge.setOriginalRight(right);
|
||||
edge.setLeft(left);
|
||||
edge.setRight(right);
|
||||
for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
|
||||
Edge childA = edges.get(i);
|
||||
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
|
||||
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(edges));
|
||||
}
|
||||
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
|
||||
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(edges));
|
||||
}
|
||||
}
|
||||
edgeB.setLeftRequiredNodes(leftRequired);
|
||||
edgeB.setRightRequiredNodes(rightRequired);
|
||||
edgeB.setLeftExtendedNodes(leftRequired);
|
||||
edgeB.setRightExtendedNodes(rightRequired);
|
||||
}
|
||||
|
||||
private int findRoot(List<Integer> parent, int idx) {
|
||||
int root = parent.get(idx);
|
||||
if (root != idx) {
|
||||
root = findRoot(parent, root);
|
||||
}
|
||||
parent.set(idx, root);
|
||||
return root;
|
||||
private BitSet subTreeEdges(Edge edge) {
|
||||
BitSet bitSet = new BitSet();
|
||||
bitSet.or(subTreeEdges(edge.getLeftChildEdges()));
|
||||
bitSet.or(subTreeEdges(edge.getRightChildEdges()));
|
||||
bitSet.set(edge.getIndex());
|
||||
return bitSet;
|
||||
}
|
||||
|
||||
private boolean isConnected(long bitmap, long excludeBitmap) {
|
||||
if (LongBitmap.getCardinality(bitmap) == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// use unionSet to check whether the bitmap is connected
|
||||
List<Integer> parent = new ArrayList<>();
|
||||
for (int i = 0; i < nodes.size(); i++) {
|
||||
parent.add(i, i);
|
||||
}
|
||||
for (Edge edge : edges) {
|
||||
if (LongBitmap.isOverlap(edge.getLeft(), excludeBitmap)
|
||||
|| LongBitmap.isOverlap(edge.getRight(), excludeBitmap)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int root = findRoot(parent, LongBitmap.nextSetBit(edge.getLeft(), 0));
|
||||
for (int idx : LongBitmap.getIterator(edge.getLeft())) {
|
||||
parent.set(idx, root);
|
||||
}
|
||||
for (int idx : LongBitmap.getIterator(edge.getRight())) {
|
||||
parent.set(idx, root);
|
||||
}
|
||||
}
|
||||
|
||||
int root = findRoot(parent, LongBitmap.nextSetBit(bitmap, 0));
|
||||
for (int idx : LongBitmap.getIterator(bitmap)) {
|
||||
if (root != findRoot(parent, idx)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
private BitSet subTreeEdges(BitSet edgeSet) {
|
||||
BitSet bitSet = new BitSet();
|
||||
edgeSet.stream()
|
||||
.mapToObj(i -> subTreeEdges(edges.get(i)))
|
||||
.forEach(b -> bitSet.or(b));
|
||||
return bitSet;
|
||||
}
|
||||
|
||||
private Pair<Long, Long> findEnds(Expression expression) {
|
||||
long bitmap = calNodeMap(expression.getInputSlots());
|
||||
int cardinality = LongBitmap.getCardinality(bitmap);
|
||||
Preconditions.checkArgument(cardinality > 1);
|
||||
for (long subset : LongBitmap.getSubsetIterator(bitmap)) {
|
||||
long left = subset;
|
||||
long right = LongBitmap.newBitmapDiff(bitmap, left);
|
||||
// when the graph without right node has a connected-sub-graph contains left nodes
|
||||
// and the graph without left node has a connected-sub-graph contains right nodes.
|
||||
// we can generate an edge for this expression
|
||||
if (isConnected(left, right) && isConnected(right, left)) {
|
||||
return Pair.of(left, right);
|
||||
}
|
||||
// Try to calculate the ends of an expression.
|
||||
// left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree
|
||||
// if left = 0, recursively calculate it in left tree
|
||||
private Pair<Long, Long> calculateEnds(long allNodes, Pair<BitSet, Long> leftEdgeNodes,
|
||||
Pair<BitSet, Long> rightEdgeNodes) {
|
||||
long left = LongBitmap.newBitmapIntersect(allNodes, leftEdgeNodes.second);
|
||||
long right = LongBitmap.newBitmapIntersect(allNodes, rightEdgeNodes.second);
|
||||
if (left == 0) {
|
||||
Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0,
|
||||
"the number of the table which expression reference is less 2");
|
||||
Pair<BitSet, Long> llEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
|
||||
Pair<BitSet, Long> lrEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
|
||||
return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes);
|
||||
}
|
||||
throw new RuntimeException("DPhyper meets unconnected subgraph");
|
||||
if (right == 0) {
|
||||
Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0,
|
||||
"the number of the table which expression reference is less 2");
|
||||
Pair<BitSet, Long> rlEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
|
||||
Pair<BitSet, Long> rrEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
|
||||
return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes);
|
||||
}
|
||||
return Pair.of(left, right);
|
||||
}
|
||||
|
||||
private long calNodeMap(Set<Slot> slots) {
|
||||
@ -291,10 +289,10 @@ public class HyperGraph {
|
||||
// For these nodes that are only in the old edge, we need remove the edge from them
|
||||
// For these nodes that are only in the new edge, we need to add the edge to them
|
||||
Edge edge = edges.get(edgeIndex);
|
||||
updateEdges(edge, edge.getLeft(), newLeft);
|
||||
updateEdges(edge, edge.getRight(), newRight);
|
||||
edges.get(edgeIndex).setLeft(newLeft);
|
||||
edges.get(edgeIndex).setRight(newRight);
|
||||
updateEdges(edge, edge.getLeftExtendedNodes(), newLeft);
|
||||
updateEdges(edge, edge.getRightExtendedNodes(), newRight);
|
||||
edges.get(edgeIndex).setLeftExtendedNodes(newLeft);
|
||||
edges.get(edgeIndex).setRightExtendedNodes(newRight);
|
||||
}
|
||||
|
||||
private void updateEdges(Edge edge, long oldNodes, long newNodes) {
|
||||
@ -339,8 +337,8 @@ public class HyperGraph {
|
||||
arrowHead = ",arrowhead=none";
|
||||
}
|
||||
|
||||
int leftIndex = LongBitmap.lowestOneIndex(edge.getLeft());
|
||||
int rightIndex = LongBitmap.lowestOneIndex(edge.getRight());
|
||||
int leftIndex = LongBitmap.lowestOneIndex(edge.getLeftExtendedNodes());
|
||||
int rightIndex = LongBitmap.lowestOneIndex(edge.getRightExtendedNodes());
|
||||
builder.append(String.format("%s -> %s [label=\"%s\"%s]\n", graphvisNodes.get(leftIndex),
|
||||
graphvisNodes.get(rightIndex), label, arrowHead));
|
||||
} else {
|
||||
@ -349,7 +347,7 @@ public class HyperGraph {
|
||||
|
||||
String leftLabel = "";
|
||||
String rightLabel = "";
|
||||
if (LongBitmap.getCardinality(edge.getLeft()) == 1) {
|
||||
if (LongBitmap.getCardinality(edge.getLeftExtendedNodes()) == 1) {
|
||||
rightLabel = label;
|
||||
} else {
|
||||
leftLabel = label;
|
||||
@ -357,13 +355,13 @@ public class HyperGraph {
|
||||
|
||||
int finalI = i;
|
||||
String finalLeftLabel = leftLabel;
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getLeft())) {
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getLeftExtendedNodes())) {
|
||||
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
|
||||
graphvisNodes.get(nodeIndex), finalI, finalLeftLabel));
|
||||
}
|
||||
|
||||
String finalRightLabel = rightLabel;
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getRight())) {
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getRightExtendedNodes())) {
|
||||
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
|
||||
graphvisNodes.get(nodeIndex), finalI, finalRightLabel));
|
||||
}
|
||||
|
||||
@ -225,8 +225,8 @@ public class SubgraphEnumerator {
|
||||
neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
|
||||
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
|
||||
for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
|
||||
long left = edge.getLeft();
|
||||
long right = edge.getRight();
|
||||
long left = edge.getLeftExtendedNodes();
|
||||
long right = edge.getRightExtendedNodes();
|
||||
if (LongBitmap.isSubset(left, subgraph) && !LongBitmap.isOverlap(right, forbiddenNodes)) {
|
||||
neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(right));
|
||||
} else if (LongBitmap.isSubset(right, subgraph) && !LongBitmap.isOverlap(left, forbiddenNodes)) {
|
||||
@ -362,14 +362,14 @@ public class SubgraphEnumerator {
|
||||
}
|
||||
|
||||
private boolean isContainEdge(long subgraph, Edge edge) {
|
||||
int containLeft = LongBitmap.isSubset(edge.getLeft(), subgraph) ? 0 : 1;
|
||||
int containRight = LongBitmap.isSubset(edge.getRight(), subgraph) ? 0 : 1;
|
||||
int containLeft = LongBitmap.isSubset(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
|
||||
int containRight = LongBitmap.isSubset(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
|
||||
return containLeft + containRight == 1;
|
||||
}
|
||||
|
||||
private boolean isOverlapEdge(long subgraph, Edge edge) {
|
||||
int overlapLeft = LongBitmap.isOverlap(edge.getLeft(), subgraph) ? 0 : 1;
|
||||
int overlapRight = LongBitmap.isOverlap(edge.getRight(), subgraph) ? 0 : 1;
|
||||
int overlapLeft = LongBitmap.isOverlap(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
|
||||
int overlapRight = LongBitmap.isOverlap(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
|
||||
return overlapLeft + overlapRight == 1;
|
||||
}
|
||||
|
||||
|
||||
@ -60,6 +60,7 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
@ -102,15 +103,14 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
|
||||
Preconditions.checkArgument(planTable.containsKey(left));
|
||||
Preconditions.checkArgument(planTable.containsKey(right));
|
||||
|
||||
processMissedEdges(left, right, edges);
|
||||
|
||||
Memo memo = jobContext.getCascadesContext().getMemo();
|
||||
emitCount += 1;
|
||||
if (emitCount > limit) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Memo memo = jobContext.getCascadesContext().getMemo();
|
||||
GroupPlan leftPlan = new GroupPlan(planTable.get(left));
|
||||
GroupPlan rightPlan = new GroupPlan(planTable.get(right));
|
||||
|
||||
@ -118,6 +118,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
// In this step, we don't generate logical expression because they are useless in DPhyp.
|
||||
List<Expression> hashConjuncts = new ArrayList<>();
|
||||
List<Expression> otherConjuncts = new ArrayList<>();
|
||||
|
||||
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
|
||||
if (joinType == null) {
|
||||
return true;
|
||||
@ -126,6 +127,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
|
||||
List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
|
||||
otherConjuncts);
|
||||
|
||||
List<Plan> physicalPlans = proposeProject(physicalJoins, edges, left, right);
|
||||
|
||||
// Second, we copy all physical plan to Group and generate properties and calculate cost
|
||||
@ -188,7 +190,7 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
// find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes
|
||||
for (Edge edge : hyperGraph.getEdges()) {
|
||||
long referenceNodes =
|
||||
LongBitmap.newBitmapUnion(edge.getOriginalLeft(), edge.getOriginalRight());
|
||||
LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
|
||||
if (LongBitmap.isSubset(referenceNodes, allReferenceNodes)
|
||||
&& !usedEdgesBitmap.get(edge.getIndex())) {
|
||||
// add the missed edge to edges
|
||||
@ -344,7 +346,6 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
long fullKey = LongBitmap.newBitmapUnion(left, right);
|
||||
List<Slot> outputs = allChild.get(0).getOutput();
|
||||
Set<Slot> outputSet = allChild.get(0).getOutputSet();
|
||||
List<NamedExpression> allProjects = Lists.newArrayList();
|
||||
|
||||
List<NamedExpression> complexProjects = new ArrayList<>();
|
||||
// Calculate complex expression should be done by current(fullKey) node
|
||||
@ -358,43 +359,23 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
|
||||
// complexProjectMap is created by a bottom up traverse of join tree, so child node is put before parent node
|
||||
// in the bitmaps
|
||||
bitmaps.sort(Long::compare);
|
||||
for (long bitmap : bitmaps) {
|
||||
if (complexProjects.isEmpty()) {
|
||||
complexProjects = complexProjectMap.get(bitmap);
|
||||
complexProjects.addAll(complexProjectMap.get(bitmap));
|
||||
} else {
|
||||
// The top project of (T1, T2, T3) is different after reorder
|
||||
// we need merge Project1 and Project2 as Project4 after reorder
|
||||
// T1 join T2 join T3:
|
||||
// Project1(a, e + f)
|
||||
// join(a = e)
|
||||
// Project2(a, b + d as e)
|
||||
// join(a = c)
|
||||
// T1(a, b)
|
||||
// T2(c, d)
|
||||
// T3(e, f)
|
||||
//
|
||||
// after reorder:
|
||||
// T1 join T3 join T2:
|
||||
// Project4(a, b + d + f)
|
||||
// join(a = c)
|
||||
// Project3(a, b, f)
|
||||
// join(a = e)
|
||||
// T1(a, b)
|
||||
// T3(e, f)
|
||||
// T2(c, d)
|
||||
//
|
||||
complexProjects =
|
||||
PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap));
|
||||
// Rewrite project expression by its children
|
||||
complexProjects.addAll(
|
||||
PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap)));
|
||||
}
|
||||
}
|
||||
allProjects.addAll(complexProjects);
|
||||
|
||||
// calculate required columns by all parents
|
||||
Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges);
|
||||
|
||||
// add output slots belong to required slots to project list
|
||||
allProjects.addAll(outputs.stream().filter(e -> requireSlots.contains(e))
|
||||
.collect(Collectors.toList()));
|
||||
List<NamedExpression> allProjects = Stream.concat(
|
||||
outputs.stream().filter(e -> requireSlots.contains(e)),
|
||||
complexProjects.stream().filter(e -> requireSlots.contains(e.toSlot()))
|
||||
).collect(Collectors.toList());
|
||||
|
||||
// propose physical project
|
||||
if (allProjects.isEmpty()) {
|
||||
@ -416,7 +397,13 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
.map(c -> new PhysicalProject<>(projects, projectProperties, c))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size());
|
||||
if (!(!projects.isEmpty() && projects.size() == allProjects.size())) {
|
||||
Set<NamedExpression> s1 = projects.stream().collect(Collectors.toSet());
|
||||
List<NamedExpression> s2 = allProjects.stream().filter(e -> !s1.contains(e)).collect(Collectors.toList());
|
||||
System.out.println(s2);
|
||||
}
|
||||
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size(),
|
||||
" there are some projects left " + projects + allProjects);
|
||||
|
||||
return allChild;
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.HyperGraphBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
@ -32,23 +33,37 @@ import java.util.Set;
|
||||
|
||||
public class OtherJoinTest extends TPCHTestBase {
|
||||
@Test
|
||||
public void randomTest() {
|
||||
public void test() {
|
||||
for (int t = 3; t < 10; t++) {
|
||||
for (int e = t - 1; e <= (t * (t - 1)) / 2; e++) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
System.out.println(String.valueOf(t) + " " + e + ": " + i);
|
||||
randomTest(t, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void randomTest(int tableNum, int edgeNum) {
|
||||
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder();
|
||||
Plan plan = hyperGraphBuilder
|
||||
.randomBuildPlanWith(10, 20);
|
||||
Set<List<Integer>> res1 = hyperGraphBuilder.evaluate(plan);
|
||||
.randomBuildPlanWith(tableNum, edgeNum);
|
||||
plan = new LogicalProject(plan.getOutput(), plan);
|
||||
Set<List<String>> res1 = hyperGraphBuilder.evaluate(plan);
|
||||
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
|
||||
hyperGraphBuilder.initStats(cascadesContext);
|
||||
Plan optimizedPlan = PlanChecker.from(cascadesContext)
|
||||
.dpHypOptimize()
|
||||
.getBestPlanTree();
|
||||
.dpHypOptimize()
|
||||
.getBestPlanTree();
|
||||
|
||||
Set<List<Integer>> res2 = hyperGraphBuilder.evaluate(optimizedPlan);
|
||||
Set<List<String>> res2 = hyperGraphBuilder.evaluate(optimizedPlan);
|
||||
if (!res1.equals(res2)) {
|
||||
System.out.println(res1);
|
||||
System.out.println(res2);
|
||||
System.out.println(plan.treeString());
|
||||
System.out.println(optimizedPlan.treeString());
|
||||
cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
|
||||
PlanChecker.from(cascadesContext).dpHypOptimize().getBestPlanTree();
|
||||
System.out.println(res1);
|
||||
System.out.println(res2);
|
||||
}
|
||||
Assertions.assertTrue(res1.equals(res2));
|
||||
|
||||
|
||||
@ -130,8 +130,8 @@ public class SubgraphEnumeratorTest {
|
||||
visited.add(left);
|
||||
visited.add(right);
|
||||
for (Edge edge : hyperGraph.getEdges()) {
|
||||
if ((LongBitmap.isSubset(edge.getLeft(), left) && LongBitmap.isSubset(edge.getRight(), right)) || (
|
||||
LongBitmap.isSubset(edge.getLeft(), right) && LongBitmap.isSubset(edge.getRight(), left))) {
|
||||
if ((LongBitmap.isSubset(edge.getLeftExtendedNodes(), left) && LongBitmap.isSubset(edge.getRightExtendedNodes(), right)) || (
|
||||
LongBitmap.isSubset(edge.getLeftExtendedNodes(), right) && LongBitmap.isSubset(edge.getRightExtendedNodes(), left))) {
|
||||
count += countAndCheck(left, hyperGraph, counter, cache) * countAndCheck(right, hyperGraph,
|
||||
counter, cache);
|
||||
break;
|
||||
|
||||
@ -87,6 +87,20 @@ public class JoinOrderJobTest extends SqlTestBase {
|
||||
.dpHypOptimize();
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void testConstantJoin() {
|
||||
String sql = "select count(*) \n"
|
||||
+ "from \n"
|
||||
+ "T1 \n"
|
||||
+ " join (\n"
|
||||
+ "select * , now() as t from T2 \n"
|
||||
+ ") subTable on T1.id = t; \n";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.dpHypOptimize();
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void testCountJoin() {
|
||||
String sql = "select count(*) \n"
|
||||
|
||||
@ -17,9 +17,9 @@
|
||||
|
||||
package org.apache.doris.nereids.sqltest;
|
||||
|
||||
import org.apache.doris.nereids.properties.DistributionSpecGather;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecHash;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
@ -50,12 +50,8 @@ public class JoinTest extends SqlTestBase {
|
||||
.getBestPlanTree();
|
||||
// generate colocate join plan without physicalDistribute
|
||||
System.out.println(plan.treeString());
|
||||
Assertions.assertFalse(plan.anyMatch(p -> {
|
||||
if (p instanceof PhysicalDistribute) {
|
||||
return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather);
|
||||
}
|
||||
return false;
|
||||
}));
|
||||
Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute
|
||||
&& ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash));
|
||||
sql = "select * from T1 join T0 on T1.score = T0.score and T1.id = T0.id;";
|
||||
plan = PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
@ -63,12 +59,8 @@ public class JoinTest extends SqlTestBase {
|
||||
.optimize()
|
||||
.getBestPlanTree();
|
||||
// generate colocate join plan without physicalDistribute
|
||||
Assertions.assertFalse(plan.anyMatch(p -> {
|
||||
if (p instanceof PhysicalDistribute) {
|
||||
return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather);
|
||||
}
|
||||
return false;
|
||||
}));
|
||||
Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute
|
||||
&& ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -100,7 +92,7 @@ public class JoinTest extends SqlTestBase {
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.optimize()
|
||||
.getBestPlanTree();
|
||||
.getBestPlanTree(PhysicalProperties.ANY);
|
||||
Assertions.assertEquals(
|
||||
ShuffleType.NATURAL,
|
||||
((DistributionSpecHash) ((PhysicalPlan) (plan.child(0).child(0)))
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
|
||||
@ -26,6 +27,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
@ -35,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
|
||||
import org.apache.doris.statistics.ColumnStatistic;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
import org.apache.doris.statistics.StatisticsCacheKey;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -90,6 +93,12 @@ public class HyperGraphBuilder {
|
||||
return buildHyperGraph(plan);
|
||||
}
|
||||
|
||||
public Plan buildPlan() {
|
||||
assert plans.size() == 1 : "there are cross join";
|
||||
Plan plan = plans.values().iterator().next();
|
||||
return plan;
|
||||
}
|
||||
|
||||
public Plan buildJoinPlan() {
|
||||
assert plans.size() == 1 : "there are cross join";
|
||||
Plan plan = plans.values().iterator().next();
|
||||
@ -166,9 +175,14 @@ public class HyperGraphBuilder {
|
||||
for (Group group : context.getMemo().getGroups()) {
|
||||
GroupExpression groupExpression = group.getLogicalExpression();
|
||||
if (groupExpression.getPlan() instanceof LogicalOlapScan) {
|
||||
LogicalOlapScan scan = (LogicalOlapScan) groupExpression.getPlan();
|
||||
Statistics stats = injectRowcount((LogicalOlapScan) groupExpression.getPlan());
|
||||
groupExpression.setStatDerived(true);
|
||||
group.setStatistics(stats);
|
||||
for (Expression expr : stats.columnStatistics().keySet()) {
|
||||
SlotReference slot = (SlotReference) expr;
|
||||
Env.getCurrentEnv().getStatisticsCache().putCache(
|
||||
new StatisticsCacheKey(scan.getTable().getId(), -1, slot.getName()),
|
||||
stats.columnStatistics().get(expr));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -364,7 +378,7 @@ public class HyperGraphBuilder {
|
||||
return hashConjunts;
|
||||
}
|
||||
|
||||
public Set<List<Integer>> evaluate(Plan plan) {
|
||||
public Set<List<String>> evaluate(Plan plan) {
|
||||
JoinEvaluator evaluator = new JoinEvaluator(rowCounts);
|
||||
Map<Slot, List<Integer>> res = evaluator.evaluate(plan);
|
||||
int rowCount = 0;
|
||||
@ -376,11 +390,12 @@ public class HyperGraphBuilder {
|
||||
(slot1, slot2) ->
|
||||
String.CASE_INSENSITIVE_ORDER.compare(slot1.toString(), slot2.toString()))
|
||||
.collect(Collectors.toList());
|
||||
Set<List<Integer>> tuples = new HashSet<>();
|
||||
Set<List<String>> tuples = new HashSet<>();
|
||||
tuples.add(keySet.stream().map(s -> s.toString()).collect(Collectors.toList()));
|
||||
for (int i = 0; i < rowCount; i++) {
|
||||
List<Integer> tuple = new ArrayList<>();
|
||||
List<String> tuple = new ArrayList<>();
|
||||
for (Slot key : keySet) {
|
||||
tuple.add(res.get(key).get(i));
|
||||
tuple.add(String.valueOf(res.get(key).get(i)));
|
||||
}
|
||||
tuples.add(tuple);
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
|
||||
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
|
||||
import org.apache.doris.nereids.memo.CopyInResult;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
@ -243,23 +242,13 @@ public class PlanChecker {
|
||||
|
||||
public PlanChecker dpHypOptimize() {
|
||||
double now = System.currentTimeMillis();
|
||||
cascadesContext.getStatementContext().setDpHyp(true);
|
||||
cascadesContext.getConnectContext().getSessionVariable().enableDPHypOptimizer = true;
|
||||
Group root = cascadesContext.getMemo().getRoot();
|
||||
boolean changeRoot = false;
|
||||
if (root.isValidJoinGroup()) {
|
||||
// If the root group is join group, DPHyp can change the root group.
|
||||
// To keep the root group is not changed, we add a dummy project operator above join
|
||||
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
|
||||
LogicalPlan plan = new LogicalProject(outputs, root.getLogicalExpression().getPlan());
|
||||
CopyInResult copyInResult = cascadesContext.getMemo().copyIn(plan, null, false);
|
||||
root = copyInResult.correspondingExpression.getOwnerGroup();
|
||||
changeRoot = true;
|
||||
}
|
||||
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
|
||||
cascadesContext.pushJob(new DeriveStatsJob(root.getLogicalExpression(),
|
||||
cascadesContext.getCurrentJobContext()));
|
||||
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
|
||||
if (changeRoot) {
|
||||
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
|
||||
}
|
||||
// if the root is not join, we need to optimize again.
|
||||
optimize();
|
||||
System.out.println("DPhyp:" + (System.currentTimeMillis() - now));
|
||||
return this;
|
||||
@ -602,7 +591,7 @@ public class PlanChecker {
|
||||
}
|
||||
|
||||
public PhysicalPlan getBestPlanTree() {
|
||||
return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.ANY);
|
||||
return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.GATHER);
|
||||
}
|
||||
|
||||
public PhysicalPlan getBestPlanTree(PhysicalProperties properties) {
|
||||
|
||||
Reference in New Issue
Block a user