[fix](Nereids) fix ends calculation when there are constant project (#22265)

This commit is contained in:
谢健
2023-07-31 14:10:44 +08:00
committed by GitHub
parent 147a148364
commit 8ccd8b4337
12 changed files with 341 additions and 259 deletions

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
@ -26,6 +27,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -111,20 +113,22 @@ public class JoinOrderJob extends Job {
*
* @param group root group, should be join type
* @param hyperGraph build hyperGraph
*
* @return return edges of group's child and subTreeNodes of this group
*/
public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
public Pair<BitSet, Long> buildGraph(Group group, HyperGraph hyperGraph) {
if (group.isProjectGroup()) {
BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
processProjectPlan(hyperGraph, group);
return edgeMap;
Pair<BitSet, Long> res = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
processProjectPlan(hyperGraph, group, res.second);
return res;
}
if (!group.isValidJoinGroup()) {
hyperGraph.addNode(optimizePlan(group));
return new BitSet();
int idx = hyperGraph.addNode(optimizePlan(group));
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
}
BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap);
Pair<BitSet, Long> left = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
Pair<BitSet, Long> right = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
return Pair.of(hyperGraph.addEdge(group, left, right), LongBitmap.or(left.second, right.second));
}
/**
@ -133,14 +137,14 @@ public class JoinOrderJob extends Job {
* 2. If it's an alias that may be used in the join operator, we need to add it to graph
* 3. If it's other expression, we can ignore them and add it after optimizing
*/
private void processProjectPlan(HyperGraph hyperGraph, Group group) {
private void processProjectPlan(HyperGraph hyperGraph, Group group, long subTreeNodes) {
LogicalProject<? extends Plan> logicalProject
= (LogicalProject<? extends Plan>) group.getLogicalExpression()
.getPlan();
for (NamedExpression expr : logicalProject.getProjects()) {
if (expr instanceof Alias) {
hyperGraph.addAlias((Alias) expr);
hyperGraph.addAlias((Alias) expr, subTreeNodes);
} else if (!expr.isSlot()) {
otherProject.add(expr);
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -26,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -38,23 +40,36 @@ public class Edge {
final LogicalJoin<? extends Plan, ? extends Plan> join;
final double selectivity;
// The endpoints (hyperNodes) of this hyperEdge.
// left and right may not overlap, and both must have at least one bit set.
private long left = LongBitmap.newBitmap();
private long right = LongBitmap.newBitmap();
private long originalLeft = LongBitmap.newBitmap();
private long originalRight = LongBitmap.newBitmap();
// "RequiredNodes" refers to the nodes that can activate this edge based on
// specific requirements. These requirements are established during the building process.
// "ExtendNodes" encompasses both the "RequiredNodes" and any additional nodes
// added by the graph simplifier.
private long leftRequiredNodes = LongBitmap.newBitmap();
private long rightRequiredNodes = LongBitmap.newBitmap();
private long leftExtendedNodes = LongBitmap.newBitmap();
private long rightExtendedNodes = LongBitmap.newBitmap();
private long referenceNodes = LongBitmap.newBitmap();
// record the left child edges and right child edges in origin plan tree
private BitSet leftChildEdges;
private BitSet rightChildEdges;
// record the edges in the same operator
private BitSet curJoinEdges = new BitSet();
// record all sub nodes behind in this operator. It's T function in paper
private Long subTreeNodes;
/**
* Create simple edge.
*/
public Edge(LogicalJoin join, int index) {
public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChildEdges, Long subTreeNodes) {
this.index = index;
this.join = join;
this.selectivity = 1.0;
this.leftChildEdges = leftChildEdges;
this.rightChildEdges = rightChildEdges;
this.subTreeNodes = subTreeNodes;
}
public LogicalJoin getJoin() {
@ -66,65 +81,107 @@ public class Edge {
}
public boolean isSimple() {
return LongBitmap.getCardinality(left) == 1 && LongBitmap.getCardinality(right) == 1;
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}
public void addLeftNode(long left) {
this.left = LongBitmap.or(this.left, left);
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left);
referenceNodes = LongBitmap.or(referenceNodes, left);
}
public void addLeftNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
this.left = LongBitmap.or(this.left, bitmap);
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, bitmap);
referenceNodes = LongBitmap.or(referenceNodes, bitmap);
}
}
public void addRightNode(long right) {
this.right = LongBitmap.or(this.right, right);
this.rightExtendedNodes = LongBitmap.or(this.rightExtendedNodes, right);
referenceNodes = LongBitmap.or(referenceNodes, right);
}
public void addRightNodes(long... bitmaps) {
for (long bitmap : bitmaps) {
LongBitmap.or(this.right, bitmap);
LongBitmap.or(this.rightExtendedNodes, bitmap);
LongBitmap.or(referenceNodes, bitmap);
}
}
public long getLeft() {
return left;
public long getSubTreeNodes() {
return this.subTreeNodes;
}
public void setLeft(long left) {
public long getLeftExtendedNodes() {
return leftExtendedNodes;
}
public BitSet getLeftChildEdges() {
return leftChildEdges;
}
public Pair<BitSet, Long> getLeftEdgeNodes(List<Edge> edges) {
return Pair.of(leftChildEdges, getLeftSubNodes(edges));
}
public Pair<BitSet, Long> getRightEdgeNodes(List<Edge> edges) {
return Pair.of(rightChildEdges, getRightSubNodes(edges));
}
public long getLeftSubNodes(List<Edge> edges) {
if (leftChildEdges.isEmpty()) {
return leftRequiredNodes;
}
return edges.get(leftChildEdges.nextSetBit(0)).getSubTreeNodes();
}
public long getRightSubNodes(List<Edge> edges) {
if (rightChildEdges.isEmpty()) {
return rightRequiredNodes;
}
return edges.get(rightChildEdges.nextSetBit(0)).getSubTreeNodes();
}
public void setLeftExtendedNodes(long leftExtendedNodes) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.left = left;
this.leftExtendedNodes = leftExtendedNodes;
}
public long getRight() {
return right;
public long getRightExtendedNodes() {
return rightExtendedNodes;
}
public void setRight(long right) {
public BitSet getRightChildEdges() {
return rightChildEdges;
}
public void setRightExtendedNodes(long rightExtendedNodes) {
referenceNodes = LongBitmap.clear(referenceNodes);
this.right = right;
this.rightExtendedNodes = rightExtendedNodes;
}
public long getOriginalLeft() {
return originalLeft;
public long getLeftRequiredNodes() {
return leftRequiredNodes;
}
public void setOriginalLeft(long left) {
this.originalLeft = left;
public void setLeftRequiredNodes(long left) {
this.leftRequiredNodes = left;
}
public long getOriginalRight() {
return originalRight;
public long getRightRequiredNodes() {
return rightRequiredNodes;
}
public void setOriginalRight(long right) {
this.originalRight = right;
public void setRightRequiredNodes(long right) {
this.rightRequiredNodes = right;
}
public void addCurJoinEdges(BitSet edges) {
curJoinEdges.or(edges);
}
public BitSet getCurJoinEdges() {
return curJoinEdges;
}
public boolean isSub(Edge edge) {
@ -135,11 +192,15 @@ public class Edge {
public long getReferenceNodes() {
if (LongBitmap.getCardinality(referenceNodes) == 0) {
referenceNodes = LongBitmap.newBitmapUnion(left, right);
referenceNodes = LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes);
}
return referenceNodes;
}
public long getRequireNodes() {
return LongBitmap.newBitmapUnion(leftRequiredNodes, rightRequiredNodes);
}
public int getIndex() {
return index;
}
@ -165,7 +226,8 @@ public class Edge {
@Override
public String toString() {
return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right));
return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString(
rightExtendedNodes));
}
}

View File

@ -108,10 +108,10 @@ public class GraphSimplifier {
Edge edge1 = graph.getEdge(i);
Edge edge2 = graph.getEdge(j);
List<Long> superset = new ArrayList<>();
tryGetSuperset(edge1.getLeft(), edge2.getLeft(), superset);
tryGetSuperset(edge1.getLeft(), edge2.getRight(), superset);
tryGetSuperset(edge1.getRight(), edge2.getLeft(), superset);
tryGetSuperset(edge1.getRight(), edge2.getRight(), superset);
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset);
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes(), superset);
if (!circleDetector.checkCircleWithEdge(i, j) && !circleDetector.checkCircleWithEdge(j, i)
&& !edge2.isSub(edge1) && !edge1.isSub(edge2) && !superset.isEmpty()) {
return false;
@ -213,8 +213,8 @@ public class GraphSimplifier {
BestSimplification bestSimplification = priorityQueue.poll();
bestSimplification.isInQueue = false;
SimplificationStep bestStep = bestSimplification.getStep();
while (bestSimplification.bestNeighbor == -1 || !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex,
bestStep.afterIndex)) {
while (bestSimplification.bestNeighbor == -1
|| !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex, bestStep.afterIndex)) {
processNeighbors(bestStep.afterIndex, 0, edgeSize);
if (priorityQueue.isEmpty()) {
return null;
@ -307,10 +307,14 @@ public class GraphSimplifier {
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)) {
return Optional.empty();
}
long left1 = edge1.getLeft();
long right1 = edge1.getRight();
long left2 = edge2.getLeft();
long right2 = edge2.getRight();
long left1 = edge1.getLeftExtendedNodes();
long right1 = edge1.getRightExtendedNodes();
long left2 = edge2.getLeftExtendedNodes();
long right2 = edge2.getRightExtendedNodes();
if (!cacheStats.containsKey(left1) || !cacheStats.containsKey(right1)
|| !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) {
return Optional.empty();
}
Pair<Statistics, Edge> edge1Before2;
Pair<Statistics, Edge> edge2Before1;
List<Long> superBitset = new ArrayList<>();
@ -351,13 +355,14 @@ public class GraphSimplifier {
Statistics leftStats = JoinEstimation.estimate(cacheStats.get(bitmap1), cacheStats.get(bitmap2),
edge1.getJoin());
Statistics joinStats = JoinEstimation.estimate(leftStats, cacheStats.get(bitmap3), edge2.getJoin());
Edge edge = new Edge(edge2.getJoin(), -1);
Edge edge = new Edge(
edge2.getJoin(), -1, edge2.getLeftChildEdges(), edge2.getRightChildEdges(), edge2.getSubTreeNodes());
long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
// To avoid overlapping the left and the right, the newLeft is calculated, Note the
// newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency
newLeft = LongBitmap.andNot(newLeft, bitmap3);
edge.addLeftNodes(newLeft);
edge.addRightNode(edge2.getRight());
edge.addRightNode(edge2.getRightExtendedNodes());
cacheStats.put(newLeft, leftStats);
cacheCost.put(newLeft, calCost(edge2, leftStats, cacheStats.get(bitmap1), cacheStats.get(bitmap2)));
return Pair.of(joinStats, edge);
@ -370,11 +375,12 @@ public class GraphSimplifier {
Statistics rightStats = JoinEstimation.estimate(cacheStats.get(bitmap2), cacheStats.get(bitmap3),
edge2.getJoin());
Statistics joinStats = JoinEstimation.estimate(cacheStats.get(bitmap1), rightStats, edge1.getJoin());
Edge edge = new Edge(edge1.getJoin(), -1);
Edge edge = new Edge(
edge1.getJoin(), -1, edge1.getLeftChildEdges(), edge1.getRightChildEdges(), edge1.getSubTreeNodes());
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
newRight = LongBitmap.andNot(newRight, bitmap1);
edge.addLeftNode(edge1.getLeft());
edge.addLeftNode(edge1.getLeftExtendedNodes());
edge.addRightNode(newRight);
cacheStats.put(newRight, rightStats);
cacheCost.put(newRight, calCost(edge2, rightStats, cacheStats.get(bitmap2), cacheStats.get(bitmap3)));
@ -384,11 +390,11 @@ public class GraphSimplifier {
private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2,
Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) {
Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first,
cacheStats.get(edge1Before2.second.getLeft()),
cacheStats.get(edge1Before2.second.getRight()));
cacheStats.get(edge1Before2.second.getLeftExtendedNodes()),
cacheStats.get(edge1Before2.second.getRightExtendedNodes()));
double benefit = Double.MAX_VALUE;
SimplificationStep step;
// Choose the plan with smaller cost and make the simplification step to replace the old edge by it.
@ -397,17 +403,17 @@ public class GraphSimplifier {
benefit = cost2Before1.getValue() / cost1Before2.getValue();
}
// choose edge1Before2
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(),
edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(),
graph.getEdge(edgeIndex2).getRight());
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeftExtendedNodes(),
edge1Before2.second.getRightExtendedNodes(), graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getEdge(edgeIndex2).getRightExtendedNodes());
} else {
if (cost2Before1.getValue() != 0) {
benefit = cost1Before2.getValue() / cost2Before1.getValue();
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(),
edge2Before1.second.getRight(), graph.getEdge(edgeIndex1).getLeft(),
graph.getEdge(edgeIndex1).getRight());
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeftExtendedNodes(),
edge2Before1.second.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getEdge(edgeIndex1).getRightExtendedNodes());
}
return step;
}
@ -438,8 +444,8 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(nestedLoopJoin, planContext);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
} else {
PhysicalHashJoin hashJoin = new PhysicalHashJoin<>(
join.getJoinType(),
@ -451,8 +457,8 @@ public class GraphSimplifier {
join.left(),
join.right());
cost = CostCalculator.calculateCost(hashJoin, planContext);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0);
cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1);
}
return cost;

View File

@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -88,7 +87,7 @@ public class HyperGraph {
*
* @param alias The alias Expression in project Operator
*/
public boolean addAlias(Alias alias) {
public boolean addAlias(Alias alias, long subTreeNodes) {
Slot aliasSlot = alias.toSlot();
if (slotToNodeMap.containsKey(aliasSlot)) {
return true;
@ -97,12 +96,21 @@ public class HyperGraph {
for (Slot slot : alias.getInputSlots()) {
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
}
// The case hit when there are some constant aliases such as:
// select * from t1 join (
// select *, 1 as b1 from t2)
// on t1.b = b1
// just reference them all for this slot
if (bitmap == 0) {
bitmap = subTreeNodes;
}
Preconditions.checkArgument(bitmap > 0, "slot must belong to some table");
slotToNodeMap.put(aliasSlot, bitmap);
if (!complexProject.containsKey(bitmap)) {
complexProject.put(bitmap, new ArrayList<>());
} else if (!(alias.child() instanceof SlotReference)) {
alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0);
}
alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0);
complexProject.get(bitmap).add(alias);
return true;
}
@ -111,8 +119,9 @@ public class HyperGraph {
* add end node to HyperGraph
*
* @param group The group that is the end node in graph
* @return return the node index
*/
public void addNode(Group group) {
public int addNode(Group group) {
Preconditions.checkArgument(!group.isValidJoinGroup());
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
@ -120,6 +129,7 @@ public class HyperGraph {
}
nodeSet.add(group);
nodes.add(new Node(nodes.size(), group));
return nodes.size() - 1;
}
public boolean isNodeGroup(Group group) {
@ -135,138 +145,126 @@ public class HyperGraph {
*
* @param group The join group
*/
public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) {
public BitSet addEdge(Group group, Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
Preconditions.checkArgument(group.isValidJoinGroup());
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
for (Expression expression : join.getHashJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
// TODO: avoid calling calculateEnds if calNodeMap's results are same
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
rightEdgeNodes);
if (!conjuncts.containsKey(ends)) {
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
}
conjuncts.get(ends).first.add(expression);
}
for (Expression expression : join.getOtherJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
rightEdgeNodes);
if (!conjuncts.containsKey(ends)) {
conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>()));
}
conjuncts.get(ends).second.add(expression);
}
BitSet edgeMap = new BitSet();
edgeMap.or(leftEdgeMap);
edgeMap.or(rightEdgeMap);
BitSet curJoinEdges = new BitSet();
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
Lists.newArrayList(join.left(), join.right()));
Edge edge = new Edge(singleJoin, edges.size());
Edge edge = new Edge(singleJoin, edges.size(), leftEdgeNodes.first, rightEdgeNodes.first,
LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second));
Pair<Long, Long> ends = entry.getKey();
initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap);
edge.setLeftRequiredNodes(ends.first);
edge.setLeftExtendedNodes(ends.first);
edge.setRightRequiredNodes(ends.second);
edge.setRightExtendedNodes(ends.second);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
edgeMap.set(edge.getIndex());
curJoinEdges.set(edge.getIndex());
edges.add(edge);
}
return edgeMap;
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges));
curJoinEdges.stream().forEach(i -> makeConflictRules(edges.get(i)));
return curJoinEdges;
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
// We don't implement this trick now.
}
// Make edge with CD-A algorithm in
// Make edge with CD-C algorithm in
// On the correct and complete enumeration of the core search
private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) {
long left = ends.first;
long right = ends.second;
for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) {
Edge lEdge = edges.get(i);
if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getLeft());
private void makeConflictRules(Edge edgeB) {
BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges());
BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges());
long leftRequired = edgeB.getLeftRequiredNodes();
long rightRequired = edgeB.getRightRequiredNodes();
for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) {
Edge childA = edges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(edges));
}
if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getRight());
}
}
for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) {
Edge rEdge = edges.get(i);
if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getRight());
}
if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getLeft());
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(edges));
}
}
edge.setOriginalLeft(left);
edge.setOriginalRight(right);
edge.setLeft(left);
edge.setRight(right);
for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
Edge childA = edges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(edges));
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(edges));
}
}
edgeB.setLeftRequiredNodes(leftRequired);
edgeB.setRightRequiredNodes(rightRequired);
edgeB.setLeftExtendedNodes(leftRequired);
edgeB.setRightExtendedNodes(rightRequired);
}
private int findRoot(List<Integer> parent, int idx) {
int root = parent.get(idx);
if (root != idx) {
root = findRoot(parent, root);
}
parent.set(idx, root);
return root;
private BitSet subTreeEdges(Edge edge) {
BitSet bitSet = new BitSet();
bitSet.or(subTreeEdges(edge.getLeftChildEdges()));
bitSet.or(subTreeEdges(edge.getRightChildEdges()));
bitSet.set(edge.getIndex());
return bitSet;
}
private boolean isConnected(long bitmap, long excludeBitmap) {
if (LongBitmap.getCardinality(bitmap) == 1) {
return true;
}
// use unionSet to check whether the bitmap is connected
List<Integer> parent = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
parent.add(i, i);
}
for (Edge edge : edges) {
if (LongBitmap.isOverlap(edge.getLeft(), excludeBitmap)
|| LongBitmap.isOverlap(edge.getRight(), excludeBitmap)) {
continue;
}
int root = findRoot(parent, LongBitmap.nextSetBit(edge.getLeft(), 0));
for (int idx : LongBitmap.getIterator(edge.getLeft())) {
parent.set(idx, root);
}
for (int idx : LongBitmap.getIterator(edge.getRight())) {
parent.set(idx, root);
}
}
int root = findRoot(parent, LongBitmap.nextSetBit(bitmap, 0));
for (int idx : LongBitmap.getIterator(bitmap)) {
if (root != findRoot(parent, idx)) {
return false;
}
}
return true;
private BitSet subTreeEdges(BitSet edgeSet) {
BitSet bitSet = new BitSet();
edgeSet.stream()
.mapToObj(i -> subTreeEdges(edges.get(i)))
.forEach(b -> bitSet.or(b));
return bitSet;
}
private Pair<Long, Long> findEnds(Expression expression) {
long bitmap = calNodeMap(expression.getInputSlots());
int cardinality = LongBitmap.getCardinality(bitmap);
Preconditions.checkArgument(cardinality > 1);
for (long subset : LongBitmap.getSubsetIterator(bitmap)) {
long left = subset;
long right = LongBitmap.newBitmapDiff(bitmap, left);
// when the graph without right node has a connected-sub-graph contains left nodes
// and the graph without left node has a connected-sub-graph contains right nodes.
// we can generate an edge for this expression
if (isConnected(left, right) && isConnected(right, left)) {
return Pair.of(left, right);
}
// Try to calculate the ends of an expression.
// left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree
// if left = 0, recursively calculate it in left tree
private Pair<Long, Long> calculateEnds(long allNodes, Pair<BitSet, Long> leftEdgeNodes,
Pair<BitSet, Long> rightEdgeNodes) {
long left = LongBitmap.newBitmapIntersect(allNodes, leftEdgeNodes.second);
long right = LongBitmap.newBitmapIntersect(allNodes, rightEdgeNodes.second);
if (left == 0) {
Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0,
"the number of the table which expression reference is less 2");
Pair<BitSet, Long> llEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
Pair<BitSet, Long> lrEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes);
}
throw new RuntimeException("DPhyper meets unconnected subgraph");
if (right == 0) {
Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0,
"the number of the table which expression reference is less 2");
Pair<BitSet, Long> rlEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges);
Pair<BitSet, Long> rrEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges);
return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes);
}
return Pair.of(left, right);
}
private long calNodeMap(Set<Slot> slots) {
@ -291,10 +289,10 @@ public class HyperGraph {
// For these nodes that are only in the old edge, we need remove the edge from them
// For these nodes that are only in the new edge, we need to add the edge to them
Edge edge = edges.get(edgeIndex);
updateEdges(edge, edge.getLeft(), newLeft);
updateEdges(edge, edge.getRight(), newRight);
edges.get(edgeIndex).setLeft(newLeft);
edges.get(edgeIndex).setRight(newRight);
updateEdges(edge, edge.getLeftExtendedNodes(), newLeft);
updateEdges(edge, edge.getRightExtendedNodes(), newRight);
edges.get(edgeIndex).setLeftExtendedNodes(newLeft);
edges.get(edgeIndex).setRightExtendedNodes(newRight);
}
private void updateEdges(Edge edge, long oldNodes, long newNodes) {
@ -339,8 +337,8 @@ public class HyperGraph {
arrowHead = ",arrowhead=none";
}
int leftIndex = LongBitmap.lowestOneIndex(edge.getLeft());
int rightIndex = LongBitmap.lowestOneIndex(edge.getRight());
int leftIndex = LongBitmap.lowestOneIndex(edge.getLeftExtendedNodes());
int rightIndex = LongBitmap.lowestOneIndex(edge.getRightExtendedNodes());
builder.append(String.format("%s -> %s [label=\"%s\"%s]\n", graphvisNodes.get(leftIndex),
graphvisNodes.get(rightIndex), label, arrowHead));
} else {
@ -349,7 +347,7 @@ public class HyperGraph {
String leftLabel = "";
String rightLabel = "";
if (LongBitmap.getCardinality(edge.getLeft()) == 1) {
if (LongBitmap.getCardinality(edge.getLeftExtendedNodes()) == 1) {
rightLabel = label;
} else {
leftLabel = label;
@ -357,13 +355,13 @@ public class HyperGraph {
int finalI = i;
String finalLeftLabel = leftLabel;
for (int nodeIndex : LongBitmap.getIterator(edge.getLeft())) {
for (int nodeIndex : LongBitmap.getIterator(edge.getLeftExtendedNodes())) {
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
graphvisNodes.get(nodeIndex), finalI, finalLeftLabel));
}
String finalRightLabel = rightLabel;
for (int nodeIndex : LongBitmap.getIterator(edge.getRight())) {
for (int nodeIndex : LongBitmap.getIterator(edge.getRightExtendedNodes())) {
builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n",
graphvisNodes.get(nodeIndex), finalI, finalRightLabel));
}

View File

@ -225,8 +225,8 @@ public class SubgraphEnumerator {
neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
long left = edge.getLeft();
long right = edge.getRight();
long left = edge.getLeftExtendedNodes();
long right = edge.getRightExtendedNodes();
if (LongBitmap.isSubset(left, subgraph) && !LongBitmap.isOverlap(right, forbiddenNodes)) {
neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(right));
} else if (LongBitmap.isSubset(right, subgraph) && !LongBitmap.isOverlap(left, forbiddenNodes)) {
@ -362,14 +362,14 @@ public class SubgraphEnumerator {
}
private boolean isContainEdge(long subgraph, Edge edge) {
int containLeft = LongBitmap.isSubset(edge.getLeft(), subgraph) ? 0 : 1;
int containRight = LongBitmap.isSubset(edge.getRight(), subgraph) ? 0 : 1;
int containLeft = LongBitmap.isSubset(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
int containRight = LongBitmap.isSubset(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
return containLeft + containRight == 1;
}
private boolean isOverlapEdge(long subgraph, Edge edge) {
int overlapLeft = LongBitmap.isOverlap(edge.getLeft(), subgraph) ? 0 : 1;
int overlapRight = LongBitmap.isOverlap(edge.getRight(), subgraph) ? 0 : 1;
int overlapLeft = LongBitmap.isOverlap(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
int overlapRight = LongBitmap.isOverlap(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
return overlapLeft + overlapRight == 1;
}

View File

@ -60,6 +60,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
/**
@ -102,15 +103,14 @@ public class PlanReceiver implements AbstractReceiver {
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
processMissedEdges(left, right, edges);
Memo memo = jobContext.getCascadesContext().getMemo();
emitCount += 1;
if (emitCount > limit) {
return false;
}
Memo memo = jobContext.getCascadesContext().getMemo();
GroupPlan leftPlan = new GroupPlan(planTable.get(left));
GroupPlan rightPlan = new GroupPlan(planTable.get(right));
@ -118,6 +118,7 @@ public class PlanReceiver implements AbstractReceiver {
// In this step, we don't generate logical expression because they are useless in DPhyp.
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
@ -126,6 +127,7 @@ public class PlanReceiver implements AbstractReceiver {
List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
otherConjuncts);
List<Plan> physicalPlans = proposeProject(physicalJoins, edges, left, right);
// Second, we copy all physical plan to Group and generate properties and calculate cost
@ -188,7 +190,7 @@ public class PlanReceiver implements AbstractReceiver {
// find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes
for (Edge edge : hyperGraph.getEdges()) {
long referenceNodes =
LongBitmap.newBitmapUnion(edge.getOriginalLeft(), edge.getOriginalRight());
LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
if (LongBitmap.isSubset(referenceNodes, allReferenceNodes)
&& !usedEdgesBitmap.get(edge.getIndex())) {
// add the missed edge to edges
@ -344,7 +346,6 @@ public class PlanReceiver implements AbstractReceiver {
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Slot> outputs = allChild.get(0).getOutput();
Set<Slot> outputSet = allChild.get(0).getOutputSet();
List<NamedExpression> allProjects = Lists.newArrayList();
List<NamedExpression> complexProjects = new ArrayList<>();
// Calculate complex expression should be done by current(fullKey) node
@ -358,43 +359,23 @@ public class PlanReceiver implements AbstractReceiver {
// complexProjectMap is created by a bottom up traverse of join tree, so child node is put before parent node
// in the bitmaps
bitmaps.sort(Long::compare);
for (long bitmap : bitmaps) {
if (complexProjects.isEmpty()) {
complexProjects = complexProjectMap.get(bitmap);
complexProjects.addAll(complexProjectMap.get(bitmap));
} else {
// The top project of (T1, T2, T3) is different after reorder
// we need merge Project1 and Project2 as Project4 after reorder
// T1 join T2 join T3:
// Project1(a, e + f)
// join(a = e)
// Project2(a, b + d as e)
// join(a = c)
// T1(a, b)
// T2(c, d)
// T3(e, f)
//
// after reorder:
// T1 join T3 join T2:
// Project4(a, b + d + f)
// join(a = c)
// Project3(a, b, f)
// join(a = e)
// T1(a, b)
// T3(e, f)
// T2(c, d)
//
complexProjects =
PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap));
// Rewrite project expression by its children
complexProjects.addAll(
PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap)));
}
}
allProjects.addAll(complexProjects);
// calculate required columns by all parents
Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges);
// add output slots belong to required slots to project list
allProjects.addAll(outputs.stream().filter(e -> requireSlots.contains(e))
.collect(Collectors.toList()));
List<NamedExpression> allProjects = Stream.concat(
outputs.stream().filter(e -> requireSlots.contains(e)),
complexProjects.stream().filter(e -> requireSlots.contains(e.toSlot()))
).collect(Collectors.toList());
// propose physical project
if (allProjects.isEmpty()) {
@ -416,7 +397,13 @@ public class PlanReceiver implements AbstractReceiver {
.map(c -> new PhysicalProject<>(projects, projectProperties, c))
.collect(Collectors.toList());
}
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size());
if (!(!projects.isEmpty() && projects.size() == allProjects.size())) {
Set<NamedExpression> s1 = projects.stream().collect(Collectors.toSet());
List<NamedExpression> s2 = allProjects.stream().filter(e -> !s1.contains(e)).collect(Collectors.toList());
System.out.println(s2);
}
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size(),
" there are some projects left " + projects + allProjects);
return allChild;
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.HyperGraphBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@ -32,23 +33,37 @@ import java.util.Set;
public class OtherJoinTest extends TPCHTestBase {
@Test
public void randomTest() {
public void test() {
for (int t = 3; t < 10; t++) {
for (int e = t - 1; e <= (t * (t - 1)) / 2; e++) {
for (int i = 0; i < 10; i++) {
System.out.println(String.valueOf(t) + " " + e + ": " + i);
randomTest(t, e);
}
}
}
}
private void randomTest(int tableNum, int edgeNum) {
HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder();
Plan plan = hyperGraphBuilder
.randomBuildPlanWith(10, 20);
Set<List<Integer>> res1 = hyperGraphBuilder.evaluate(plan);
.randomBuildPlanWith(tableNum, edgeNum);
plan = new LogicalProject(plan.getOutput(), plan);
Set<List<String>> res1 = hyperGraphBuilder.evaluate(plan);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
hyperGraphBuilder.initStats(cascadesContext);
Plan optimizedPlan = PlanChecker.from(cascadesContext)
.dpHypOptimize()
.getBestPlanTree();
.dpHypOptimize()
.getBestPlanTree();
Set<List<Integer>> res2 = hyperGraphBuilder.evaluate(optimizedPlan);
Set<List<String>> res2 = hyperGraphBuilder.evaluate(optimizedPlan);
if (!res1.equals(res2)) {
System.out.println(res1);
System.out.println(res2);
System.out.println(plan.treeString());
System.out.println(optimizedPlan.treeString());
cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
PlanChecker.from(cascadesContext).dpHypOptimize().getBestPlanTree();
System.out.println(res1);
System.out.println(res2);
}
Assertions.assertTrue(res1.equals(res2));

View File

@ -130,8 +130,8 @@ public class SubgraphEnumeratorTest {
visited.add(left);
visited.add(right);
for (Edge edge : hyperGraph.getEdges()) {
if ((LongBitmap.isSubset(edge.getLeft(), left) && LongBitmap.isSubset(edge.getRight(), right)) || (
LongBitmap.isSubset(edge.getLeft(), right) && LongBitmap.isSubset(edge.getRight(), left))) {
if ((LongBitmap.isSubset(edge.getLeftExtendedNodes(), left) && LongBitmap.isSubset(edge.getRightExtendedNodes(), right)) || (
LongBitmap.isSubset(edge.getLeftExtendedNodes(), right) && LongBitmap.isSubset(edge.getRightExtendedNodes(), left))) {
count += countAndCheck(left, hyperGraph, counter, cache) * countAndCheck(right, hyperGraph,
counter, cache);
break;

View File

@ -87,6 +87,20 @@ public class JoinOrderJobTest extends SqlTestBase {
.dpHypOptimize();
}
@Test
protected void testConstantJoin() {
String sql = "select count(*) \n"
+ "from \n"
+ "T1 \n"
+ " join (\n"
+ "select * , now() as t from T2 \n"
+ ") subTable on T1.id = t; \n";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.dpHypOptimize();
}
@Test
protected void testCountJoin() {
String sql = "select count(*) \n"

View File

@ -17,9 +17,9 @@
package org.apache.doris.nereids.sqltest;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
@ -50,12 +50,8 @@ public class JoinTest extends SqlTestBase {
.getBestPlanTree();
// generate colocate join plan without physicalDistribute
System.out.println(plan.treeString());
Assertions.assertFalse(plan.anyMatch(p -> {
if (p instanceof PhysicalDistribute) {
return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather);
}
return false;
}));
Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute
&& ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash));
sql = "select * from T1 join T0 on T1.score = T0.score and T1.id = T0.id;";
plan = PlanChecker.from(connectContext)
.analyze(sql)
@ -63,12 +59,8 @@ public class JoinTest extends SqlTestBase {
.optimize()
.getBestPlanTree();
// generate colocate join plan without physicalDistribute
Assertions.assertFalse(plan.anyMatch(p -> {
if (p instanceof PhysicalDistribute) {
return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather);
}
return false;
}));
Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute
&& ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash));
}
@Test
@ -100,7 +92,7 @@ public class JoinTest extends SqlTestBase {
.analyze(sql)
.rewrite()
.optimize()
.getBestPlanTree();
.getBestPlanTree(PhysicalProperties.ANY);
Assertions.assertEquals(
ShuffleType.NATURAL,
((DistributionSpecHash) ((PhysicalPlan) (plan.child(0).child(0)))

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.util;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
@ -26,6 +27,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -35,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsCacheKey;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -90,6 +93,12 @@ public class HyperGraphBuilder {
return buildHyperGraph(plan);
}
public Plan buildPlan() {
assert plans.size() == 1 : "there are cross join";
Plan plan = plans.values().iterator().next();
return plan;
}
public Plan buildJoinPlan() {
assert plans.size() == 1 : "there are cross join";
Plan plan = plans.values().iterator().next();
@ -166,9 +175,14 @@ public class HyperGraphBuilder {
for (Group group : context.getMemo().getGroups()) {
GroupExpression groupExpression = group.getLogicalExpression();
if (groupExpression.getPlan() instanceof LogicalOlapScan) {
LogicalOlapScan scan = (LogicalOlapScan) groupExpression.getPlan();
Statistics stats = injectRowcount((LogicalOlapScan) groupExpression.getPlan());
groupExpression.setStatDerived(true);
group.setStatistics(stats);
for (Expression expr : stats.columnStatistics().keySet()) {
SlotReference slot = (SlotReference) expr;
Env.getCurrentEnv().getStatisticsCache().putCache(
new StatisticsCacheKey(scan.getTable().getId(), -1, slot.getName()),
stats.columnStatistics().get(expr));
}
}
}
}
@ -364,7 +378,7 @@ public class HyperGraphBuilder {
return hashConjunts;
}
public Set<List<Integer>> evaluate(Plan plan) {
public Set<List<String>> evaluate(Plan plan) {
JoinEvaluator evaluator = new JoinEvaluator(rowCounts);
Map<Slot, List<Integer>> res = evaluator.evaluate(plan);
int rowCount = 0;
@ -376,11 +390,12 @@ public class HyperGraphBuilder {
(slot1, slot2) ->
String.CASE_INSENSITIVE_ORDER.compare(slot1.toString(), slot2.toString()))
.collect(Collectors.toList());
Set<List<Integer>> tuples = new HashSet<>();
Set<List<String>> tuples = new HashSet<>();
tuples.add(keySet.stream().map(s -> s.toString()).collect(Collectors.toList()));
for (int i = 0; i < rowCount; i++) {
List<Integer> tuple = new ArrayList<>();
List<String> tuple = new ArrayList<>();
for (Slot key : keySet) {
tuple.add(res.get(key).get(i));
tuple.add(String.valueOf(res.get(key).get(i)));
}
tuples.add(tuple);
}

View File

@ -33,7 +33,6 @@ import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
@ -243,23 +242,13 @@ public class PlanChecker {
public PlanChecker dpHypOptimize() {
double now = System.currentTimeMillis();
cascadesContext.getStatementContext().setDpHyp(true);
cascadesContext.getConnectContext().getSessionVariable().enableDPHypOptimizer = true;
Group root = cascadesContext.getMemo().getRoot();
boolean changeRoot = false;
if (root.isValidJoinGroup()) {
// If the root group is join group, DPHyp can change the root group.
// To keep the root group is not changed, we add a dummy project operator above join
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
LogicalPlan plan = new LogicalProject(outputs, root.getLogicalExpression().getPlan());
CopyInResult copyInResult = cascadesContext.getMemo().copyIn(plan, null, false);
root = copyInResult.correspondingExpression.getOwnerGroup();
changeRoot = true;
}
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
cascadesContext.pushJob(new DeriveStatsJob(root.getLogicalExpression(),
cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
if (changeRoot) {
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
}
// if the root is not join, we need to optimize again.
optimize();
System.out.println("DPhyp:" + (System.currentTimeMillis() - now));
return this;
@ -602,7 +591,7 @@ public class PlanChecker {
}
public PhysicalPlan getBestPlanTree() {
return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.ANY);
return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.GATHER);
}
public PhysicalPlan getBestPlanTree(PhysicalProperties properties) {