[feat](Nereids): support alias when eliminate join for partially mv rewritting #30498

This commit is contained in:
谢健
2024-01-30 11:53:07 +08:00
committed by yiguolei
parent 4648902350
commit 53c624ffa0
6 changed files with 154 additions and 52 deletions

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
@ -417,7 +418,7 @@ public class GraphSimplifier {
JoinEdge newEdge = new JoinEdge(join, edge.getIndex(),
edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes(),
edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
edge.getLeftRequiredNodes(), edge.getRightRequiredNodes(), ImmutableSet.of(), ImmutableSet.of());
newEdge.addLeftExtendNode(leftNodes);
newEdge.addRightExtendNode(rightNodes);
return newEdge;

View File

@ -46,14 +46,17 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* The graph is a join graph, whose node is the leaf plan and edge is a join operator.
@ -288,6 +291,36 @@ public class HyperGraph {
return new HyperGraph.Builder().buildHyperGraphForMv(plan);
}
/**
* map output to requires output and construct named expressions
*/
public @Nullable List<NamedExpression> getNamedExpressions(
long nodeMap, Set<Slot> outputSet, Set<Slot> requireOutputs) {
List<NamedExpression> output = new ArrayList<>();
List<NamedExpression> projects = getComplexProject().get(nodeMap);
if (projects == null) {
return null;
}
for (Slot slot : requireOutputs) {
if (outputSet.contains(slot)) {
output.add(slot);
} else {
Optional<NamedExpression> expr = projects.stream()
.filter(p -> p.toSlot().equals(slot))
.findFirst();
if (!expr.isPresent()) {
return null;
}
// TODO: consider cascades alias
if (!outputSet.containsAll(expr.get().getInputSlots())) {
return null;
}
output.add(expr.get());
}
}
return output;
}
/**
* Builder of HyperGraph
*/
@ -509,6 +542,10 @@ public class HyperGraph {
}
BitSet curJoinEdges = new BitSet();
Set<Slot> leftInputSlots = ImmutableSet.copyOf(
Sets.intersection(join.getInputSlots(), join.left().getOutputSet()));
Set<Slot> rightInputSlots = ImmutableSet.copyOf(
Sets.intersection(join.getInputSlots(), join.right().getOutputSet()));
for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin<?, ?> singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
@ -518,7 +555,7 @@ public class HyperGraph {
Pair<Long, Long> ends = entry.getKey();
JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first,
LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second),
ends.first, ends.second);
ends.first, ends.second, leftInputSlots, rightInputSlots);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}

View File

@ -79,6 +79,10 @@ public abstract class Edge {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}
public boolean isRightSimple() {
return LongBitmap.getCardinality(rightExtendedNodes) == 1;
}
public void addLeftRejectEdge(JoinEdge edge) {
leftRejectEdges.add(edge);
}

View File

