[feature](Nereids): support merge graph in group (#27353)
This commit is contained in:
@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -45,6 +46,7 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* The graph is a join graph, whose node is the leaf plan and edge is a join operator.
|
||||
@ -56,12 +58,17 @@ public class HyperGraph {
|
||||
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
|
||||
// record all edges that can be placed on the subgraph
|
||||
private final Map<Long, BitSet> treeEdgesCache = new HashMap<>();
|
||||
private final Set<Slot> finalOutputs;
|
||||
|
||||
// Record the complex project expression for some subgraph
|
||||
// e.g. project (a + b)
|
||||
// |-- join(t1.a = t2.b)
|
||||
private final HashMap<Long, List<NamedExpression>> complexProject = new HashMap<>();
|
||||
|
||||
HyperGraph(Set<Slot> finalOutputs) {
|
||||
this.finalOutputs = ImmutableSet.copyOf(finalOutputs);
|
||||
}
|
||||
|
||||
public List<Edge> getEdges() {
|
||||
return edges;
|
||||
}
|
||||
@ -127,7 +134,7 @@ public class HyperGraph {
|
||||
* @param group The group that is the end node in graph
|
||||
* @return return the node index
|
||||
*/
|
||||
public int addDPHyperNode(Group group) {
|
||||
private int addDPHyperNode(Group group) {
|
||||
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
|
||||
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
|
||||
@ -142,7 +149,7 @@ public class HyperGraph {
|
||||
* @param plan The plan that is the end node in graph
|
||||
* @return return the node index
|
||||
*/
|
||||
public int addStructInfoNode(Plan plan) {
|
||||
private int addStructInfoNode(Plan plan) {
|
||||
for (Slot slot : plan.getOutput()) {
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
|
||||
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
|
||||
@ -151,6 +158,15 @@ public class HyperGraph {
|
||||
return nodes.size() - 1;
|
||||
}
|
||||
|
||||
private int addStructInfoNode(List<HyperGraph> childGraphs) {
|
||||
for (Slot slot : childGraphs.get(0).finalOutputs) {
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
|
||||
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
|
||||
}
|
||||
nodes.add(new StructInfoNode(nodes.size(), childGraphs));
|
||||
return nodes.size() - 1;
|
||||
}
|
||||
|
||||
public void updateNode(int idx, Group group) {
|
||||
Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode);
|
||||
nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group));
|
||||
@ -160,6 +176,19 @@ public class HyperGraph {
|
||||
return complexProject;
|
||||
}
|
||||
|
||||
private void addEdgeOfInfo(Edge edge) {
|
||||
long nodeMap = calNodeMap(edge.getInputSlots());
|
||||
Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1,
|
||||
"edge must have more than one ends");
|
||||
this.edges.add(new Edge(edge.getJoin(), edges.size(), null, null, null));
|
||||
long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0));
|
||||
long right = LongBitmap.newBitmapDiff(nodeMap, left);
|
||||
edge.setLeftRequiredNodes(left);
|
||||
edge.setLeftExtendedNodes(left);
|
||||
edge.setRightRequiredNodes(right);
|
||||
edge.setRightExtendedNodes(right);
|
||||
}
|
||||
|
||||
/**
|
||||
* try to add edge for join group
|
||||
*
|
||||
@ -320,16 +349,44 @@ public class HyperGraph {
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
public static HyperGraph toStructInfo(Plan plan) {
|
||||
Preconditions.checkArgument(plan.getGroupExpression().isPresent(),
|
||||
"HyperGraph requires a GroupExpression in ", plan);
|
||||
HyperGraph hyperGraph = new HyperGraph();
|
||||
public static List<HyperGraph> toStructInfo(Plan plan) {
|
||||
HyperGraph hyperGraph = new HyperGraph(plan.getOutputSet());
|
||||
hyperGraph.buildStructInfo(plan);
|
||||
return hyperGraph;
|
||||
return hyperGraph.flatChildren();
|
||||
}
|
||||
|
||||
private List<HyperGraph> flatChildren() {
|
||||
if (nodes.stream().noneMatch(n -> ((StructInfoNode) n).needToFlat())) {
|
||||
return Lists.newArrayList(this);
|
||||
}
|
||||
List<HyperGraph> res = new ArrayList<>();
|
||||
res.add(new HyperGraph(finalOutputs));
|
||||
for (AbstractNode node : nodes) {
|
||||
res = flatChild((StructInfoNode) node, res);
|
||||
}
|
||||
for (Edge edge : edges) {
|
||||
res.forEach(g -> g.addEdgeOfInfo(edge));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private List<HyperGraph> flatChild(StructInfoNode infoNode, List<HyperGraph> hyperGraphs) {
|
||||
if (!infoNode.needToFlat()) {
|
||||
hyperGraphs.forEach(g -> g.addStructInfoNode(infoNode.getPlan()));
|
||||
return hyperGraphs;
|
||||
}
|
||||
return hyperGraphs.stream().flatMap(g ->
|
||||
infoNode.getGraphs().stream().map(subGraph -> {
|
||||
HyperGraph hyperGraph = new HyperGraph(g.finalOutputs);
|
||||
hyperGraph.addStructInfo(g);
|
||||
hyperGraph.addStructInfo(subGraph);
|
||||
return hyperGraph;
|
||||
})
|
||||
).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static HyperGraph toDPhyperGraph(Group group) {
|
||||
HyperGraph hyperGraph = new HyperGraph();
|
||||
HyperGraph hyperGraph = new HyperGraph(group.getLogicalProperties().getOutputSet());
|
||||
hyperGraph.buildDPhyperGraph(group.getLogicalExpressions().get(0));
|
||||
return hyperGraph;
|
||||
}
|
||||
@ -362,15 +419,28 @@ public class HyperGraph {
|
||||
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
|
||||
}
|
||||
|
||||
private void addStructInfo(HyperGraph other) {
|
||||
int offset = this.getNodes().size();
|
||||
other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan()));
|
||||
other.getComplexProject().forEach((t, projectList) ->
|
||||
projectList.forEach(e -> this.addAlias((Alias) e, t << offset)));
|
||||
other.getEdges().forEach(this::addEdgeOfInfo);
|
||||
}
|
||||
|
||||
// Build Graph for matching mv
|
||||
private Pair<BitSet, Long> buildStructInfo(Plan plan) {
|
||||
if (plan instanceof GroupPlan) {
|
||||
Group group = ((GroupPlan) plan).getGroup();
|
||||
if (group.getHyperGraph() == null) {
|
||||
buildStructInfo(group.getLogicalExpressions().get(0).getPlan());
|
||||
} else {
|
||||
//TODO: merge Group
|
||||
buildStructInfo(group.getLogicalExpressions().get(0).getPlan());
|
||||
List<HyperGraph> childGraphs = ((GroupPlan) plan).getGroup().getHyperGraphs();
|
||||
if (childGraphs.size() != 0) {
|
||||
int idx = addStructInfoNode(childGraphs);
|
||||
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
|
||||
}
|
||||
GroupExpression groupExpression = group.getLogicalExpressions().get(0);
|
||||
buildStructInfo(groupExpression.getPlan()
|
||||
.withChildren(
|
||||
groupExpression.children().stream().map(GroupPlan::new).collect(Collectors.toList())));
|
||||
}
|
||||
// process Project
|
||||
if (isValidProject(plan)) {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
|
||||
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -27,6 +28,9 @@ import java.util.List;
|
||||
* HyperGraph Node.
|
||||
*/
|
||||
public class StructInfoNode extends AbstractNode {
|
||||
|
||||
private List<HyperGraph> graphs = new ArrayList<>();
|
||||
|
||||
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
|
||||
super(plan, index, edges);
|
||||
}
|
||||
@ -34,4 +38,18 @@ public class StructInfoNode extends AbstractNode {
|
||||
public StructInfoNode(int index, Plan plan) {
|
||||
this(index, plan, new ArrayList<>());
|
||||
}
|
||||
|
||||
public StructInfoNode(int index, List<HyperGraph> graphs) {
|
||||
this(index, graphs.get(0).getNode(0).getPlan(), new ArrayList<>());
|
||||
this.graphs = graphs;
|
||||
}
|
||||
|
||||
public boolean needToFlat() {
|
||||
return !graphs.isEmpty();
|
||||
}
|
||||
|
||||
public List<HyperGraph> getGraphs() {
|
||||
return graphs;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -49,7 +49,6 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Representation for group in cascades optimizer.
|
||||
@ -419,8 +418,8 @@ public class Group {
|
||||
return false;
|
||||
}
|
||||
|
||||
public @Nullable HyperGraph getHyperGraph() {
|
||||
return null;
|
||||
public List<HyperGraph> getHyperGraphs() {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
public boolean isProjectGroup() {
|
||||
|
||||
Reference in New Issue
Block a user