[feature](Nereids): support merge graph in group (#27353)

This commit is contained in:
谢健
2023-11-27 11:48:38 +08:00
committed by GitHub
parent 0e1e4c8508
commit 331effdb20
3 changed files with 102 additions and 15 deletions

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
@ -45,6 +46,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* The graph is a join graph, whose node is the leaf plan and edge is a join operator.
@ -56,12 +58,17 @@ public class HyperGraph {
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
// record all edges that can be placed on the subgraph
private final Map<Long, BitSet> treeEdgesCache = new HashMap<>();
private final Set<Slot> finalOutputs;
// Record the complex project expression for some subgraph
// e.g. project (a + b)
// |-- join(t1.a = t2.b)
private final HashMap<Long, List<NamedExpression>> complexProject = new HashMap<>();
HyperGraph(Set<Slot> finalOutputs) {
this.finalOutputs = ImmutableSet.copyOf(finalOutputs);
}
public List<Edge> getEdges() {
return edges;
}
@ -127,7 +134,7 @@ public class HyperGraph {
* @param group The group that is the end node in graph
* @return return the node index
*/
public int addDPHyperNode(Group group) {
private int addDPHyperNode(Group group) {
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
@ -142,7 +149,7 @@ public class HyperGraph {
* @param plan The plan that is the end node in graph
* @return return the node index
*/
public int addStructInfoNode(Plan plan) {
private int addStructInfoNode(Plan plan) {
for (Slot slot : plan.getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
@ -151,6 +158,15 @@ public class HyperGraph {
return nodes.size() - 1;
}
private int addStructInfoNode(List<HyperGraph> childGraphs) {
for (Slot slot : childGraphs.get(0).finalOutputs) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
}
nodes.add(new StructInfoNode(nodes.size(), childGraphs));
return nodes.size() - 1;
}
public void updateNode(int idx, Group group) {
Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode);
nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group));
@ -160,6 +176,19 @@ public class HyperGraph {
return complexProject;
}
private void addEdgeOfInfo(Edge edge) {
long nodeMap = calNodeMap(edge.getInputSlots());
Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1,
"edge must have more than one ends");
this.edges.add(new Edge(edge.getJoin(), edges.size(), null, null, null));
long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0));
long right = LongBitmap.newBitmapDiff(nodeMap, left);
edge.setLeftRequiredNodes(left);
edge.setLeftExtendedNodes(left);
edge.setRightRequiredNodes(right);
edge.setRightExtendedNodes(right);
}
/**
* try to add edge for join group
*
@ -320,16 +349,44 @@ public class HyperGraph {
return bitmap;
}
public static HyperGraph toStructInfo(Plan plan) {
Preconditions.checkArgument(plan.getGroupExpression().isPresent(),
"HyperGraph requires a GroupExpression in ", plan);
HyperGraph hyperGraph = new HyperGraph();
public static List<HyperGraph> toStructInfo(Plan plan) {
HyperGraph hyperGraph = new HyperGraph(plan.getOutputSet());
hyperGraph.buildStructInfo(plan);
return hyperGraph;
return hyperGraph.flatChildren();
}
private List<HyperGraph> flatChildren() {
if (nodes.stream().noneMatch(n -> ((StructInfoNode) n).needToFlat())) {
return Lists.newArrayList(this);
}
List<HyperGraph> res = new ArrayList<>();
res.add(new HyperGraph(finalOutputs));
for (AbstractNode node : nodes) {
res = flatChild((StructInfoNode) node, res);
}
for (Edge edge : edges) {
res.forEach(g -> g.addEdgeOfInfo(edge));
}
return res;
}
private List<HyperGraph> flatChild(StructInfoNode infoNode, List<HyperGraph> hyperGraphs) {
if (!infoNode.needToFlat()) {
hyperGraphs.forEach(g -> g.addStructInfoNode(infoNode.getPlan()));
return hyperGraphs;
}
return hyperGraphs.stream().flatMap(g ->
infoNode.getGraphs().stream().map(subGraph -> {
HyperGraph hyperGraph = new HyperGraph(g.finalOutputs);
hyperGraph.addStructInfo(g);
hyperGraph.addStructInfo(subGraph);
return hyperGraph;
})
).collect(Collectors.toList());
}
public static HyperGraph toDPhyperGraph(Group group) {
HyperGraph hyperGraph = new HyperGraph();
HyperGraph hyperGraph = new HyperGraph(group.getLogicalProperties().getOutputSet());
hyperGraph.buildDPhyperGraph(group.getLogicalExpressions().get(0));
return hyperGraph;
}
@ -362,15 +419,28 @@ public class HyperGraph {
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
}
private void addStructInfo(HyperGraph other) {
int offset = this.getNodes().size();
other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan()));
other.getComplexProject().forEach((t, projectList) ->
projectList.forEach(e -> this.addAlias((Alias) e, t << offset)));
other.getEdges().forEach(this::addEdgeOfInfo);
}
// Build Graph for matching mv
private Pair<BitSet, Long> buildStructInfo(Plan plan) {
if (plan instanceof GroupPlan) {
Group group = ((GroupPlan) plan).getGroup();
if (group.getHyperGraph() == null) {
buildStructInfo(group.getLogicalExpressions().get(0).getPlan());
} else {
//TODO: merge Group
buildStructInfo(group.getLogicalExpressions().get(0).getPlan());
List<HyperGraph> childGraphs = ((GroupPlan) plan).getGroup().getHyperGraphs();
if (childGraphs.size() != 0) {
int idx = addStructInfoNode(childGraphs);
return Pair.of(new BitSet(), LongBitmap.newBitmap(idx));
}
GroupExpression groupExpression = group.getLogicalExpressions().get(0);
buildStructInfo(groupExpression.getPlan()
.withChildren(
groupExpression.children().stream().map(GroupPlan::new).collect(Collectors.toList())));
}
// process Project
if (isValidProject(plan)) {

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.ArrayList;
@ -27,6 +28,9 @@ import java.util.List;
* HyperGraph Node.
*/
public class StructInfoNode extends AbstractNode {
private List<HyperGraph> graphs = new ArrayList<>();
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
super(plan, index, edges);
}
@ -34,4 +38,18 @@ public class StructInfoNode extends AbstractNode {
public StructInfoNode(int index, Plan plan) {
this(index, plan, new ArrayList<>());
}
public StructInfoNode(int index, List<HyperGraph> graphs) {
this(index, graphs.get(0).getNode(0).getPlan(), new ArrayList<>());
this.graphs = graphs;
}
public boolean needToFlat() {
return !graphs.isEmpty();
}
public List<HyperGraph> getGraphs() {
return graphs;
}
}

View File

@ -49,7 +49,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* Representation for group in cascades optimizer.
@ -419,8 +418,8 @@ public class Group {
return false;
}
public @Nullable HyperGraph getHyperGraph() {
return null;
public List<HyperGraph> getHyperGraphs() {
return new ArrayList<>();
}
public boolean isProjectGroup() {