diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java index b25793090b..08e40f84a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java @@ -73,11 +73,12 @@ public class JoinOrderJob extends Job { } private Group optimizeJoin(Group group) { - HyperGraph hyperGraph = HyperGraph.toDPhyperGraph(group); - for (AbstractNode node : hyperGraph.getNodes()) { + HyperGraph.Builder builder = HyperGraph.builderForDPhyper(group); + for (AbstractNode node : builder.getNodes()) { DPhyperNode dPhyperNode = (DPhyperNode) node; - hyperGraph.updateNode(node.getIndex(), optimizePlan(dPhyperNode.getGroup())); + builder.updateNode(node.getIndex(), optimizePlan(dPhyperNode.getGroup())); } + HyperGraph hyperGraph = builder.build(); // TODO: Right now, we just hardcode the limit with 10000, maybe we need a better way to set it int limit = 1000; PlanReceiver planReceiver = new PlanReceiver(this.context, limit, hyperGraph, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 8a0bd8daaa..8472837d79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -42,6 +42,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.PlanUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; @@ -58,21 +60,25 @@ import java.util.stream.Collectors; * It's used for join ordering */ public class HyperGraph { - private final List joinEdges = new ArrayList<>(); - private final List filterEdges = new ArrayList<>(); - private final List nodes = new ArrayList<>(); - private final HashMap slotToNodeMap = new HashMap<>(); // record all edges that can be placed on the subgraph private final Map treeEdgesCache = new HashMap<>(); + private final List joinEdges; + private final List filterEdges; + private final List nodes; private final Set finalOutputs; // Record the complex project expression for some subgraph // e.g. project (a + b) // |-- join(t1.a = t2.b) - private final HashMap> complexProject = new HashMap<>(); + private final Map> complexProject; - HyperGraph(Set finalOutputs) { + HyperGraph(Set finalOutputs, List joinEdges, List nodes, List filterEdges, + Map> complexProject) { this.finalOutputs = ImmutableSet.copyOf(finalOutputs); + this.joinEdges = ImmutableList.copyOf(joinEdges); + this.nodes = ImmutableList.copyOf(nodes); + this.complexProject = ImmutableMap.copyOf(complexProject); + this.filterEdges = ImmutableList.copyOf(filterEdges); } public List getJoinEdges() { @@ -103,191 +109,10 @@ public class HyperGraph { return nodes.get(index); } - /** - * Store the relation between Alias Slot and Original Slot and its expression - * e.g., - * a = b - * |--- project((c + d) as b) - *

- * a = b - * |--- project((c + 1) as b) - * - * @param alias The alias Expression in project Operator - */ - public boolean addAlias(Alias alias, long subTreeNodes) { - Slot aliasSlot = alias.toSlot(); - if (slotToNodeMap.containsKey(aliasSlot)) { - return true; - } - long bitmap = LongBitmap.newBitmap(); - for (Slot slot : alias.getInputSlots()) { - bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); - } - // The case hit when there are some constant aliases such as: - // select * from t1 join ( - // select *, 1 as b1 from t2) - // on t1.b = b1 - // just reference them all for this slot - if (bitmap == 0) { - bitmap = subTreeNodes; - } - Preconditions.checkArgument(bitmap > 0, "slot must belong to some table"); - slotToNodeMap.put(aliasSlot, bitmap); - if (!complexProject.containsKey(bitmap)) { - complexProject.put(bitmap, new ArrayList<>()); - } - alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0); - - complexProject.get(bitmap).add(alias); - return true; - } - - /** - * add end node to HyperGraph - * - * @param group The group that is the end node in graph - * @return return the node index - */ - private int addDPHyperNode(Group group) { - for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) { - Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); - slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); - } - nodes.add(new DPhyperNode(nodes.size(), group)); - return nodes.size() - 1; - } - - /** - * add end node to HyperGraph - * - * @param plan The plan that is the end node in graph - * @return return the node index - */ - private int addStructInfoNode(Plan plan) { - for (Slot slot : plan.getOutput()) { - Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); - slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); - } - nodes.add(new StructInfoNode(nodes.size(), plan)); - return nodes.size() - 1; - } - - private int addStructInfoNode(List childGraphs) { - for (Slot slot : childGraphs.get(0).finalOutputs) { - Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); - slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); - } - nodes.add(new StructInfoNode(nodes.size(), childGraphs)); - return nodes.size() - 1; - } - - public void updateNode(int idx, Group group) { - Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode); - nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group)); - } - - public HashMap> getComplexProject() { + public Map> getComplexProject() { return complexProject; } - private void addEdgeOfInfo(JoinEdge edge) { - long nodeMap = calNodeMap(edge.getInputSlots()); - Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1, - "edge must have more than one ends"); - long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0)); - long right = LongBitmap.newBitmapDiff(nodeMap, left); - this.joinEdges.add(new JoinEdge(edge.getJoin(), joinEdges.size(), - null, null, 0, left, right)); - } - - /** - * try to add edge for join group - * - * @param join The join plan - */ - private BitSet addJoin(LogicalJoin join, - Pair leftEdgeNodes, Pair rightEdgeNodes) { - HashMap, Pair, List>> conjuncts = new HashMap<>(); - for (Expression expression : join.getHashJoinConjuncts()) { - // TODO: avoid calling calculateEnds if calNodeMap's results are same - Pair ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, - rightEdgeNodes); - if (!conjuncts.containsKey(ends)) { - conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); - } - conjuncts.get(ends).first.add(expression); - } - for (Expression expression : join.getOtherJoinConjuncts()) { - Pair ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, - rightEdgeNodes); - if (!conjuncts.containsKey(ends)) { - conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); - } - conjuncts.get(ends).second.add(expression); - } - - BitSet curJoinEdges = new BitSet(); - for (Map.Entry, Pair, List>> entry : conjuncts - .entrySet()) { - LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, - entry.getValue().second, - new DistributeHint(DistributeType.NONE), join.getMarkJoinSlotReference(), - Lists.newArrayList(join.left(), join.right())); - Pair ends = entry.getKey(); - JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first, - LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second), ends.first, ends.second); - for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { - nodes.get(nodeIndex).attachEdge(edge); - } - curJoinEdges.set(edge.getIndex()); - joinEdges.add(edge); - } - curJoinEdges.stream().forEach(i -> joinEdges.get(i).addCurJoinEdges(curJoinEdges)); - curJoinEdges.stream().forEach(i -> ConflictRulesMaker.makeJoinConflictRules(joinEdges.get(i), joinEdges)); - curJoinEdges.stream().forEach(i -> - ConflictRulesMaker.makeFilterConflictRules(joinEdges.get(i), joinEdges, filterEdges)); - return curJoinEdges; - // In MySQL, each edge is reversed and store in edges again for reducing the branch miss - // We don't implement this trick now. - } - - private BitSet addFilter(LogicalFilter filter, Pair childEdgeNodes) { - FilterEdge edge = new FilterEdge(filter, filterEdges.size(), childEdgeNodes.first, childEdgeNodes.second, - childEdgeNodes.second); - filterEdges.add(edge); - BitSet bitSet = new BitSet(); - bitSet.set(edge.getIndex()); - return bitSet; - } - - // Try to calculate the ends of an expression. - // left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree - // if left = 0, recursively calculate it in left tree - private Pair calculateEnds(long allNodes, Pair leftEdgeNodes, - Pair rightEdgeNodes) { - long left = LongBitmap.newBitmapIntersect(allNodes, leftEdgeNodes.second); - long right = LongBitmap.newBitmapIntersect(allNodes, rightEdgeNodes.second); - if (left == 0) { - Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0, - "the number of the table which expression reference is less 2"); - Pair llEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( - joinEdges); - Pair lrEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( - joinEdges); - return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes); - } - if (right == 0) { - Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0, - "the number of the table which expression reference is less 2"); - Pair rlEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( - joinEdges); - Pair rrEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( - joinEdges); - return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes); - } - return Pair.of(left, right); - } - public BitSet getEdgesInOperator(long left, long right) { BitSet operatorEdgesMap = new BitSet(); operatorEdgesMap.or(getEdgesInTree(LongBitmap.or(left, right))); @@ -312,182 +137,6 @@ public class HyperGraph { return treeEdgesCache.get(treeNodesMap); } - private long calNodeMap(Set slots) { - Preconditions.checkArgument(slots.size() != 0); - long bitmap = LongBitmap.newBitmap(); - for (Slot slot : slots) { - Preconditions.checkArgument(slotToNodeMap.containsKey(slot)); - bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); - } - return bitmap; - } - - public static List toStructInfo(Plan plan) { - HyperGraph hyperGraph = new HyperGraph(plan.getOutputSet()); - hyperGraph.buildStructInfo(plan); - return hyperGraph.flatChildren(); - } - - private List flatChildren() { - if (nodes.stream().noneMatch(n -> ((StructInfoNode) n).needToFlat())) { - return Lists.newArrayList(this); - } - List res = new ArrayList<>(); - res.add(new HyperGraph(finalOutputs)); - for (AbstractNode node : nodes) { - res = flatChild((StructInfoNode) node, res); - } - for (JoinEdge edge : joinEdges) { - res.forEach(g -> g.addEdgeOfInfo(edge)); - } - return res; - } - - private List flatChild(StructInfoNode infoNode, List hyperGraphs) { - if (!infoNode.needToFlat()) { - hyperGraphs.forEach(g -> g.addStructInfoNode(infoNode.getPlan())); - return hyperGraphs; - } - return hyperGraphs.stream().flatMap(g -> - infoNode.getGraphs().stream().map(subGraph -> { - HyperGraph hyperGraph = new HyperGraph(g.finalOutputs); - hyperGraph.addStructInfo(g); - hyperGraph.addStructInfo(subGraph); - return hyperGraph; - }) - ).collect(Collectors.toList()); - } - - public static HyperGraph toDPhyperGraph(Group group) { - HyperGraph hyperGraph = new HyperGraph(group.getLogicalProperties().getOutputSet()); - hyperGraph.buildDPhyperGraph(group.getLogicalExpressions().get(0)); - return hyperGraph; - } - - // Build Graph for DPhyper - private Pair buildDPhyperGraph(GroupExpression groupExpression) { - // process Project - if (isValidProject(groupExpression.getPlan())) { - LogicalProject project = (LogicalProject) groupExpression.getPlan(); - Pair res = this.buildDPhyperGraph(groupExpression.child(0).getLogicalExpressions().get(0)); - for (NamedExpression expr : project.getProjects()) { - if (expr instanceof Alias) { - this.addAlias((Alias) expr, res.second); - } - } - return res; - } - - // process Join - if (isValidJoin(groupExpression.getPlan())) { - LogicalJoin join = (LogicalJoin) groupExpression.getPlan(); - Pair left = this.buildDPhyperGraph(groupExpression.child(0).getLogicalExpressions().get(0)); - Pair right = this.buildDPhyperGraph(groupExpression.child(1).getLogicalExpressions().get(0)); - return Pair.of(this.addJoin(join, left, right), - LongBitmap.or(left.second, right.second)); - } - - // process Other Node - int idx = this.addDPHyperNode(groupExpression.getOwnerGroup()); - return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); - } - - private void addStructInfo(HyperGraph other) { - int offset = this.getNodes().size(); - other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan())); - other.getComplexProject().forEach((t, projectList) -> - projectList.forEach(e -> this.addAlias((Alias) e, t << offset))); - other.getJoinEdges().forEach(this::addEdgeOfInfo); - } - - // Build Graph for matching mv, return join edge set and nodes in this plan - private Pair buildStructInfo(Plan plan) { - if (plan instanceof GroupPlan) { - Group group = ((GroupPlan) plan).getGroup(); - List childGraphs = ((GroupPlan) plan).getGroup().getHyperGraphs(); - if (childGraphs.size() != 0) { - int idx = addStructInfoNode(childGraphs); - return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); - } - GroupExpression groupExpression = group.getLogicalExpressions().get(0); - return buildStructInfo(groupExpression.getPlan() - .withChildren( - groupExpression.children().stream().map(GroupPlan::new).collect(Collectors.toList()))); - } - // process Project - if (isValidProject(plan)) { - LogicalProject project = (LogicalProject) plan; - Pair res = this.buildStructInfo(plan.child(0)); - for (NamedExpression expr : project.getProjects()) { - if (expr instanceof Alias) { - this.addAlias((Alias) expr, res.second); - } - } - return res; - } - - // process Join - if (isValidJoinForStructInfo(plan)) { - LogicalJoin join = (LogicalJoin) plan; - Pair left = this.buildStructInfo(plan.child(0)); - Pair right = this.buildStructInfo(plan.child(1)); - return Pair.of(this.addJoin(join, left, right), - LongBitmap.or(left.second, right.second)); - } - - if (isValidFilter(plan)) { - LogicalFilter filter = (LogicalFilter) plan; - Pair child = this.buildStructInfo(filter.child()); - this.addFilter(filter, child); - return Pair.of(new BitSet(), child.second); - } - - // process Other Node - int idx = this.addStructInfoNode(plan); - return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); - } - - /** - * inner join group without mark slot - */ - public static boolean isValidJoin(Plan plan) { - if (!(plan instanceof LogicalJoin)) { - return false; - } - LogicalJoin join = (LogicalJoin) plan; - return join.getJoinType() == JoinType.INNER_JOIN - && !join.isMarkJoin() - && !join.getExpressions().isEmpty(); - } - - /** - * inner join group without mark slot - */ - public static boolean isValidJoinForStructInfo(Plan plan) { - if (!(plan instanceof LogicalJoin)) { - return false; - } - - LogicalJoin join = (LogicalJoin) plan; - return !join.isMarkJoin() - && !join.getExpressions().isEmpty(); - } - - public static boolean isValidFilter(Plan plan) { - return plan instanceof LogicalFilter; - } - - /** - * the project with alias and slot - */ - public static boolean isValidProject(Plan plan) { - if (!(plan instanceof LogicalProject)) { - return false; - } - return ((LogicalProject) plan).getProjects().stream() - .allMatch(e -> e instanceof Slot || e instanceof Alias); - } - /** * Graph simplifier need to update the edge for join ordering * @@ -520,8 +169,45 @@ public class HyperGraph { LongBitmap.getIterator(addedNodes).forEach(index -> nodes.get(index).attachEdge(edge)); } - public int edgeSize() { - return joinEdges.size() + filterEdges.size(); + /** + * inner join group without mark slot + */ + public static boolean isValidJoinForStructInfo(Plan plan) { + if (!(plan instanceof LogicalJoin)) { + return false; + } + + LogicalJoin join = (LogicalJoin) plan; + return !join.isMarkJoin() + && !join.getExpressions().isEmpty(); + } + + public static boolean isValidFilter(Plan plan) { + return plan instanceof LogicalFilter; + } + + /** + * the project with alias and slot + */ + public static boolean isValidProject(Plan plan) { + if (!(plan instanceof LogicalProject)) { + return false; + } + return ((LogicalProject) plan).getProjects().stream() + .allMatch(e -> e instanceof Slot || e instanceof Alias); + } + + /** + * inner join group without mark slot + */ + public static boolean isValidJoin(Plan plan) { + if (!(plan instanceof LogicalJoin)) { + return false; + } + LogicalJoin join = (LogicalJoin) plan; + return join.getJoinType() == JoinType.INNER_JOIN + && !join.isMarkJoin() + && !join.getExpressions().isEmpty(); } /** @@ -535,7 +221,7 @@ public class HyperGraph { */ public String toDottyHyperGraph() { StringBuilder builder = new StringBuilder(); - builder.append(String.format("digraph G { # %d edges\n", joinEdges.size())); + builder.append(String.format("digraph G { # %d edges%n", joinEdges.size())); List graphvisNodes = new ArrayList<>(); for (AbstractNode node : nodes) { String nodeName = node.getName(); @@ -547,7 +233,7 @@ public class HyperGraph { double rowCount = (node instanceof DPhyperNode) ? ((DPhyperNode) node).getRowCount() : -1; - builder.append(String.format(" %s [label=\"%s \n rowCount=%.2f\"];\n", + builder.append(String.format(" %s [label=\"%s %n rowCount=%.2f\"];%n", nodeID, nodeName, rowCount)); graphvisNodes.add(nodeName); } @@ -563,11 +249,11 @@ public class HyperGraph { int leftIndex = LongBitmap.lowestOneIndex(edge.getLeftExtendedNodes()); int rightIndex = LongBitmap.lowestOneIndex(edge.getRightExtendedNodes()); - builder.append(String.format("%s -> %s [label=\"%s\"%s]\n", graphvisNodes.get(leftIndex), + builder.append(String.format("%s -> %s [label=\"%s\"%s]%n", graphvisNodes.get(leftIndex), graphvisNodes.get(rightIndex), label, arrowHead)); } else { // Hyper edge is considered as a tiny virtual node - builder.append(String.format("e%d [shape=circle, width=.001, label=\"\"]\n", i)); + builder.append(String.format("e%d [shape=circle, width=.001, label=\"\"]%n", i)); String leftLabel = ""; String rightLabel = ""; @@ -577,21 +263,312 @@ public class HyperGraph { leftLabel = label; } - int finalI = i; String finalLeftLabel = leftLabel; for (int nodeIndex : LongBitmap.getIterator(edge.getLeftExtendedNodes())) { - builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", - graphvisNodes.get(nodeIndex), finalI, finalLeftLabel)); + builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]%n", + graphvisNodes.get(nodeIndex), i, finalLeftLabel)); } String finalRightLabel = rightLabel; for (int nodeIndex : LongBitmap.getIterator(edge.getRightExtendedNodes())) { - builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", - graphvisNodes.get(nodeIndex), finalI, finalRightLabel)); + builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]%n", + graphvisNodes.get(nodeIndex), i, finalRightLabel)); } } } builder.append("}\n"); return builder.toString(); } + + public static HyperGraph.Builder builderForDPhyper(Group group) { + return new HyperGraph.Builder().buildHyperGraphForDPhyper(group); + } + + public static HyperGraph.Builder builderForMv(Plan plan) { + return new HyperGraph.Builder().buildHyperGraphForMv(plan); + } + + /** + * Builder of HyperGraph + */ + public static class Builder { + private final List joinEdges = new ArrayList<>(); + private final List filterEdges = new ArrayList<>(); + private final List nodes = new ArrayList<>(); + + // These hyperGraphs should be replaced nodes when building all + private final Map> replacedHyperGraphs = new HashMap<>(); + private final HashMap slotToNodeMap = new HashMap<>(); + private final Map> complexProject = new HashMap<>(); + private Set finalOutputs; + + public List getNodes() { + return nodes; + } + + private HyperGraph.Builder buildHyperGraphForDPhyper(Group group) { + finalOutputs = group.getLogicalProperties().getOutputSet(); + this.buildForDPhyper(group.getLogicalExpression()); + return this; + } + + private HyperGraph.Builder buildHyperGraphForMv(Plan plan) { + finalOutputs = plan.getOutputSet(); + this.buildForMv(plan); + return this; + } + + public HyperGraph build() { + return new HyperGraph(finalOutputs, joinEdges, nodes, filterEdges, complexProject); + } + + public List buildAll() { + return ImmutableList.of(build()); + } + + public void updateNode(int idx, Group group) { + Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode); + nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group)); + } + + // Build Graph for DPhyper + private Pair buildForDPhyper(GroupExpression groupExpression) { + // process Project + if (isValidProject(groupExpression.getPlan())) { + LogicalProject project = (LogicalProject) groupExpression.getPlan(); + Pair res = this.buildForDPhyper(groupExpression.child(0).getLogicalExpressions().get(0)); + for (NamedExpression expr : project.getProjects()) { + if (expr instanceof Alias) { + this.addAlias((Alias) expr, res.second); + } + } + return res; + } + + // process Join + if (isValidJoin(groupExpression.getPlan())) { + LogicalJoin join = (LogicalJoin) groupExpression.getPlan(); + Pair left = + this.buildForDPhyper(groupExpression.child(0).getLogicalExpressions().get(0)); + Pair right = + this.buildForDPhyper(groupExpression.child(1).getLogicalExpressions().get(0)); + return Pair.of(this.addJoin(join, left, right), + LongBitmap.or(left.second, right.second)); + } + + // process Other Node + int idx = this.addDPHyperNode(groupExpression.getOwnerGroup()); + return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); + } + + // Build Graph for matching mv, return join edge set and nodes in this plan + private Pair buildForMv(Plan plan) { + if (plan instanceof GroupPlan) { + Group group = ((GroupPlan) plan).getGroup(); + GroupExpression groupExpression = group.getLogicalExpressions().get(0); + return buildForMv(groupExpression.getPlan() + .withChildren( + groupExpression.children().stream().map(GroupPlan::new).collect(Collectors.toList()))); + } + // process Project + if (isValidProject(plan)) { + LogicalProject project = (LogicalProject) plan; + Pair res = this.buildForMv(plan.child(0)); + for (NamedExpression expr : project.getProjects()) { + if (expr instanceof Alias) { + this.addAlias((Alias) expr, res.second); + } + } + return res; + } + + // process Join + if (isValidJoinForStructInfo(plan)) { + LogicalJoin join = (LogicalJoin) plan; + Pair left = this.buildForMv(plan.child(0)); + Pair right = this.buildForMv(plan.child(1)); + return Pair.of(this.addJoin(join, left, right), + LongBitmap.or(left.second, right.second)); + } + + if (isValidFilter(plan)) { + LogicalFilter filter = (LogicalFilter) plan; + Pair child = this.buildForMv(filter.child()); + this.addFilter(filter, child); + return Pair.of(new BitSet(), child.second); + } + + // process Other Node + int idx = this.addStructInfoNode(plan); + return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); + } + + /** + * Store the relation between Alias Slot and Original Slot and its expression + * e.g., + * a = b + * |--- project((c + d) as b) + *

+ * a = b + * |--- project((c + 1) as b) + * + * @param alias The alias Expression in project Operator + */ + public boolean addAlias(Alias alias, long subTreeNodes) { + Slot aliasSlot = alias.toSlot(); + if (slotToNodeMap.containsKey(aliasSlot)) { + return true; + } + long bitmap = LongBitmap.newBitmap(); + for (Slot slot : alias.getInputSlots()) { + bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); + } + // The case hit when there are some constant aliases such as: + // select * from t1 join ( + // select *, 1 as b1 from t2) + // on t1.b = b1 + // just reference them all for this slot + if (bitmap == 0) { + bitmap = subTreeNodes; + } + Preconditions.checkArgument(bitmap > 0, "slot must belong to some table"); + slotToNodeMap.put(aliasSlot, bitmap); + if (!complexProject.containsKey(bitmap)) { + complexProject.put(bitmap, new ArrayList<>()); + } + alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0); + + complexProject.get(bitmap).add(alias); + return true; + } + + /** + * add end node to HyperGraph + * + * @param group The group that is the end node in graph + * @return return the node index + */ + private int addDPHyperNode(Group group) { + for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) { + Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); + slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); + } + nodes.add(new DPhyperNode(nodes.size(), group)); + return nodes.size() - 1; + } + + /** + * add end node to HyperGraph + * + * @param plan The plan that is the end node in graph + * @return return the node index + */ + private int addStructInfoNode(Plan plan) { + for (Slot slot : plan.getOutput()) { + Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); + slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); + } + nodes.add(new StructInfoNode(nodes.size(), plan)); + return nodes.size() - 1; + } + + private long calNodeMap(Set slots) { + Preconditions.checkArgument(slots.size() != 0); + long bitmap = LongBitmap.newBitmap(); + for (Slot slot : slots) { + Preconditions.checkArgument(slotToNodeMap.containsKey(slot)); + bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); + } + return bitmap; + } + + /** + * try to add edge for join group + * + * @param join The join plan + */ + private BitSet addJoin(LogicalJoin join, + Pair leftEdgeNodes, Pair rightEdgeNodes) { + HashMap, Pair, List>> conjuncts = new HashMap<>(); + for (Expression expression : join.getHashJoinConjuncts()) { + // TODO: avoid calling calculateEnds if calNodeMap's results are same + Pair ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, + rightEdgeNodes); + if (!conjuncts.containsKey(ends)) { + conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); + } + conjuncts.get(ends).first.add(expression); + } + for (Expression expression : join.getOtherJoinConjuncts()) { + Pair ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, + rightEdgeNodes); + if (!conjuncts.containsKey(ends)) { + conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); + } + conjuncts.get(ends).second.add(expression); + } + + BitSet curJoinEdges = new BitSet(); + for (Map.Entry, Pair, List>> entry : conjuncts + .entrySet()) { + LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, + entry.getValue().second, + new DistributeHint(DistributeType.NONE), join.getMarkJoinSlotReference(), + Lists.newArrayList(join.left(), join.right())); + Pair ends = entry.getKey(); + JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first, + LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second), + ends.first, ends.second); + for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { + nodes.get(nodeIndex).attachEdge(edge); + } + curJoinEdges.set(edge.getIndex()); + joinEdges.add(edge); + } + curJoinEdges.stream().forEach(i -> joinEdges.get(i).addCurJoinEdges(curJoinEdges)); + curJoinEdges.stream().forEach(i -> ConflictRulesMaker.makeJoinConflictRules(joinEdges.get(i), joinEdges)); + curJoinEdges.stream().forEach(i -> + ConflictRulesMaker.makeFilterConflictRules(joinEdges.get(i), joinEdges, filterEdges)); + return curJoinEdges; + // In MySQL, each edge is reversed and store in edges again for reducing the branch miss + // We don't implement this trick now. + } + + private BitSet addFilter(LogicalFilter filter, Pair childEdgeNodes) { + FilterEdge edge = new FilterEdge(filter, filterEdges.size(), childEdgeNodes.first, childEdgeNodes.second, + childEdgeNodes.second); + filterEdges.add(edge); + BitSet bitSet = new BitSet(); + bitSet.set(edge.getIndex()); + return bitSet; + } + + // Try to calculate the ends of an expression. + // left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree + // if left = 0, recursively calculate it in left tree + private Pair calculateEnds(long allNodes, Pair leftEdgeNodes, + Pair rightEdgeNodes) { + long left = LongBitmap.newBitmapIntersect(allNodes, leftEdgeNodes.second); + long right = LongBitmap.newBitmapIntersect(allNodes, rightEdgeNodes.second); + if (left == 0) { + Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0, + "the number of the table which expression reference is less 2"); + Pair llEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( + joinEdges); + Pair lrEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( + joinEdges); + return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes); + } + if (right == 0) { + Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0, + "the number of the table which expression reference is less 2"); + Pair rlEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( + joinEdges); + Pair rrEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( + joinEdges); + return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes); + } + return Pair.of(left, right); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java index 9e3886bfdc..e32baba6a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.GroupPlan; @@ -44,8 +43,6 @@ import javax.annotation.Nullable; * HyperGraph Node. */ public class StructInfoNode extends AbstractNode { - - private List graphs = new ArrayList<>(); private final List> expressions; private final Set relationSet; @@ -59,11 +56,6 @@ public class StructInfoNode extends AbstractNode { this(index, plan, new ArrayList<>()); } - public StructInfoNode(int index, List graphs) { - this(index, graphs.get(0).getNode(0).getPlan(), new ArrayList<>()); - this.graphs = graphs; - } - private @Nullable List> collectExpressions(Plan plan) { if (plan instanceof LeafPlan) { return ImmutableList.of(); @@ -122,14 +114,6 @@ public class StructInfoNode extends AbstractNode { return plan.withChildren(children); } - public boolean needToFlat() { - return !graphs.isEmpty(); - } - - public List getGraphs() { - return graphs; - } - @Override public String toString() { return Utils.toSqlString("StructInfoNode[" + this.getName() + "]", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java index 3451d8e7c4..62f325a05d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java @@ -225,7 +225,7 @@ public class StructInfo { // if single table without join, the bottom is originalPlan.accept(PLAN_SPLITTER, planSplitContext); - List structInfos = HyperGraph.toStructInfo(planSplitContext.getBottomPlan()); + List structInfos = HyperGraph.builderForMv(planSplitContext.getBottomPlan()).buildAll(); return structInfos.stream() .map(hyperGraph -> StructInfo.of(originalPlan, planSplitContext.getTopPlan(), planSplitContext.getBottomPlan(), hyperGraph)) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java index bbf7746ec6..afdce7ca08 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java @@ -63,8 +63,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); Assertions.assertFalse( HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)).isInvalid()); } @@ -82,8 +82,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); Assertions.assertFalse( HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)).isInvalid()); } @@ -108,8 +108,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(1, res.getQueryExpressions().size()); Assertions.assertEquals("(id = 0)", res.getQueryExpressions().get(0).toSql()); @@ -135,8 +135,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); List exprList = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)).getQueryExpressions(); Assertions.assertEquals(0, exprList.size()); } @@ -162,8 +162,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(1, res.getQueryExpressions().size()); Assertions.assertEquals("(id = 0)", res.getQueryExpressions().get(0).toSql()); @@ -190,8 +190,8 @@ class CompareOuterJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(res.isInvalid()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferJoinTest.java index 3f1ce3fc34..05f56c4de2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferJoinTest.java @@ -58,8 +58,8 @@ class InferJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertFalse(res.isInvalid()); Assertions.assertEquals(1, res.getViewNoNullableSlot().size()); @@ -87,8 +87,8 @@ class InferJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertFalse(res.isInvalid()); Assertions.assertEquals(1, res.getViewNoNullableSlot().size()); @@ -124,8 +124,8 @@ class InferJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertFalse(res.isInvalid()); Assertions.assertEquals(1, res.getViewNoNullableSlot().size()); @@ -155,8 +155,8 @@ class InferJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(res.isInvalid()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferPredicateTest.java index 5ea49a51f5..8bb1ede804 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/InferPredicateTest.java @@ -53,8 +53,8 @@ class InferPredicateTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertFalse(res.isInvalid()); Assertions.assertEquals("(id = 1)", res.getQueryExpressions().get(0).toSql()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/PullupExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/PullupExpressionTest.java index 44b6075394..6e65dba4f0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/PullupExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/PullupExpressionTest.java @@ -53,8 +53,8 @@ class PullupExpressionTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(1, res.getQueryExpressions().size()); Assertions.assertEquals("(id = 1)", res.getQueryExpressions().get(0).toSql()); @@ -79,8 +79,8 @@ class PullupExpressionTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(1, res.getQueryExpressions().size()); Assertions.assertEquals("(score = score)", res.getQueryExpressions().get(0).toSql()); @@ -105,8 +105,8 @@ class PullupExpressionTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(2, res.getViewExpressions().size()); Assertions.assertEquals("(id = 1)", res.getViewExpressions().get(0).toSql()); @@ -132,8 +132,8 @@ class PullupExpressionTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertEquals(1, res.getViewExpressions().size()); Assertions.assertEquals("(score = score)", res.getViewExpressions().get(0).toSql()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/BuildStructInfoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/BuildStructInfoTest.java index bf5edfa45e..20324ff873 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/BuildStructInfoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/BuildStructInfoTest.java @@ -41,7 +41,7 @@ class BuildStructInfoTest extends SqlTestBase { .deriveStats() .matches(logicalJoin() .when(j -> { - HyperGraph.toStructInfo(j); + HyperGraph.builderForMv(j); return true; })); @@ -58,7 +58,7 @@ class BuildStructInfoTest extends SqlTestBase { .deriveStats() .matches(logicalJoin() .when(j -> { - List hyperGraph = HyperGraph.toStructInfo(j); + List hyperGraph = HyperGraph.builderForMv(j).buildAll(); Assertions.assertTrue(hyperGraph.get(0).getNodes().stream() .allMatch(n -> n.getPlan() .collectToList(GroupPlan.class::isInstance).isEmpty())); @@ -77,7 +77,7 @@ class BuildStructInfoTest extends SqlTestBase { .rewrite() .matches(logicalJoin() .when(j -> { - HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0); + HyperGraph structInfo = HyperGraph.builderForMv(j).buildAll().get(0); Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin()); Assertions.assertEquals(0, structInfo.getFilterEdge(0).getLeftRejectEdge().size()); Assertions.assertEquals(1, structInfo.getFilterEdge(0).getRightRejectEdge().size()); @@ -91,7 +91,7 @@ class BuildStructInfoTest extends SqlTestBase { .rewrite() .matches(logicalJoin() .when(j -> { - HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0); + HyperGraph structInfo = HyperGraph.builderForMv(j).buildAll().get(0); Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin()); return true; })); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphAggTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphAggTest.java index 6032d00040..38ebc99c47 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphAggTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphAggTest.java @@ -49,7 +49,7 @@ class HyperGraphAggTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p2).buildAll().get(0); Assertions.assertEquals("id", Objects.requireNonNull(((StructInfoNode) h1.getNode(1)).getExpressions()).get(0).toSql()); } @@ -79,8 +79,8 @@ class HyperGraphAggTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java index b89934d457..1fd2ac86ab 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java @@ -55,8 +55,8 @@ class HyperGraphComparatorTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); @@ -86,8 +86,8 @@ class HyperGraphComparatorTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); @@ -118,8 +118,8 @@ class HyperGraphComparatorTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); @@ -153,8 +153,8 @@ class HyperGraphComparatorTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); - HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); - HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); + HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java index f544e6ad8b..d167054c29 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java @@ -335,14 +335,14 @@ public class HyperGraphBuilder { plan); cascadesContext.getJobScheduler().executeJobPool(cascadesContext); injectRowcount(cascadesContext.getMemo().getRoot()); - return HyperGraph.toDPhyperGraph(cascadesContext.getMemo().getRoot()); + return HyperGraph.builderForDPhyper(cascadesContext.getMemo().getRoot()).build(); } public static HyperGraph buildHyperGraphFromPlan(Plan plan) { CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(MemoTestUtils.createConnectContext(), plan); cascadesContext.getJobScheduler().executeJobPool(cascadesContext); - return HyperGraph.toDPhyperGraph(cascadesContext.getMemo().getRoot()); + return HyperGraph.builderForDPhyper(cascadesContext.getMemo().getRoot()).build(); } private void injectRowcount(Group group) {