[fix](Nereids): Update plan when prune column in DPHyp (#14880)
This commit is contained in:
@ -29,6 +29,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
|
||||
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
|
||||
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
|
||||
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.metrics.event.CounterEvent;
|
||||
@ -36,9 +37,11 @@ import org.apache.doris.nereids.processor.post.PlanPostProcessors;
|
||||
import org.apache.doris.nereids.processor.pre.PlanPreprocessors;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.planner.PlanFragment;
|
||||
import org.apache.doris.planner.Planner;
|
||||
@ -209,6 +212,24 @@ public class NereidsPlanner extends Planner {
|
||||
}
|
||||
|
||||
private void joinReorder() {
|
||||
Group root = getRoot();
|
||||
boolean changeRoot = false;
|
||||
if (root.isJoinGroup()) {
|
||||
// If the root group is join group, DPHyp can change the root group.
|
||||
// To keep the root group is not changed, we add a project operator above join
|
||||
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
|
||||
GroupExpression newExpr = new GroupExpression(
|
||||
new LogicalProject(outputs, root.getLogicalExpression().getPlan()),
|
||||
Lists.newArrayList(root));
|
||||
root = new Group();
|
||||
root.addGroupExpression(newExpr);
|
||||
changeRoot = true;
|
||||
}
|
||||
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
|
||||
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
|
||||
if (changeRoot) {
|
||||
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -17,24 +17,36 @@
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.jobs.Job;
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.jobs.JobType;
|
||||
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Join Order job with DPHyp
|
||||
*/
|
||||
public class JoinOrderJob extends Job {
|
||||
private final Group group;
|
||||
private final Set<NamedExpression> otherProject = new HashSet<>();
|
||||
|
||||
public JoinOrderJob(Group group, JobContext context) {
|
||||
super(JobType.JOIN_ORDER, context);
|
||||
@ -43,12 +55,16 @@ public class JoinOrderJob extends Job {
|
||||
|
||||
@Override
|
||||
public void execute() throws AnalysisException {
|
||||
Preconditions.checkArgument(!group.isJoinGroup());
|
||||
GroupExpression rootExpr = group.getLogicalExpression();
|
||||
int arity = rootExpr.arity();
|
||||
for (int i = 0; i < arity; i++) {
|
||||
rootExpr.setChild(i, optimizePlan(rootExpr.child(i)));
|
||||
}
|
||||
CascadesContext cascadesContext = context.getCascadesContext();
|
||||
cascadesContext.topDownRewrite(new ColumnPruning());
|
||||
cascadesContext.pushJob(
|
||||
new DeriveStatsJob(group.getLogicalExpression(), cascadesContext.getCurrentJobContext()));
|
||||
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
|
||||
}
|
||||
|
||||
private Group optimizePlan(Group group) {
|
||||
@ -66,7 +82,7 @@ public class JoinOrderJob extends Job {
|
||||
private Group optimizeJoin(Group group) {
|
||||
HyperGraph hyperGraph = new HyperGraph();
|
||||
buildGraph(group, hyperGraph);
|
||||
// Right now, we just hardcode the limit with 10000, maybe we need a better way to set it
|
||||
// TODO: Right now, we just hardcode the limit with 10000, maybe we need a better way to set it
|
||||
int limit = 10000;
|
||||
PlanReceiver planReceiver = new PlanReceiver(limit);
|
||||
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
|
||||
@ -77,13 +93,22 @@ public class JoinOrderJob extends Job {
|
||||
throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit);
|
||||
}
|
||||
}
|
||||
|
||||
Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap());
|
||||
return copyToMemo(optimized);
|
||||
Group memoRoot = copyToMemo(optimized);
|
||||
|
||||
// For other projects, such as project constant or project nullable, we construct a new project above root
|
||||
if (otherProject.size() != 0) {
|
||||
otherProject.addAll(memoRoot.getLogicalExpression().getPlan().getOutput());
|
||||
LogicalProject logicalProject = new LogicalProject(new ArrayList<>(otherProject),
|
||||
memoRoot.getLogicalExpression().getPlan());
|
||||
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
|
||||
memoRoot = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression);
|
||||
}
|
||||
return memoRoot;
|
||||
}
|
||||
|
||||
private Group copyToMemo(Group root) {
|
||||
if (!root.isJoinGroup()) {
|
||||
if (root.getGroupId() != null) {
|
||||
return root;
|
||||
}
|
||||
GroupExpression groupExpression = root.getLogicalExpression();
|
||||
@ -105,6 +130,11 @@ public class JoinOrderJob extends Job {
|
||||
* @param hyperGraph build hyperGraph
|
||||
*/
|
||||
public void buildGraph(Group group, HyperGraph hyperGraph) {
|
||||
if (group.isProjectGroup()) {
|
||||
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
|
||||
processProjectPlan(hyperGraph, group);
|
||||
return;
|
||||
}
|
||||
if (!group.isJoinGroup()) {
|
||||
hyperGraph.addNode(optimizePlan(group));
|
||||
return;
|
||||
@ -113,4 +143,26 @@ public class JoinOrderJob extends Job {
|
||||
buildGraph(group.getLogicalExpression().child(1), hyperGraph);
|
||||
hyperGraph.addEdge(group);
|
||||
}
|
||||
|
||||
/**
|
||||
* Process project expression in HyperGraph
|
||||
* 1. If it's a simple expression for column pruning, we just ignore it
|
||||
* 2. If it's an alias that may be used in the join operator, we need to add it to graph
|
||||
* 3. If it's other expression, we can ignore them and add it after optimizing
|
||||
* 4. If it's a project only associate with one table, it's seen as an endNode just like a table
|
||||
*/
|
||||
private void processProjectPlan(HyperGraph hyperGraph, Group group) {
|
||||
LogicalProject<? extends Plan> logicalProject = (LogicalProject<? extends Plan>) group.getLogicalExpression()
|
||||
.getPlan();
|
||||
|
||||
for (NamedExpression expr : logicalProject.getProjects()) {
|
||||
if (expr.isAlias()) {
|
||||
if (!hyperGraph.addAlias((Alias) expr, group)) {
|
||||
break;
|
||||
}
|
||||
} else if (!expr.isSlot()) {
|
||||
otherProject.add(expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,7 +19,9 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
|
||||
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -29,6 +31,8 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@ -37,8 +41,11 @@ import java.util.Set;
|
||||
* It's used for join ordering
|
||||
*/
|
||||
public class HyperGraph {
|
||||
private List<Edge> edges = new ArrayList<>();
|
||||
private List<Node> nodes = new ArrayList<>();
|
||||
private final List<Edge> edges = new ArrayList<>();
|
||||
private final List<Node> nodes = new ArrayList<>();
|
||||
private final HashSet<Group> nodeSet = new HashSet<>();
|
||||
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
|
||||
private final HashMap<Long, NamedExpression> complexProject = new HashMap<>();
|
||||
|
||||
public List<Edge> getEdges() {
|
||||
return edges;
|
||||
@ -60,12 +67,59 @@ public class HyperGraph {
|
||||
return nodes.get(index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Store the relation between Alias Slot and Original Slot and its expression
|
||||
* e.g. a = b
|
||||
* project((c + d) as b)
|
||||
* Note if the alias if the alias only associated with one endNode,
|
||||
* e.g. a = b
|
||||
* project((c + 1) as b)
|
||||
* we need to replace the group of that node with this project group.
|
||||
*
|
||||
* @param alias The alias Expression in project Operator
|
||||
*/
|
||||
public boolean addAlias(Alias alias, Group group) {
|
||||
long bitmap = LongBitmap.newBitmap();
|
||||
for (Slot slot : alias.getInputSlots()) {
|
||||
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
|
||||
}
|
||||
Slot aliasSlot = alias.toSlot();
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(aliasSlot));
|
||||
slotToNodeMap.put(aliasSlot, bitmap);
|
||||
if (LongBitmap.getCardinality(bitmap) == 1) {
|
||||
int index = LongBitmap.lowestOneIndex(bitmap);
|
||||
nodeSet.remove(nodes.get(index).getGroup());
|
||||
nodeSet.add(group);
|
||||
nodes.get(index).replaceGroupWith(group);
|
||||
return false;
|
||||
}
|
||||
complexProject.put(bitmap, alias);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* add end node to HyperGraph
|
||||
*
|
||||
* @param group The group that is the end node in graph
|
||||
*/
|
||||
public void addNode(Group group) {
|
||||
Preconditions.checkArgument(!group.isJoinGroup());
|
||||
// TODO: replace plan with group expression or others
|
||||
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
|
||||
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
|
||||
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
|
||||
}
|
||||
nodeSet.add(group);
|
||||
nodes.add(new Node(nodes.size(), group));
|
||||
}
|
||||
|
||||
public boolean isNodeGroup(Group group) {
|
||||
return nodeSet.contains(group);
|
||||
}
|
||||
|
||||
public HashMap<Long, NamedExpression> getComplexProject() {
|
||||
return complexProject;
|
||||
}
|
||||
|
||||
/**
|
||||
* try to add edge for join group
|
||||
*
|
||||
@ -78,15 +132,12 @@ public class HyperGraph {
|
||||
LogicalJoin singleJoin = new LogicalJoin(join.getJoinType(), ImmutableList.of(expression), join.left(),
|
||||
join.right());
|
||||
Edge edge = new Edge(singleJoin, edges.size());
|
||||
long bitmap = findNodes(expression.getInputSlots());
|
||||
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 2,
|
||||
String.format("HyperGraph has not supported polynomial %s yet", expression));
|
||||
int leftIndex = LongBitmap.nextSetBit(bitmap, 0);
|
||||
long left = LongBitmap.newBitmap(leftIndex);
|
||||
edge.addLeftNode(left);
|
||||
int rightIndex = LongBitmap.nextSetBit(bitmap, leftIndex + 1);
|
||||
long right = LongBitmap.newBitmap(rightIndex);
|
||||
edge.addRightNode(right);
|
||||
Preconditions.checkArgument(expression.children().size() == 2);
|
||||
|
||||
long left = calNodeMap(expression.child(0).getInputSlots());
|
||||
edge.setLeft(left);
|
||||
long right = calNodeMap(expression.child(1).getInputSlots());
|
||||
edge.setRight(right);
|
||||
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
|
||||
nodes.get(nodeIndex).attachEdge(edge);
|
||||
}
|
||||
@ -96,15 +147,12 @@ public class HyperGraph {
|
||||
// We don't implement this trick now.
|
||||
}
|
||||
|
||||
private long findNodes(Set<Slot> slots) {
|
||||
private long calNodeMap(Set<Slot> slots) {
|
||||
Preconditions.checkArgument(slots.size() != 0);
|
||||
long bitmap = LongBitmap.newBitmap();
|
||||
for (Node node : nodes) {
|
||||
for (Slot slot : node.getPlan().getOutput()) {
|
||||
if (slots.contains(slot)) {
|
||||
bitmap = LongBitmap.set(bitmap, node.getIndex());
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (Slot slot : slots) {
|
||||
Preconditions.checkArgument(slotToNodeMap.containsKey(slot));
|
||||
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
|
||||
}
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
@ -37,6 +37,10 @@ public class Node {
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
public void replaceGroupWith(Group group) {
|
||||
this.group = group;
|
||||
}
|
||||
|
||||
public int getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ public class SubgraphEnumerator {
|
||||
if (edges.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
if (!receiver.emitCsgCmp(csg, newCmp, edges)) {
|
||||
if (!receiver.emitCsgCmp(csg, newCmp, edges, hyperGraph.getComplexProject())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -144,16 +144,15 @@ public class SubgraphEnumerator {
|
||||
forbiddenNodes = LongBitmap.or(forbiddenNodes, csg);
|
||||
long neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, LongBitmap.clone(forbiddenNodes),
|
||||
edgeCalculator);
|
||||
|
||||
for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) {
|
||||
long cmp = LongBitmap.newBitmap(nodeIndex);
|
||||
// whether there is an edge between csg and cmp
|
||||
List<Edge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
|
||||
if (edges.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
if (!receiver.emitCsgCmp(csg, cmp, edges)) {
|
||||
return false;
|
||||
|
||||
if (!edges.isEmpty()) {
|
||||
if (!receiver.emitCsgCmp(csg, cmp, edges, hyperGraph.getComplexProject())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// In order to avoid enumerate repeated cmp, e.g.,
|
||||
@ -191,7 +190,6 @@ public class SubgraphEnumerator {
|
||||
forbiddenNodes = LongBitmap.or(forbiddenNodes, subgraph);
|
||||
neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
|
||||
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
|
||||
|
||||
for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
|
||||
long left = edge.getLeft();
|
||||
long right = edge.getRight();
|
||||
@ -268,11 +266,11 @@ public class SubgraphEnumerator {
|
||||
simpleContains.or(containSimpleEdges.get(bitmap1));
|
||||
simpleContains.or(containSimpleEdges.get(bitmap2));
|
||||
BitSet complexContains = new BitSet();
|
||||
simpleContains.or(containComplexEdges.get(bitmap1));
|
||||
simpleContains.or(containComplexEdges.get(bitmap2));
|
||||
complexContains.or(containComplexEdges.get(bitmap1));
|
||||
complexContains.or(containComplexEdges.get(bitmap2));
|
||||
BitSet overlaps = new BitSet();
|
||||
simpleContains.or(overlapEdges.get(bitmap1));
|
||||
simpleContains.or(overlapEdges.get(bitmap2));
|
||||
overlaps.or(overlapEdges.get(bitmap1));
|
||||
overlaps.or(overlapEdges.get(bitmap2));
|
||||
for (int index : overlaps.stream().toArray()) {
|
||||
Edge edge = edges.get(index);
|
||||
if (isContainEdge(subgraph, edge)) {
|
||||
|
||||
@ -19,20 +19,23 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
|
||||
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A interface of receiver
|
||||
*/
|
||||
public interface AbstractReceiver {
|
||||
public boolean emitCsgCmp(long csg, long cmp, List<Edge> edges);
|
||||
boolean emitCsgCmp(long csg, long cmp, List<Edge> edges,
|
||||
HashMap<Long, NamedExpression> projectExpression);
|
||||
|
||||
public void addGroup(long bitSet, Group group);
|
||||
void addGroup(long bitSet, Group group);
|
||||
|
||||
public boolean contain(long bitSet);
|
||||
boolean contain(long bitSet);
|
||||
|
||||
public void reset();
|
||||
void reset();
|
||||
|
||||
public Group getBestPlan(long bitSet);
|
||||
Group getBestPlan(long bitSet);
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
@ -31,9 +32,9 @@ import java.util.List;
|
||||
*/
|
||||
public class Counter implements AbstractReceiver {
|
||||
// limit define the max number of csg-cmp pair in this Receiver
|
||||
private int limit;
|
||||
private final int limit;
|
||||
private int emitCount = 0;
|
||||
private HashMap<Long, Integer> counter = new HashMap<>();
|
||||
private final HashMap<Long, Integer> counter = new HashMap<>();
|
||||
|
||||
public Counter() {
|
||||
this.limit = Integer.MAX_VALUE;
|
||||
@ -51,7 +52,8 @@ public class Counter implements AbstractReceiver {
|
||||
* @param edges the join operator
|
||||
* @return the left and the right can be connected by the edge
|
||||
*/
|
||||
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
|
||||
public boolean emitCsgCmp(long left, long right, List<Edge> edges,
|
||||
HashMap<Long, NamedExpression> projectExpression) {
|
||||
Preconditions.checkArgument(counter.containsKey(left));
|
||||
Preconditions.checkArgument(counter.containsKey(right));
|
||||
emitCount += 1;
|
||||
|
||||
@ -24,8 +24,9 @@ import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.stats.StatsCalculator;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
@ -60,7 +61,8 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
* @return the left and the right can be connected by the edge
|
||||
*/
|
||||
@Override
|
||||
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
|
||||
public boolean emitCsgCmp(long left, long right, List<Edge> edges,
|
||||
HashMap<Long, NamedExpression> projectExpression) {
|
||||
Preconditions.checkArgument(planTable.containsKey(left));
|
||||
Preconditions.checkArgument(planTable.containsKey(right));
|
||||
emitCount += 1;
|
||||
@ -81,9 +83,12 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
if (!planTable.containsKey(fullKey)
|
||||
|| planTable.get(fullKey).getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)
|
||||
> winnerGroup.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)) {
|
||||
// When we decide to store the new Plan, we need to add the complex project to it.
|
||||
winnerGroup = tryAddProject(winnerGroup, projectExpression, fullKey);
|
||||
planTable.put(fullKey, winnerGroup);
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -108,11 +113,39 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
return planTable.get(bitmap);
|
||||
}
|
||||
|
||||
private double getSimpleCost(Plan plan) {
|
||||
if (!(plan instanceof LogicalJoin)) {
|
||||
return plan.getGroupExpression().get().getOwnerGroup().getStatistics().getRowCount();
|
||||
private double getSimpleCost(Group group) {
|
||||
if (!group.isJoinGroup()) {
|
||||
return group.getStatistics().getRowCount();
|
||||
}
|
||||
return plan.getGroupExpression().get().getCostByProperties(PhysicalProperties.ANY);
|
||||
return group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY);
|
||||
}
|
||||
|
||||
private Group tryAddProject(Group group, HashMap<Long, NamedExpression> projectExpression, long fullKey) {
|
||||
List<NamedExpression> projects = new ArrayList<>();
|
||||
List<Long> removedKey = new ArrayList<>();
|
||||
for (Long bitmap : projectExpression.keySet()) {
|
||||
if (LongBitmap.isSubset(bitmap, fullKey)) {
|
||||
NamedExpression namedExpression = projectExpression.get(bitmap);
|
||||
projects.add(namedExpression);
|
||||
removedKey.add(bitmap);
|
||||
}
|
||||
}
|
||||
for (Long bitmap : removedKey) {
|
||||
projectExpression.remove(bitmap);
|
||||
}
|
||||
if (projects.size() != 0) {
|
||||
LogicalProject logicalProject = new LogicalProject<>(projects,
|
||||
group.getLogicalExpression().getPlan());
|
||||
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
|
||||
groupExpression.updateLowestCostTable(PhysicalProperties.ANY,
|
||||
Lists.newArrayList(PhysicalProperties.ANY, PhysicalProperties.ANY),
|
||||
group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY));
|
||||
Group projectGroup = new Group();
|
||||
projectGroup.addGroupExpression(groupExpression);
|
||||
StatsCalculator.estimate(groupExpression);
|
||||
return projectGroup;
|
||||
}
|
||||
return group;
|
||||
}
|
||||
|
||||
private Group constructGroup(long left, long right, List<Edge> edges) {
|
||||
@ -120,10 +153,8 @@ public class PlanReceiver implements AbstractReceiver {
|
||||
Preconditions.checkArgument(planTable.containsKey(right));
|
||||
Group leftGroup = planTable.get(left);
|
||||
Group rightGroup = planTable.get(right);
|
||||
Plan leftPlan = leftGroup.getLogicalExpression().getPlan();
|
||||
Plan rightPlan = rightGroup.getLogicalExpression().getPlan();
|
||||
|
||||
double cost = getSimpleCost(leftPlan) + getSimpleCost(rightPlan);
|
||||
double cost = getSimpleCost(leftGroup) + getSimpleCost(rightGroup);
|
||||
List<Expression> conditions = new ArrayList<>();
|
||||
for (Edge edge : edges) {
|
||||
conditions.addAll(edge.getJoin().getExpressions());
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.TreeStringUtils;
|
||||
import org.apache.doris.statistics.StatsDeriveResult;
|
||||
|
||||
@ -322,6 +323,10 @@ public class Group {
|
||||
return getLogicalExpression().getPlan() instanceof LogicalJoin;
|
||||
}
|
||||
|
||||
public boolean isProjectGroup() {
|
||||
return getLogicalExpression().getPlan() instanceof LogicalProject;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@ -60,7 +60,7 @@ public class Memo {
|
||||
private final Map<GroupId, Group> groups = Maps.newLinkedHashMap();
|
||||
// we could not use Set, because Set does not have get method.
|
||||
private final Map<GroupExpression, GroupExpression> groupExpressions = Maps.newHashMap();
|
||||
private final Group root;
|
||||
private Group root;
|
||||
|
||||
// FOR TEST ONLY
|
||||
public Memo() {
|
||||
@ -71,10 +71,24 @@ public class Memo {
|
||||
root = init(plan);
|
||||
}
|
||||
|
||||
public static long getStateId() {
|
||||
return stateId;
|
||||
}
|
||||
|
||||
public Group getRoot() {
|
||||
return root;
|
||||
}
|
||||
|
||||
/**
|
||||
* This function used to update the root group when DPHyp change the root Group
|
||||
* Note it only used in DPHyp
|
||||
*
|
||||
* @param root The new root Group
|
||||
*/
|
||||
public void setRoot(Group root) {
|
||||
this.root = root;
|
||||
}
|
||||
|
||||
public List<Group> getGroups() {
|
||||
return ImmutableList.copyOf(groups.values());
|
||||
}
|
||||
@ -83,10 +97,6 @@ public class Memo {
|
||||
return groupExpressions;
|
||||
}
|
||||
|
||||
public static long getStateId() {
|
||||
return stateId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add plan to Memo.
|
||||
*
|
||||
|
||||
@ -143,6 +143,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
return this instanceof Slot;
|
||||
}
|
||||
|
||||
public boolean isAlias() {
|
||||
return this instanceof Alias;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class JoinOrderJobTest {
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(1, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(2, "t2", 0);
|
||||
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(3, "t3", 0);
|
||||
private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(4, "t4", 0);
|
||||
private final LogicalOlapScan scan5 = PlanConstructor.newLogicalOlapScan(5, "t5", 0);
|
||||
|
||||
@Test
|
||||
void testJoinOrderJob() {
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(
|
||||
new LogicalPlanBuilder(scan2)
|
||||
.hashJoinUsing(scan3, JoinType.INNER_JOIN, Pair.of(0, 1))
|
||||
.hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 1))
|
||||
.hashJoinUsing(scan5, JoinType.INNER_JOIN, Pair.of(0, 1))
|
||||
.build(),
|
||||
JoinType.INNER_JOIN, Pair.of(0, 1)
|
||||
)
|
||||
.project(Lists.newArrayList(1))
|
||||
.build();
|
||||
System.out.println(plan.treeString());
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.deriveStats()
|
||||
.orderJoin()
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder;
|
||||
|
||||
import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
|
||||
import org.apache.doris.nereids.datasets.tpch.TPCHUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class TPCHTest extends TPCHTestBase {
|
||||
@Test
|
||||
void testQ5() {
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(TPCHUtils.Q5)
|
||||
.rewrite()
|
||||
.deriveStats()
|
||||
.orderJoin()
|
||||
.optimize();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.sqltest;
|
||||
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class JoinOrderJobTest extends SqlTestBase {
|
||||
@Test
|
||||
protected void testSimpleSQL() {
|
||||
String sql = "select * from T1, T2, T3, T4 "
|
||||
+ "where "
|
||||
+ "T1.id = T2.id and "
|
||||
+ "T2.score = T3.score and "
|
||||
+ "T3.id = T4.id";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.deriveStats()
|
||||
.orderJoin()
|
||||
.printlnTree();
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void testSimpleSQLWithProject() {
|
||||
String sql = "select T1.id from T1, T2, T3, T4 "
|
||||
+ "where "
|
||||
+ "T1.id = T2.id and "
|
||||
+ "T2.score = T3.score and "
|
||||
+ "T3.id = T4.id";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.deriveStats()
|
||||
.orderJoin()
|
||||
.printlnTree();
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void testComplexProject() {
|
||||
String sql = "select count(*) \n"
|
||||
+ "from \n"
|
||||
+ "T1, \n"
|
||||
+ "(\n"
|
||||
+ "select (T2.score + T3.score) as score from T2 join T3 on T2.id = T3.id"
|
||||
+ ") subTable, \n"
|
||||
+ "( \n"
|
||||
+ "select (T4.id*2) as id from T4"
|
||||
+ ") doubleT4 \n"
|
||||
+ "where \n"
|
||||
+ "T1.id = doubleT4.id and \n"
|
||||
+ "T1.score = subTable.score;\n";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.deriveStats()
|
||||
.orderJoin()
|
||||
.printlnTree();
|
||||
}
|
||||
|
||||
@Test
|
||||
protected void test() {
|
||||
String sql = "select count(*) \n"
|
||||
+ "from \n"
|
||||
+ "T1 \n"
|
||||
+ " join (\n"
|
||||
+ "select (1) from T2"
|
||||
+ ") subTable; \n";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.deriveStats()
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
@ -46,9 +46,9 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
public class HyperGraphBuilder {
|
||||
private List<Integer> rowCounts = new ArrayList<>();
|
||||
private HashMap<BitSet, LogicalPlan> plans = new HashMap<>();
|
||||
private HashMap<BitSet, List<Integer>> schemas = new HashMap<>();
|
||||
private final List<Integer> rowCounts = new ArrayList<>();
|
||||
private final HashMap<BitSet, LogicalPlan> plans = new HashMap<>();
|
||||
private final HashMap<BitSet, List<Integer>> schemas = new HashMap<>();
|
||||
|
||||
public HyperGraph build() {
|
||||
assert plans.size() == 1 : "there are cross join";
|
||||
@ -215,9 +215,9 @@ public class HyperGraphBuilder {
|
||||
List<Expression> conditions = new ArrayList<>(join.getExpressions());
|
||||
Set<Slot> inputs = condition.getInputSlots();
|
||||
if (leftSlots.containsAll(inputs)) {
|
||||
left = (LogicalJoin) attachCondition(condition, (LogicalJoin) left);
|
||||
left = attachCondition(condition, (LogicalJoin) left);
|
||||
} else if (rightSlots.containsAll(inputs)) {
|
||||
right = (LogicalJoin) attachCondition(condition, (LogicalJoin) right);
|
||||
right = attachCondition(condition, (LogicalJoin) right);
|
||||
} else {
|
||||
conditions.add(condition);
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
|
||||
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
|
||||
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
|
||||
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
@ -38,8 +39,10 @@ import org.apache.doris.nereids.rules.RuleFactory;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.OriginStatement;
|
||||
@ -165,6 +168,11 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker optimize() {
|
||||
new OptimizeRulesJob(cascadesContext).execute();
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker implement() {
|
||||
Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot());
|
||||
Assertions.assertTrue(plan instanceof PhysicalPlan);
|
||||
@ -269,9 +277,22 @@ public class PlanChecker {
|
||||
}
|
||||
|
||||
public PlanChecker orderJoin() {
|
||||
cascadesContext.pushJob(
|
||||
new JoinOrderJob(cascadesContext.getMemo().getRoot(), cascadesContext.getCurrentJobContext()));
|
||||
Group root = cascadesContext.getMemo().getRoot();
|
||||
boolean changeRoot = false;
|
||||
if (root.isJoinGroup()) {
|
||||
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
|
||||
GroupExpression newExpr = new GroupExpression(
|
||||
new LogicalProject(outputs, root.getLogicalExpression().getPlan()),
|
||||
Lists.newArrayList(root));
|
||||
root = new Group();
|
||||
root.addGroupExpression(newExpr);
|
||||
changeRoot = true;
|
||||
}
|
||||
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
|
||||
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
|
||||
if (changeRoot) {
|
||||
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user