@ -37,12 +37,16 @@ import javax.annotation.Nullable;
public class JoinEdge extends Edge {
private final LogicalJoin<? extends Plan, ? extends Plan> join;
private final Set<Slot> leftInputSlots;
private final Set<Slot> rightInputSlots;
public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int index,
BitSet leftChildEdges, BitSet rightChildEdges, long subTreeNodes,
long leftRequireNodes, long rightRequireNodes) {
long leftRequireNodes, long rightRequireNodes, Set<Slot> leftInputSlots, Set<Slot> rightInputSlots) {
super(index, leftChildEdges, rightChildEdges, subTreeNodes, leftRequireNodes, rightRequireNodes);
this.join = join;
this.leftInputSlots = leftInputSlots;
this.rightInputSlots = rightInputSlots;
}
/**
@ -51,7 +55,8 @@ public class JoinEdge extends Edge {
public JoinEdge swap() {
JoinEdge swapEdge = new
JoinEdge(join.swap(), getIndex(), getRightChildEdges(),
getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes());
getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes(),
this.rightInputSlots, this.leftInputSlots);
swapEdge.addLeftRejectEdges(getLeftRejectEdge());
swapEdge.addRightRejectEdges(getRightRejectEdge());
return swapEdge;
@ -63,7 +68,7 @@ public class JoinEdge extends Edge {
public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) {
return new JoinEdge(join.withJoinType(joinType), getIndex(), getLeftChildEdges(), getRightChildEdges(),
getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes());
getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes(), leftInputSlots, rightInputSlots);
}
public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
@ -112,4 +117,12 @@ public class JoinEdge extends Edge {
join.getExpressions().forEach(expression -> slots.addAll(expression.getInputSlots()));
return slots;
}
public Set<Slot> getLeftInputSlots() {
return leftInputSlots;
}
public Set<Slot> getRightInputSlots() {
return rightInputSlots;
}
}

View File

@ -27,9 +27,11 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
@ -46,6 +48,7 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* HyperGraphComparator
@ -144,60 +147,81 @@ public class HyperGraphComparator {
return buildComparisonRes();
}
private @Nullable Plan constructViewPlan(long nodeBitmap, Set<Slot> requireOutputs) {
if (LongBitmap.getCardinality(nodeBitmap) != 1) {
return null;
}
Plan basePlan = viewHyperGraph.getNode(LongBitmap.lowestOneIndex(nodeBitmap)).getPlan();
if (basePlan.getOutputSet().containsAll(requireOutputs)) {
return basePlan;
}
List<NamedExpression> projects = viewHyperGraph
.getNamedExpressions(nodeBitmap, basePlan.getOutputSet(), requireOutputs);
if (projects == null) {
return null;
}
return new LogicalProject<>(projects, basePlan);
}
private boolean canEliminatePrimaryByForeign(long primaryNodes, long foreignNodes,
Set<Slot> primarySlots, Set<Slot> foreignSlots, JoinEdge joinEdge) {
Plan foreign = constructViewPlan(foreignNodes, foreignSlots);
Plan primary = constructViewPlan(primaryNodes, primarySlots);
if (foreign == null || primary == null) {
return false;
}
return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign);
}
private boolean canEliminateViewEdge(JoinEdge joinEdge) {
// eliminate by unique
if (joinEdge.getJoinType().isLeftOuterJoin() && joinEdge.isRightSimple()) {
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedRight) != 1) {
return false;
}
Plan rigthPlan = constructViewPlan(joinEdge.getRightExtendedNodes(), joinEdge.getRightInputSlots());
if (rigthPlan == null) {
return false;
}
return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
rigthPlan.getLogicalProperties().getFunctionalDependencies());
}
// eliminate by pk fk
if (joinEdge.getJoinType().isInnerJoin()) {
if (!joinEdge.isSimple()) {
return false;
}
long eliminatedLeft =
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), eliminateViewNodesMap);
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedLeft) == 0
&& LongBitmap.getCardinality(eliminatedRight) == 1) {
return canEliminatePrimaryByForeign(joinEdge.getRightExtendedNodes(), joinEdge.getLeftExtendedNodes(),
joinEdge.getRightInputSlots(), joinEdge.getLeftInputSlots(), joinEdge);
} else if (LongBitmap.getCardinality(eliminatedLeft) == 1
&& LongBitmap.getCardinality(eliminatedRight) == 0) {
return canEliminatePrimaryByForeign(joinEdge.getLeftExtendedNodes(), joinEdge.getRightExtendedNodes(),
joinEdge.getLeftInputSlots(), joinEdge.getRightInputSlots(), joinEdge);
}
}
return false;
}
private boolean tryEliminateNodesAndEdge() {
boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream()
.filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) == 1)
.anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), eliminateViewNodesMap));
if (hasFilterEdgeAbove) {
// If there is some filter edge above the eliminated node, we should rebuild a plan
// Right now, just refuse it.
// Right now, just reject it.
return false;
}
for (JoinEdge joinEdge : viewHyperGraph.getJoinEdges()) {
if (!LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap)) {
continue;
}
// eliminate by unique
if (joinEdge.getJoinType().isLeftOuterJoin()) {
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedRight) != 1) {
return false;
}
Plan rigthPlan = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
rigthPlan.getLogicalProperties().getFunctionalDependencies());
}
// eliminate by pk fk
if (joinEdge.getJoinType().isInnerJoin()) {
if (!joinEdge.isSimple()) {
return false;
}
long eliminatedLeft =
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), eliminateViewNodesMap);
long eliminatedRight =
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap);
if (LongBitmap.getCardinality(eliminatedLeft) == 0
&& LongBitmap.getCardinality(eliminatedRight) == 1) {
Plan foreign = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
Plan primary = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign);
} else if (LongBitmap.getCardinality(eliminatedLeft) == 1
&& LongBitmap.getCardinality(eliminatedRight) == 0) {
Plan foreign = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
Plan primary = viewHyperGraph
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign);
}
return false;
}
}
return true;
return viewHyperGraph.getJoinEdges().stream()
.filter(joinEdge -> LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap))
.allMatch(this::canEliminateViewEdge);
}
private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode view) {