[fix](Nereids): Update plan when prune column in DPHyp (#14880)

This commit is contained in:
谢健
2022-12-15 21:59:55 +08:00
committed by GitHub
parent 94e0955687
commit 6ddbd204e7
16 changed files with 393 additions and 125 deletions

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.metrics.event.CounterEvent;
@ -36,9 +37,11 @@ import org.apache.doris.nereids.processor.post.PlanPostProcessors;
import org.apache.doris.nereids.processor.pre.PlanPreprocessors;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.Planner;
@ -209,6 +212,24 @@ public class NereidsPlanner extends Planner {
}
private void joinReorder() {
Group root = getRoot();
boolean changeRoot = false;
if (root.isJoinGroup()) {
// If the root group is join group, DPHyp can change the root group.
// To keep the root group is not changed, we add a project operator above join
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
GroupExpression newExpr = new GroupExpression(
new LogicalProject(outputs, root.getLogicalExpression().getPlan()),
Lists.newArrayList(root));
root = new Group();
root.addGroupExpression(newExpr);
changeRoot = true;
}
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
if (changeRoot) {
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
}
}
/**

View File

@ -17,24 +17,36 @@
package org.apache.doris.nereids.jobs.joinorder;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
/**
* Join Order job with DPHyp
*/
public class JoinOrderJob extends Job {
private final Group group;
private final Set<NamedExpression> otherProject = new HashSet<>();
public JoinOrderJob(Group group, JobContext context) {
super(JobType.JOIN_ORDER, context);
@ -43,12 +55,16 @@ public class JoinOrderJob extends Job {
@Override
public void execute() throws AnalysisException {
Preconditions.checkArgument(!group.isJoinGroup());
GroupExpression rootExpr = group.getLogicalExpression();
int arity = rootExpr.arity();
for (int i = 0; i < arity; i++) {
rootExpr.setChild(i, optimizePlan(rootExpr.child(i)));
}
CascadesContext cascadesContext = context.getCascadesContext();
cascadesContext.topDownRewrite(new ColumnPruning());
cascadesContext.pushJob(
new DeriveStatsJob(group.getLogicalExpression(), cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
private Group optimizePlan(Group group) {
@ -66,7 +82,7 @@ public class JoinOrderJob extends Job {
private Group optimizeJoin(Group group) {
HyperGraph hyperGraph = new HyperGraph();
buildGraph(group, hyperGraph);
// Right now, we just hardcode the limit with 10000, maybe we need a better way to set it
// TODO: Right now, we just hardcode the limit with 10000, maybe we need a better way to set it
int limit = 10000;
PlanReceiver planReceiver = new PlanReceiver(limit);
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph);
@ -77,13 +93,22 @@ public class JoinOrderJob extends Job {
throw new RuntimeException("DPHyp can not enumerate all sub graphs with limit=" + limit);
}
}
Group optimized = planReceiver.getBestPlan(hyperGraph.getNodesMap());
return copyToMemo(optimized);
Group memoRoot = copyToMemo(optimized);
// For other projects, such as project constant or project nullable, we construct a new project above root
if (otherProject.size() != 0) {
otherProject.addAll(memoRoot.getLogicalExpression().getPlan().getOutput());
LogicalProject logicalProject = new LogicalProject(new ArrayList<>(otherProject),
memoRoot.getLogicalExpression().getPlan());
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
memoRoot = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression);
}
return memoRoot;
}
private Group copyToMemo(Group root) {
if (!root.isJoinGroup()) {
if (root.getGroupId() != null) {
return root;
}
GroupExpression groupExpression = root.getLogicalExpression();
@ -105,6 +130,11 @@ public class JoinOrderJob extends Job {
* @param hyperGraph build hyperGraph
*/
public void buildGraph(Group group, HyperGraph hyperGraph) {
if (group.isProjectGroup()) {
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
processProjectPlan(hyperGraph, group);
return;
}
if (!group.isJoinGroup()) {
hyperGraph.addNode(optimizePlan(group));
return;
@ -113,4 +143,26 @@ public class JoinOrderJob extends Job {
buildGraph(group.getLogicalExpression().child(1), hyperGraph);
hyperGraph.addEdge(group);
}
/**
* Process project expression in HyperGraph
* 1. If it's a simple expression for column pruning, we just ignore it
* 2. If it's an alias that may be used in the join operator, we need to add it to graph
* 3. If it's other expression, we can ignore them and add it after optimizing
* 4. If it's a project only associate with one table, it's seen as an endNode just like a table
*/
private void processProjectPlan(HyperGraph hyperGraph, Group group) {
LogicalProject<? extends Plan> logicalProject = (LogicalProject<? extends Plan>) group.getLogicalExpression()
.getPlan();
for (NamedExpression expr : logicalProject.getProjects()) {
if (expr.isAlias()) {
if (!hyperGraph.addAlias((Alias) expr, group)) {
break;
}
} else if (!expr.isSlot()) {
otherProject.add(expr);
}
}
}
}

View File

@ -19,7 +19,9 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -29,6 +31,8 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -37,8 +41,11 @@ import java.util.Set;
* It's used for join ordering
*/
public class HyperGraph {
private List<Edge> edges = new ArrayList<>();
private List<Node> nodes = new ArrayList<>();
private final List<Edge> edges = new ArrayList<>();
private final List<Node> nodes = new ArrayList<>();
private final HashSet<Group> nodeSet = new HashSet<>();
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
private final HashMap<Long, NamedExpression> complexProject = new HashMap<>();
public List<Edge> getEdges() {
return edges;
@ -60,12 +67,59 @@ public class HyperGraph {
return nodes.get(index);
}
/**
* Store the relation between Alias Slot and Original Slot and its expression
* e.g. a = b
* project((c + d) as b)
* Note if the alias if the alias only associated with one endNode,
* e.g. a = b
* project((c + 1) as b)
* we need to replace the group of that node with this project group.
*
* @param alias The alias Expression in project Operator
*/
public boolean addAlias(Alias alias, Group group) {
long bitmap = LongBitmap.newBitmap();
for (Slot slot : alias.getInputSlots()) {
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
}
Slot aliasSlot = alias.toSlot();
Preconditions.checkArgument(!slotToNodeMap.containsKey(aliasSlot));
slotToNodeMap.put(aliasSlot, bitmap);
if (LongBitmap.getCardinality(bitmap) == 1) {
int index = LongBitmap.lowestOneIndex(bitmap);
nodeSet.remove(nodes.get(index).getGroup());
nodeSet.add(group);
nodes.get(index).replaceGroupWith(group);
return false;
}
complexProject.put(bitmap, alias);
return true;
}
/**
* add end node to HyperGraph
*
* @param group The group that is the end node in graph
*/
public void addNode(Group group) {
Preconditions.checkArgument(!group.isJoinGroup());
// TODO: replace plan with group expression or others
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
}
nodeSet.add(group);
nodes.add(new Node(nodes.size(), group));
}
public boolean isNodeGroup(Group group) {
return nodeSet.contains(group);
}
public HashMap<Long, NamedExpression> getComplexProject() {
return complexProject;
}
/**
* try to add edge for join group
*
@ -78,15 +132,12 @@ public class HyperGraph {
LogicalJoin singleJoin = new LogicalJoin(join.getJoinType(), ImmutableList.of(expression), join.left(),
join.right());
Edge edge = new Edge(singleJoin, edges.size());
long bitmap = findNodes(expression.getInputSlots());
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 2,
String.format("HyperGraph has not supported polynomial %s yet", expression));
int leftIndex = LongBitmap.nextSetBit(bitmap, 0);
long left = LongBitmap.newBitmap(leftIndex);
edge.addLeftNode(left);
int rightIndex = LongBitmap.nextSetBit(bitmap, leftIndex + 1);
long right = LongBitmap.newBitmap(rightIndex);
edge.addRightNode(right);
Preconditions.checkArgument(expression.children().size() == 2);
long left = calNodeMap(expression.child(0).getInputSlots());
edge.setLeft(left);
long right = calNodeMap(expression.child(1).getInputSlots());
edge.setRight(right);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
@ -96,15 +147,12 @@ public class HyperGraph {
// We don't implement this trick now.
}
private long findNodes(Set<Slot> slots) {
private long calNodeMap(Set<Slot> slots) {
Preconditions.checkArgument(slots.size() != 0);
long bitmap = LongBitmap.newBitmap();
for (Node node : nodes) {
for (Slot slot : node.getPlan().getOutput()) {
if (slots.contains(slot)) {
bitmap = LongBitmap.set(bitmap, node.getIndex());
break;
}
}
for (Slot slot : slots) {
Preconditions.checkArgument(slotToNodeMap.containsKey(slot));
bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot));
}
return bitmap;
}

View File

@ -37,6 +37,10 @@ public class Node {
this.index = index;
}
public void replaceGroupWith(Group group) {
this.group = group;
}
public int getIndex() {
return index;
}

View File

@ -120,7 +120,7 @@ public class SubgraphEnumerator {
if (edges.isEmpty()) {
continue;
}
if (!receiver.emitCsgCmp(csg, newCmp, edges)) {
if (!receiver.emitCsgCmp(csg, newCmp, edges, hyperGraph.getComplexProject())) {
return false;
}
}
@ -144,16 +144,15 @@ public class SubgraphEnumerator {
forbiddenNodes = LongBitmap.or(forbiddenNodes, csg);
long neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, LongBitmap.clone(forbiddenNodes),
edgeCalculator);
for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) {
long cmp = LongBitmap.newBitmap(nodeIndex);
// whether there is an edge between csg and cmp
List<Edge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
if (edges.isEmpty()) {
continue;
}
if (!receiver.emitCsgCmp(csg, cmp, edges)) {
return false;
if (!edges.isEmpty()) {
if (!receiver.emitCsgCmp(csg, cmp, edges, hyperGraph.getComplexProject())) {
return false;
}
}
// In order to avoid enumerate repeated cmp, e.g.,
@ -191,7 +190,6 @@ public class SubgraphEnumerator {
forbiddenNodes = LongBitmap.or(forbiddenNodes, subgraph);
neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
long left = edge.getLeft();
long right = edge.getRight();
@ -268,11 +266,11 @@ public class SubgraphEnumerator {
simpleContains.or(containSimpleEdges.get(bitmap1));
simpleContains.or(containSimpleEdges.get(bitmap2));
BitSet complexContains = new BitSet();
simpleContains.or(containComplexEdges.get(bitmap1));
simpleContains.or(containComplexEdges.get(bitmap2));
complexContains.or(containComplexEdges.get(bitmap1));
complexContains.or(containComplexEdges.get(bitmap2));
BitSet overlaps = new BitSet();
simpleContains.or(overlapEdges.get(bitmap1));
simpleContains.or(overlapEdges.get(bitmap2));
overlaps.or(overlapEdges.get(bitmap1));
overlaps.or(overlapEdges.get(bitmap2));
for (int index : overlaps.stream().toArray()) {
Edge edge = edges.get(index);
if (isContainEdge(subgraph, edge)) {

View File

@ -19,20 +19,23 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import java.util.HashMap;
import java.util.List;
/**
* A interface of receiver
*/
public interface AbstractReceiver {
public boolean emitCsgCmp(long csg, long cmp, List<Edge> edges);
boolean emitCsgCmp(long csg, long cmp, List<Edge> edges,
HashMap<Long, NamedExpression> projectExpression);
public void addGroup(long bitSet, Group group);
void addGroup(long bitSet, Group group);
public boolean contain(long bitSet);
boolean contain(long bitSet);
public void reset();
void reset();
public Group getBestPlan(long bitSet);
Group getBestPlan(long bitSet);
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import com.google.common.base.Preconditions;
@ -31,9 +32,9 @@ import java.util.List;
*/
public class Counter implements AbstractReceiver {
// limit define the max number of csg-cmp pair in this Receiver
private int limit;
private final int limit;
private int emitCount = 0;
private HashMap<Long, Integer> counter = new HashMap<>();
private final HashMap<Long, Integer> counter = new HashMap<>();
public Counter() {
this.limit = Integer.MAX_VALUE;
@ -51,7 +52,8 @@ public class Counter implements AbstractReceiver {
* @param edges the join operator
* @return the left and the right can be connected by the edge
*/
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<Edge> edges,
HashMap<Long, NamedExpression> projectExpression) {
Preconditions.checkArgument(counter.containsKey(left));
Preconditions.checkArgument(counter.containsKey(right));
emitCount += 1;

View File

@ -24,8 +24,9 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.stats.StatsCalculator;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -60,7 +61,8 @@ public class PlanReceiver implements AbstractReceiver {
* @return the left and the right can be connected by the edge
*/
@Override
public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
public boolean emitCsgCmp(long left, long right, List<Edge> edges,
HashMap<Long, NamedExpression> projectExpression) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
emitCount += 1;
@ -81,9 +83,12 @@ public class PlanReceiver implements AbstractReceiver {
if (!planTable.containsKey(fullKey)
|| planTable.get(fullKey).getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)
> winnerGroup.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY)) {
// When we decide to store the new Plan, we need to add the complex project to it.
winnerGroup = tryAddProject(winnerGroup, projectExpression, fullKey);
planTable.put(fullKey, winnerGroup);
}
return true;
}
@Override
@ -108,11 +113,39 @@ public class PlanReceiver implements AbstractReceiver {
return planTable.get(bitmap);
}
private double getSimpleCost(Plan plan) {
if (!(plan instanceof LogicalJoin)) {
return plan.getGroupExpression().get().getOwnerGroup().getStatistics().getRowCount();
private double getSimpleCost(Group group) {
if (!group.isJoinGroup()) {
return group.getStatistics().getRowCount();
}
return plan.getGroupExpression().get().getCostByProperties(PhysicalProperties.ANY);
return group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY);
}
private Group tryAddProject(Group group, HashMap<Long, NamedExpression> projectExpression, long fullKey) {
List<NamedExpression> projects = new ArrayList<>();
List<Long> removedKey = new ArrayList<>();
for (Long bitmap : projectExpression.keySet()) {
if (LongBitmap.isSubset(bitmap, fullKey)) {
NamedExpression namedExpression = projectExpression.get(bitmap);
projects.add(namedExpression);
removedKey.add(bitmap);
}
}
for (Long bitmap : removedKey) {
projectExpression.remove(bitmap);
}
if (projects.size() != 0) {
LogicalProject logicalProject = new LogicalProject<>(projects,
group.getLogicalExpression().getPlan());
GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group));
groupExpression.updateLowestCostTable(PhysicalProperties.ANY,
Lists.newArrayList(PhysicalProperties.ANY, PhysicalProperties.ANY),
group.getLogicalExpression().getCostByProperties(PhysicalProperties.ANY));
Group projectGroup = new Group();
projectGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
return projectGroup;
}
return group;
}
private Group constructGroup(long left, long right, List<Edge> edges) {
@ -120,10 +153,8 @@ public class PlanReceiver implements AbstractReceiver {
Preconditions.checkArgument(planTable.containsKey(right));
Group leftGroup = planTable.get(left);
Group rightGroup = planTable.get(right);
Plan leftPlan = leftGroup.getLogicalExpression().getPlan();
Plan rightPlan = rightGroup.getLogicalExpression().getPlan();
double cost = getSimpleCost(leftPlan) + getSimpleCost(rightPlan);
double cost = getSimpleCost(leftGroup) + getSimpleCost(rightGroup);
List<Expression> conditions = new ArrayList<>();
for (Edge edge : edges) {
conditions.addAll(edge.getJoin().getExpressions());

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.statistics.StatsDeriveResult;
@ -322,6 +323,10 @@ public class Group {
return getLogicalExpression().getPlan() instanceof LogicalJoin;
}
public boolean isProjectGroup() {
return getLogicalExpression().getPlan() instanceof LogicalProject;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -60,7 +60,7 @@ public class Memo {
private final Map<GroupId, Group> groups = Maps.newLinkedHashMap();
// we could not use Set, because Set does not have get method.
private final Map<GroupExpression, GroupExpression> groupExpressions = Maps.newHashMap();
private final Group root;
private Group root;
// FOR TEST ONLY
public Memo() {
@ -71,10 +71,24 @@ public class Memo {
root = init(plan);
}
public static long getStateId() {
return stateId;
}
public Group getRoot() {
return root;
}
/**
* This function used to update the root group when DPHyp change the root Group
* Note it only used in DPHyp
*
* @param root The new root Group
*/
public void setRoot(Group root) {
this.root = root;
}
public List<Group> getGroups() {
return ImmutableList.copyOf(groups.values());
}
@ -83,10 +97,6 @@ public class Memo {
return groupExpressions;
}
public static long getStateId() {
return stateId;
}
/**
* Add plan to Memo.
*

View File

@ -143,6 +143,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return this instanceof Slot;
}
public boolean isAlias() {
return this instanceof Alias;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -1,58 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
public class JoinOrderJobTest {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(1, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(2, "t2", 0);
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(3, "t3", 0);
private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(4, "t4", 0);
private final LogicalOlapScan scan5 = PlanConstructor.newLogicalOlapScan(5, "t5", 0);
@Test
void testJoinOrderJob() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.hashJoinUsing(
new LogicalPlanBuilder(scan2)
.hashJoinUsing(scan3, JoinType.INNER_JOIN, Pair.of(0, 1))
.hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 1))
.hashJoinUsing(scan5, JoinType.INNER_JOIN, Pair.of(0, 1))
.build(),
JoinType.INNER_JOIN, Pair.of(0, 1)
)
.project(Lists.newArrayList(1))
.build();
System.out.println(plan.treeString());
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.deriveStats()
.orderJoin()
.printlnTree();
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder;
import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
import org.apache.doris.nereids.datasets.tpch.TPCHUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Test;
public class TPCHTest extends TPCHTestBase {
@Test
void testQ5() {
PlanChecker.from(connectContext)
.analyze(TPCHUtils.Q5)
.rewrite()
.deriveStats()
.orderJoin()
.optimize();
}
}

View File

@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.sqltest;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Test;
public class JoinOrderJobTest extends SqlTestBase {
@Test
protected void testSimpleSQL() {
String sql = "select * from T1, T2, T3, T4 "
+ "where "
+ "T1.id = T2.id and "
+ "T2.score = T3.score and "
+ "T3.id = T4.id";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.orderJoin()
.printlnTree();
}
@Test
protected void testSimpleSQLWithProject() {
String sql = "select T1.id from T1, T2, T3, T4 "
+ "where "
+ "T1.id = T2.id and "
+ "T2.score = T3.score and "
+ "T3.id = T4.id";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.orderJoin()
.printlnTree();
}
@Test
protected void testComplexProject() {
String sql = "select count(*) \n"
+ "from \n"
+ "T1, \n"
+ "(\n"
+ "select (T2.score + T3.score) as score from T2 join T3 on T2.id = T3.id"
+ ") subTable, \n"
+ "( \n"
+ "select (T4.id*2) as id from T4"
+ ") doubleT4 \n"
+ "where \n"
+ "T1.id = doubleT4.id and \n"
+ "T1.score = subTable.score;\n";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.orderJoin()
.printlnTree();
}
@Test
protected void test() {
String sql = "select count(*) \n"
+ "from \n"
+ "T1 \n"
+ " join (\n"
+ "select (1) from T2"
+ ") subTable; \n";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.deriveStats()
.printlnTree();
}
}

View File

@ -46,9 +46,9 @@ import java.util.Optional;
import java.util.Set;
public class HyperGraphBuilder {
private List<Integer> rowCounts = new ArrayList<>();
private HashMap<BitSet, LogicalPlan> plans = new HashMap<>();
private HashMap<BitSet, List<Integer>> schemas = new HashMap<>();
private final List<Integer> rowCounts = new ArrayList<>();
private final HashMap<BitSet, LogicalPlan> plans = new HashMap<>();
private final HashMap<BitSet, List<Integer>> schemas = new HashMap<>();
public HyperGraph build() {
assert plans.size() == 1 : "there are cross join";
@ -215,9 +215,9 @@ public class HyperGraphBuilder {
List<Expression> conditions = new ArrayList<>(join.getExpressions());
Set<Slot> inputs = condition.getInputSlots();
if (leftSlots.containsAll(inputs)) {
left = (LogicalJoin) attachCondition(condition, (LogicalJoin) left);
left = attachCondition(condition, (LogicalJoin) left);
} else if (rightSlots.containsAll(inputs)) {
right = (LogicalJoin) attachCondition(condition, (LogicalJoin) right);
right = attachCondition(condition, (LogicalJoin) right);
} else {
conditions.add(condition);
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.memo.Group;
@ -38,8 +39,10 @@ import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
@ -165,6 +168,11 @@ public class PlanChecker {
return this;
}
public PlanChecker optimize() {
new OptimizeRulesJob(cascadesContext).execute();
return this;
}
public PlanChecker implement() {
Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot());
Assertions.assertTrue(plan instanceof PhysicalPlan);
@ -269,9 +277,22 @@ public class PlanChecker {
}
public PlanChecker orderJoin() {
cascadesContext.pushJob(
new JoinOrderJob(cascadesContext.getMemo().getRoot(), cascadesContext.getCurrentJobContext()));
Group root = cascadesContext.getMemo().getRoot();
boolean changeRoot = false;
if (root.isJoinGroup()) {
List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput();
GroupExpression newExpr = new GroupExpression(
new LogicalProject(outputs, root.getLogicalExpression().getPlan()),
Lists.newArrayList(root));
root = new Group();
root.addGroupExpression(newExpr);
changeRoot = true;
}
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
if (changeRoot) {
cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0));
}
return this;
}