diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 003a8d4a2f..3ffd159e14 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.PlanUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import java.util.ArrayList; @@ -45,6 +46,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * The graph is a join graph, whose node is the leaf plan and edge is a join operator. @@ -56,12 +58,17 @@ public class HyperGraph { private final HashMap slotToNodeMap = new HashMap<>(); // record all edges that can be placed on the subgraph private final Map treeEdgesCache = new HashMap<>(); + private final Set finalOutputs; // Record the complex project expression for some subgraph // e.g. project (a + b) // |-- join(t1.a = t2.b) private final HashMap> complexProject = new HashMap<>(); + HyperGraph(Set finalOutputs) { + this.finalOutputs = ImmutableSet.copyOf(finalOutputs); + } + public List getEdges() { return edges; } @@ -127,7 +134,7 @@ public class HyperGraph { * @param group The group that is the end node in graph * @return return the node index */ - public int addDPHyperNode(Group group) { + private int addDPHyperNode(Group group) { for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) { Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); @@ -142,7 +149,7 @@ public class HyperGraph { * @param plan The plan that is the end node in graph * @return return the node index */ - public int addStructInfoNode(Plan plan) { + private int addStructInfoNode(Plan plan) { for (Slot slot : plan.getOutput()) { Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); @@ -151,6 +158,15 @@ public class HyperGraph { return nodes.size() - 1; } + private int addStructInfoNode(List childGraphs) { + for (Slot slot : childGraphs.get(0).finalOutputs) { + Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); + slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); + } + nodes.add(new StructInfoNode(nodes.size(), childGraphs)); + return nodes.size() - 1; + } + public void updateNode(int idx, Group group) { Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode); nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group)); @@ -160,6 +176,19 @@ public class HyperGraph { return complexProject; } + private void addEdgeOfInfo(Edge edge) { + long nodeMap = calNodeMap(edge.getInputSlots()); + Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1, + "edge must have more than one ends"); + this.edges.add(new Edge(edge.getJoin(), edges.size(), null, null, null)); + long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0)); + long right = LongBitmap.newBitmapDiff(nodeMap, left); + edge.setLeftRequiredNodes(left); + edge.setLeftExtendedNodes(left); + edge.setRightRequiredNodes(right); + edge.setRightExtendedNodes(right); + } + /** * try to add edge for join group * @@ -320,16 +349,44 @@ public class HyperGraph { return bitmap; } - public static HyperGraph toStructInfo(Plan plan) { - Preconditions.checkArgument(plan.getGroupExpression().isPresent(), - "HyperGraph requires a GroupExpression in ", plan); - HyperGraph hyperGraph = new HyperGraph(); + public static List toStructInfo(Plan plan) { + HyperGraph hyperGraph = new HyperGraph(plan.getOutputSet()); hyperGraph.buildStructInfo(plan); - return hyperGraph; + return hyperGraph.flatChildren(); + } + + private List flatChildren() { + if (nodes.stream().noneMatch(n -> ((StructInfoNode) n).needToFlat())) { + return Lists.newArrayList(this); + } + List res = new ArrayList<>(); + res.add(new HyperGraph(finalOutputs)); + for (AbstractNode node : nodes) { + res = flatChild((StructInfoNode) node, res); + } + for (Edge edge : edges) { + res.forEach(g -> g.addEdgeOfInfo(edge)); + } + return res; + } + + private List flatChild(StructInfoNode infoNode, List hyperGraphs) { + if (!infoNode.needToFlat()) { + hyperGraphs.forEach(g -> g.addStructInfoNode(infoNode.getPlan())); + return hyperGraphs; + } + return hyperGraphs.stream().flatMap(g -> + infoNode.getGraphs().stream().map(subGraph -> { + HyperGraph hyperGraph = new HyperGraph(g.finalOutputs); + hyperGraph.addStructInfo(g); + hyperGraph.addStructInfo(subGraph); + return hyperGraph; + }) + ).collect(Collectors.toList()); } public static HyperGraph toDPhyperGraph(Group group) { - HyperGraph hyperGraph = new HyperGraph(); + HyperGraph hyperGraph = new HyperGraph(group.getLogicalProperties().getOutputSet()); hyperGraph.buildDPhyperGraph(group.getLogicalExpressions().get(0)); return hyperGraph; } @@ -362,15 +419,28 @@ public class HyperGraph { return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); } + private void addStructInfo(HyperGraph other) { + int offset = this.getNodes().size(); + other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan())); + other.getComplexProject().forEach((t, projectList) -> + projectList.forEach(e -> this.addAlias((Alias) e, t << offset))); + other.getEdges().forEach(this::addEdgeOfInfo); + } + // Build Graph for matching mv private Pair buildStructInfo(Plan plan) { if (plan instanceof GroupPlan) { Group group = ((GroupPlan) plan).getGroup(); - if (group.getHyperGraph() == null) { - buildStructInfo(group.getLogicalExpressions().get(0).getPlan()); - } else { - //TODO: merge Group + buildStructInfo(group.getLogicalExpressions().get(0).getPlan()); + List childGraphs = ((GroupPlan) plan).getGroup().getHyperGraphs(); + if (childGraphs.size() != 0) { + int idx = addStructInfoNode(childGraphs); + return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); } + GroupExpression groupExpression = group.getLogicalExpressions().get(0); + buildStructInfo(groupExpression.getPlan() + .withChildren( + groupExpression.children().stream().map(GroupPlan::new).collect(Collectors.toList()))); } // process Project if (isValidProject(plan)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java index 0cca2e1aa3..29618fbd4a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node; import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.trees.plans.Plan; import java.util.ArrayList; @@ -27,6 +28,9 @@ import java.util.List; * HyperGraph Node. */ public class StructInfoNode extends AbstractNode { + + private List graphs = new ArrayList<>(); + public StructInfoNode(int index, Plan plan, List edges) { super(plan, index, edges); } @@ -34,4 +38,18 @@ public class StructInfoNode extends AbstractNode { public StructInfoNode(int index, Plan plan) { this(index, plan, new ArrayList<>()); } + + public StructInfoNode(int index, List graphs) { + this(index, graphs.get(0).getNode(0).getPlan(), new ArrayList<>()); + this.graphs = graphs; + } + + public boolean needToFlat() { + return !graphs.isEmpty(); + } + + public List getGraphs() { + return graphs; + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index 8d09c8afac..3f20fcc391 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -49,7 +49,6 @@ import java.util.Objects; import java.util.Optional; import java.util.function.Function; import java.util.stream.Collectors; -import javax.annotation.Nullable; /** * Representation for group in cascades optimizer. @@ -419,8 +418,8 @@ public class Group { return false; } - public @Nullable HyperGraph getHyperGraph() { - return null; + public List getHyperGraphs() { + return new ArrayList<>(); } public boolean isProjectGroup() {