From 6ddbd204e7416ae2a86acfc8f11396da21e77243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=81=A5?= Date: Thu, 15 Dec 2022 21:59:55 +0800 Subject: [PATCH] [fix](Nereids): Update plan when prune column in DPHyp (#14880) --- .../apache/doris/nereids/NereidsPlanner.java | 21 +++++ .../nereids/jobs/joinorder/JoinOrderJob.java | 64 +++++++++++-- .../jobs/joinorder/hypergraph/HyperGraph.java | 88 ++++++++++++++---- .../jobs/joinorder/hypergraph/Node.java | 4 + .../hypergraph/SubgraphEnumerator.java | 22 ++--- .../hypergraph/receiver/AbstractReceiver.java | 13 ++- .../hypergraph/receiver/Counter.java | 8 +- .../hypergraph/receiver/PlanReceiver.java | 49 ++++++++-- .../org/apache/doris/nereids/memo/Group.java | 5 + .../org/apache/doris/nereids/memo/Memo.java | 20 +++- .../nereids/trees/expressions/Expression.java | 4 + .../jobs/joinorder/JoinOrderJobTest.java | 58 ------------ .../nereids/jobs/joinorder/TPCHTest.java | 36 ++++++++ .../nereids/sqltest/JoinOrderJobTest.java | 91 +++++++++++++++++++ .../doris/nereids/util/HyperGraphBuilder.java | 10 +- .../doris/nereids/util/PlanChecker.java | 25 ++++- 16 files changed, 393 insertions(+), 125 deletions(-) delete mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJobTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index ee75a01ed9..433f96f183 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor; import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob; import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; +import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.metrics.event.CounterEvent; @@ -36,9 +37,11 @@ import org.apache.doris.nereids.processor.post.PlanPostProcessors; import org.apache.doris.nereids.processor.pre.PlanPreprocessors; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.planner.PlanFragment; import org.apache.doris.planner.Planner; @@ -209,6 +212,24 @@ public class NereidsPlanner extends Planner { } private void joinReorder() { + Group root = getRoot(); + boolean changeRoot = false; + if (root.isJoinGroup()) { + // If the root group is join group, DPHyp can change the root group. + // To keep the root group is not changed, we add a project operator above join + List outputs = root.getLogicalExpression().getPlan().getOutput(); + GroupExpression newExpr = new GroupExpression( + new LogicalProject(outputs, root.getLogicalExpression().getPlan()), + Lists.newArrayList(root)); + root = new Group(); + root.addGroupExpression(newExpr); + changeRoot = true; + } + cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext())); + cascadesContext.getJobScheduler().executeJobPool(cascadesContext); + if (changeRoot) { + cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0)); + } } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java index 1ef38df1d4..bdfaaa95f2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java @@ -17,24 +17,36 @@ package org.apache.doris.nereids.jobs.joinorder; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.Job; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.jobs.JobType; +import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; /** * Join Order job with DPHyp */ public class JoinOrderJob extends Job { private final Group group; + private final Set otherProject = new HashSet<>(); public JoinOrderJob(Group group, JobContext context) { super(JobType.JOIN_ORDER, context); @@ -43,12 +55,16 @@ public class JoinOrderJob extends Job { @Override public void execute() throws AnalysisException { - Preconditions.checkArgument(!group.isJoinGroup()); GroupExpression rootExpr = group.getLogicalExpression(); int arity = rootExpr.arity(); for (int i = 0; i < arity; i++) { rootExpr.setChild(i, optimizePlan(rootExpr.child(i))); } + CascadesContext cascadesContext = context.getCascadesContext(); + cascadesContext.topDownRewrite(new ColumnPruning()); + cascadesContext.pushJob( + new DeriveStatsJob(group.getLogicalExpression(), cascadesContext.getCurrentJobContext())); + cascadesContext.getJobScheduler().executeJobPool(cascadesContext); } private Group optimizePlan(Group group) { @@ -66,7 +82,7 @@ public class JoinOrderJob extends Job { private Group optimizeJoin(Group group) { HyperGraph hyperGraph = new HyperGraph(); buildGraph(group, hyperGraph); - // Right now, we just hardcode the limit with 10000, maybe we need a better way to set it + // TODO: Right now, we just hardcode the limit with 10000, maybe we need a better way to set it int limit = 10000; PlanReceiver planReceiver = new PlanReceiver(limit); SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph); @@ -77,13 +93,22 @@ public class JoinOrderJob extends Job { throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit); } } - Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap()); - return copyToMemo(optimized); + Group memoRoot = copyToMemo(optimized); + + // For other projects, such as project constant or project nullable, we construct a new project above root + if (otherProject.size() != 0) { + otherProject.addAll(memoRoot.getLogicalExpression().getPlan().getOutput()); + LogicalProject logicalProject = new LogicalProject(new ArrayList<>(otherProject), + memoRoot.getLogicalExpression().getPlan()); + GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group)); + memoRoot = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression); + } + return memoRoot; } private Group copyToMemo(Group root) { - if (!root.isJoinGroup()) { + if (root.getGroupId() != null) { return root; } GroupExpression groupExpression = root.getLogicalExpression(); @@ -105,6 +130,11 @@ public class JoinOrderJob extends Job { * @param hyperGraph build hyperGraph */ public void buildGraph(Group group, HyperGraph hyperGraph) { + if (group.isProjectGroup()) { + buildGraph(group.getLogicalExpression().child(0), hyperGraph); + processProjectPlan(hyperGraph, group); + return; + } if (!group.isJoinGroup()) { hyperGraph.addNode(optimizePlan(group)); return; @@ -113,4 +143,26 @@ public class JoinOrderJob extends Job { buildGraph(group.getLogicalExpression().child(1), hyperGraph); hyperGraph.addEdge(group); } + + /** + * Process project expression in HyperGraph + * 1. If it's a simple expression for column pruning, we just ignore it + * 2. If it's an alias that may be used in the join operator, we need to add it to graph + * 3. If it's other expression, we can ignore them and add it after optimizing + * 4. If it's a project only associate with one table, it's seen as an endNode just like a table + */ + private void processProjectPlan(HyperGraph hyperGraph, Group group) { + LogicalProject logicalProject = (LogicalProject) group.getLogicalExpression() + .getPlan(); + + for (NamedExpression expr : logicalProject.getProjects()) { + if (expr.isAlias()) { + if (!hyperGraph.addAlias((Alias) expr, group)) { + break; + } + } else if (!expr.isSlot()) { + otherProject.add(expr); + } + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 7e3f96af02..365a12a727 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -19,7 +19,9 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; @@ -29,6 +31,8 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Set; @@ -37,8 +41,11 @@ import java.util.Set; * It's used for join ordering */ public class HyperGraph { - private List edges = new ArrayList<>(); - private List nodes = new ArrayList<>(); + private final List edges = new ArrayList<>(); + private final List nodes = new ArrayList<>(); + private final HashSet nodeSet = new HashSet<>(); + private final HashMap slotToNodeMap = new HashMap<>(); + private final HashMap complexProject = new HashMap<>(); public List getEdges() { return edges; @@ -60,12 +67,59 @@ public class HyperGraph { return nodes.get(index); } + /** + * Store the relation between Alias Slot and Original Slot and its expression + * e.g. a = b + * project((c + d) as b) + * Note if the alias if the alias only associated with one endNode, + * e.g. a = b + * project((c + 1) as b) + * we need to replace the group of that node with this project group. + * + * @param alias The alias Expression in project Operator + */ + public boolean addAlias(Alias alias, Group group) { + long bitmap = LongBitmap.newBitmap(); + for (Slot slot : alias.getInputSlots()) { + bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); + } + Slot aliasSlot = alias.toSlot(); + Preconditions.checkArgument(!slotToNodeMap.containsKey(aliasSlot)); + slotToNodeMap.put(aliasSlot, bitmap); + if (LongBitmap.getCardinality(bitmap) == 1) { + int index = LongBitmap.lowestOneIndex(bitmap); + nodeSet.remove(nodes.get(index).getGroup()); + nodeSet.add(group); + nodes.get(index).replaceGroupWith(group); + return false; + } + complexProject.put(bitmap, alias); + return true; + } + + /** + * add end node to HyperGraph + * + * @param group The group that is the end node in graph + */ public void addNode(Group group) { Preconditions.checkArgument(!group.isJoinGroup()); - // TODO: replace plan with group expression or others + for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) { + Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); + slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size())); + } + nodeSet.add(group); nodes.add(new Node(nodes.size(), group)); } + public boolean isNodeGroup(Group group) { + return nodeSet.contains(group); + } + + public HashMap getComplexProject() { + return complexProject; + } + /** * try to add edge for join group * @@ -78,15 +132,12 @@ public class HyperGraph { LogicalJoin singleJoin = new LogicalJoin(join.getJoinType(), ImmutableList.of(expression), join.left(), join.right()); Edge edge = new Edge(singleJoin, edges.size()); - long bitmap = findNodes(expression.getInputSlots()); - Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 2, - String.format("HyperGraph has not supported polynomial %s yet", expression)); - int leftIndex = LongBitmap.nextSetBit(bitmap, 0); - long left = LongBitmap.newBitmap(leftIndex); - edge.addLeftNode(left); - int rightIndex = LongBitmap.nextSetBit(bitmap, leftIndex + 1); - long right = LongBitmap.newBitmap(rightIndex); - edge.addRightNode(right); + Preconditions.checkArgument(expression.children().size() == 2); + + long left = calNodeMap(expression.child(0).getInputSlots()); + edge.setLeft(left); + long right = calNodeMap(expression.child(1).getInputSlots()); + edge.setRight(right); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } @@ -96,15 +147,12 @@ public class HyperGraph { // We don't implement this trick now. } - private long findNodes(Set slots) { + private long calNodeMap(Set slots) { + Preconditions.checkArgument(slots.size() != 0); long bitmap = LongBitmap.newBitmap(); - for (Node node : nodes) { - for (Slot slot : node.getPlan().getOutput()) { - if (slots.contains(slot)) { - bitmap = LongBitmap.set(bitmap, node.getIndex()); - break; - } - } + for (Slot slot : slots) { + Preconditions.checkArgument(slotToNodeMap.containsKey(slot)); + bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); } return bitmap; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java index e25be46ad2..ee26e31629 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java @@ -37,6 +37,10 @@ public class Node { this.index = index; } + public void replaceGroupWith(Group group) { + this.group = group; + } + public int getIndex() { return index; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java index 4cf59f37b5..c88bf23b6c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java @@ -120,7 +120,7 @@ public class SubgraphEnumerator { if (edges.isEmpty()) { continue; } - if (!receiver.emitCsgCmp(csg, newCmp, edges)) { + if (!receiver.emitCsgCmp(csg, newCmp, edges, hyperGraph.getComplexProject())) { return false; } } @@ -144,16 +144,15 @@ public class SubgraphEnumerator { forbiddenNodes = LongBitmap.or(forbiddenNodes, csg); long neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, LongBitmap.clone(forbiddenNodes), edgeCalculator); - for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) { long cmp = LongBitmap.newBitmap(nodeIndex); // whether there is an edge between csg and cmp List edges = edgeCalculator.connectCsgCmp(csg, cmp); - if (edges.isEmpty()) { - continue; - } - if (!receiver.emitCsgCmp(csg, cmp, edges)) { - return false; + + if (!edges.isEmpty()) { + if (!receiver.emitCsgCmp(csg, cmp, edges, hyperGraph.getComplexProject())) { + return false; + } } // In order to avoid enumerate repeated cmp, e.g., @@ -191,7 +190,6 @@ public class SubgraphEnumerator { forbiddenNodes = LongBitmap.or(forbiddenNodes, subgraph); neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes); forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods); - for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) { long left = edge.getLeft(); long right = edge.getRight(); @@ -268,11 +266,11 @@ public class SubgraphEnumerator { simpleContains.or(containSimpleEdges.get(bitmap1)); simpleContains.or(containSimpleEdges.get(bitmap2)); BitSet complexContains = new BitSet(); - simpleContains.or(containComplexEdges.get(bitmap1)); - simpleContains.or(containComplexEdges.get(bitmap2)); + complexContains.or(containComplexEdges.get(bitmap1)); + complexContains.or(containComplexEdges.get(bitmap2)); BitSet overlaps = new BitSet(); - simpleContains.or(overlapEdges.get(bitmap1)); - simpleContains.or(overlapEdges.get(bitmap2)); + overlaps.or(overlapEdges.get(bitmap1)); + overlaps.or(overlapEdges.get(bitmap2)); for (int index : overlaps.stream().toArray()) { Edge edge = edges.get(index); if (isContainEdge(subgraph, edge)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java index 975c351248..d72da70a7c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java @@ -19,20 +19,23 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver; import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import java.util.HashMap; import java.util.List; /** * A interface of receiver */ public interface AbstractReceiver { - public boolean emitCsgCmp(long csg, long cmp, List edges); + boolean emitCsgCmp(long csg, long cmp, List edges, + HashMap projectExpression); - public void addGroup(long bitSet, Group group); + void addGroup(long bitSet, Group group); - public boolean contain(long bitSet); + boolean contain(long bitSet); - public void reset(); + void reset(); - public Group getBestPlan(long bitSet); + Group getBestPlan(long bitSet); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java index 7326bee3f5..9dcaf266e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver; import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import com.google.common.base.Preconditions; @@ -31,9 +32,9 @@ import java.util.List; */ public class Counter implements AbstractReceiver { // limit define the max number of csg-cmp pair in this Receiver - private int limit; + private final int limit; private int emitCount = 0; - private HashMap counter = new HashMap<>(); + private final HashMap counter = new HashMap<>(); public Counter() { this.limit = Integer.MAX_VALUE; @@ -51,7 +52,8 @@ public class Counter implements AbstractReceiver { * @param edges the join operator * @return the left and the right can be connected by the edge */ - public boolean emitCsgCmp(long left, long right, List edges) { + public boolean emitCsgCmp(long left, long right, List edges, + HashMap projectExpression) { Preconditions.checkArgument(counter.containsKey(left)); Preconditions.checkArgument(counter.containsKey(right)); emitCount += 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index 93a2f0ce47..26f6c846ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -24,8 +24,9 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.stats.StatsCalculator; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -60,7 +61,8 @@ public class PlanReceiver implements AbstractReceiver { * @return the left and the right can be connected by the edge */ @Override - public boolean emitCsgCmp(long left, long right, List edges) { + public boolean emitCsgCmp(long left, long right, List edges, + HashMap projectExpression) { Preconditions.checkArgument(planTable.containsKey(left)); Preconditions.checkArgument(planTable.containsKey(right)); emitCount += 1; @@ -81,9 +83,12 @@ public class PlanReceiver implements AbstractReceiver { if (!planTable.containsKey(fullKey) || planTable.get(fullKey).getLogicalExpression().getCostByProperties(PhysicalProperties.ANY) > winnerGroup.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)) { + // When we decide to store the new Plan, we need to add the complex project to it. + winnerGroup = tryAddProject(winnerGroup, projectExpression, fullKey); planTable.put(fullKey, winnerGroup); } return true; + } @Override @@ -108,11 +113,39 @@ public class PlanReceiver implements AbstractReceiver { return planTable.get(bitmap); } - private double getSimpleCost(Plan plan) { - if (!(plan instanceof LogicalJoin)) { - return plan.getGroupExpression().get().getOwnerGroup().getStatistics().getRowCount(); + private double getSimpleCost(Group group) { + if (!group.isJoinGroup()) { + return group.getStatistics().getRowCount(); } - return plan.getGroupExpression().get().getCostByProperties(PhysicalProperties.ANY); + return group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY); + } + + private Group tryAddProject(Group group, HashMap projectExpression, long fullKey) { + List projects = new ArrayList<>(); + List removedKey = new ArrayList<>(); + for (Long bitmap : projectExpression.keySet()) { + if (LongBitmap.isSubset(bitmap, fullKey)) { + NamedExpression namedExpression = projectExpression.get(bitmap); + projects.add(namedExpression); + removedKey.add(bitmap); + } + } + for (Long bitmap : removedKey) { + projectExpression.remove(bitmap); + } + if (projects.size() != 0) { + LogicalProject logicalProject = new LogicalProject<>(projects, + group.getLogicalExpression().getPlan()); + GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group)); + groupExpression.updateLowestCostTable(PhysicalProperties.ANY, + Lists.newArrayList(PhysicalProperties.ANY, PhysicalProperties.ANY), + group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)); + Group projectGroup = new Group(); + projectGroup.addGroupExpression(groupExpression); + StatsCalculator.estimate(groupExpression); + return projectGroup; + } + return group; } private Group constructGroup(long left, long right, List edges) { @@ -120,10 +153,8 @@ public class PlanReceiver implements AbstractReceiver { Preconditions.checkArgument(planTable.containsKey(right)); Group leftGroup = planTable.get(left); Group rightGroup = planTable.get(right); - Plan leftPlan = leftGroup.getLogicalExpression().getPlan(); - Plan rightPlan = rightGroup.getLogicalExpression().getPlan(); - double cost = getSimpleCost(leftPlan) + getSimpleCost(rightPlan); + double cost = getSimpleCost(leftGroup) + getSimpleCost(rightGroup); List conditions = new ArrayList<>(); for (Edge edge : edges) { conditions.addAll(edge.getJoin().getExpressions()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index cfce2e6d28..fb2adb9ac6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.TreeStringUtils; import org.apache.doris.statistics.StatsDeriveResult; @@ -322,6 +323,10 @@ public class Group { return getLogicalExpression().getPlan() instanceof LogicalJoin; } + public boolean isProjectGroup() { + return getLogicalExpression().getPlan() instanceof LogicalProject; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index 1c24ed1772..ba6a1d73fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -60,7 +60,7 @@ public class Memo { private final Map groups = Maps.newLinkedHashMap(); // we could not use Set, because Set does not have get method. private final Map groupExpressions = Maps.newHashMap(); - private final Group root; + private Group root; // FOR TEST ONLY public Memo() { @@ -71,10 +71,24 @@ public class Memo { root = init(plan); } + public static long getStateId() { + return stateId; + } + public Group getRoot() { return root; } + /** + * This function used to update the root group when DPHyp change the root Group + * Note it only used in DPHyp + * + * @param root The new root Group + */ + public void setRoot(Group root) { + this.root = root; + } + public List getGroups() { return ImmutableList.copyOf(groups.values()); } @@ -83,10 +97,6 @@ public class Memo { return groupExpressions; } - public static long getStateId() { - return stateId; - } - /** * Add plan to Memo. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 9ffa29b45e..bdf3d37763 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -143,6 +143,10 @@ public abstract class Expression extends AbstractTreeNode implements return this instanceof Slot; } + public boolean isAlias() { + return this instanceof Alias; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJobTest.java deleted file mode 100644 index 93edf18663..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJobTest.java +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.jobs.joinorder; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.util.LogicalPlanBuilder; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; - -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Test; - -public class JoinOrderJobTest { - private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(1, "t1", 0); - private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(2, "t2", 0); - private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(3, "t3", 0); - private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(4, "t4", 0); - private final LogicalOlapScan scan5 = PlanConstructor.newLogicalOlapScan(5, "t5", 0); - - @Test - void testJoinOrderJob() { - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .hashJoinUsing( - new LogicalPlanBuilder(scan2) - .hashJoinUsing(scan3, JoinType.INNER_JOIN, Pair.of(0, 1)) - .hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 1)) - .hashJoinUsing(scan5, JoinType.INNER_JOIN, Pair.of(0, 1)) - .build(), - JoinType.INNER_JOIN, Pair.of(0, 1) - ) - .project(Lists.newArrayList(1)) - .build(); - System.out.println(plan.treeString()); - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .deriveStats() - .orderJoin() - .printlnTree(); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java new file mode 100644 index 0000000000..7eaf021459 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/TPCHTest.java @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.joinorder; + +import org.apache.doris.nereids.datasets.tpch.TPCHTestBase; +import org.apache.doris.nereids.datasets.tpch.TPCHUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Test; + +public class TPCHTest extends TPCHTestBase { + @Test + void testQ5() { + PlanChecker.from(connectContext) + .analyze(TPCHUtils.Q5) + .rewrite() + .deriveStats() + .orderJoin() + .optimize(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java new file mode 100644 index 0000000000..ac11c5ebbb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.sqltest; + +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Test; + +public class JoinOrderJobTest extends SqlTestBase { + @Test + protected void testSimpleSQL() { + String sql = "select * from T1, T2, T3, T4 " + + "where " + + "T1.id = T2.id and " + + "T2.score = T3.score and " + + "T3.id = T4.id"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .deriveStats() + .orderJoin() + .printlnTree(); + } + + @Test + protected void testSimpleSQLWithProject() { + String sql = "select T1.id from T1, T2, T3, T4 " + + "where " + + "T1.id = T2.id and " + + "T2.score = T3.score and " + + "T3.id = T4.id"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .deriveStats() + .orderJoin() + .printlnTree(); + } + + @Test + protected void testComplexProject() { + String sql = "select count(*) \n" + + "from \n" + + "T1, \n" + + "(\n" + + "select (T2.score + T3.score) as score from T2 join T3 on T2.id = T3.id" + + ") subTable, \n" + + "( \n" + + "select (T4.id*2) as id from T4" + + ") doubleT4 \n" + + "where \n" + + "T1.id = doubleT4.id and \n" + + "T1.score = subTable.score;\n"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .deriveStats() + .orderJoin() + .printlnTree(); + } + + @Test + protected void test() { + String sql = "select count(*) \n" + + "from \n" + + "T1 \n" + + " join (\n" + + "select (1) from T2" + + ") subTable; \n"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .deriveStats() + .printlnTree(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java index 7a107880b9..49e80acccf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java @@ -46,9 +46,9 @@ import java.util.Optional; import java.util.Set; public class HyperGraphBuilder { - private List rowCounts = new ArrayList<>(); - private HashMap plans = new HashMap<>(); - private HashMap> schemas = new HashMap<>(); + private final List rowCounts = new ArrayList<>(); + private final HashMap plans = new HashMap<>(); + private final HashMap> schemas = new HashMap<>(); public HyperGraph build() { assert plans.size() == 1 : "there are cross join"; @@ -215,9 +215,9 @@ public class HyperGraphBuilder { List conditions = new ArrayList<>(join.getExpressions()); Set inputs = condition.getInputSlots(); if (leftSlots.containsAll(inputs)) { - left = (LogicalJoin) attachCondition(condition, (LogicalJoin) left); + left = attachCondition(condition, (LogicalJoin) left); } else if (rightSlots.containsAll(inputs)) { - right = (LogicalJoin) attachCondition(condition, (LogicalJoin) right); + right = attachCondition(condition, (LogicalJoin) right); } else { conditions.add(condition); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index a86ce9ed6c..fa7376885e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor; +import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob; import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; import org.apache.doris.nereids.memo.Group; @@ -38,8 +39,10 @@ import org.apache.doris.nereids.rules.RuleFactory; import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; @@ -165,6 +168,11 @@ public class PlanChecker { return this; } + public PlanChecker optimize() { + new OptimizeRulesJob(cascadesContext).execute(); + return this; + } + public PlanChecker implement() { Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot()); Assertions.assertTrue(plan instanceof PhysicalPlan); @@ -269,9 +277,22 @@ public class PlanChecker { } public PlanChecker orderJoin() { - cascadesContext.pushJob( - new JoinOrderJob(cascadesContext.getMemo().getRoot(), cascadesContext.getCurrentJobContext())); + Group root = cascadesContext.getMemo().getRoot(); + boolean changeRoot = false; + if (root.isJoinGroup()) { + List outputs = root.getLogicalExpression().getPlan().getOutput(); + GroupExpression newExpr = new GroupExpression( + new LogicalProject(outputs, root.getLogicalExpression().getPlan()), + Lists.newArrayList(root)); + root = new Group(); + root.addGroupExpression(newExpr); + changeRoot = true; + } + cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext())); cascadesContext.getJobScheduler().executeJobPool(cascadesContext); + if (changeRoot) { + cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0)); + } return this; }