return residual expr of join (#28760)

This commit is contained in:
谢健
2023-12-25 12:53:14 +08:00
committed by GitHub
parent e9e1e2894b
commit 1d984e0ebb
9 changed files with 380 additions and 85 deletions

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.exploration.mv.ComparisonResult;
import org.apache.doris.nereids.rules.exploration.mv.LogicalCompatibilityContext;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -44,18 +45,21 @@ import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* The graph is a join graph, whose node is the leaf plan and edge is a join operator.
@ -268,11 +272,11 @@ public class HyperGraph {
filterEdges.forEach(e -> {
if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
e.addRejectEdge(joinEdge);
}
if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
e.addRejectEdge(joinEdge);
}
});
}
@ -289,9 +293,11 @@ public class HyperGraph {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
}
@ -299,9 +305,11 @@ public class HyperGraph {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
}
edgeB.setLeftExtendedNodes(leftRequired);
@ -593,57 +601,75 @@ public class HyperGraph {
* compare hypergraph
*
* @param viewHG the compared hyper graph
* @return null represents not compatible, or return some expression which can
* be pull up from this hyper graph
* @return Comparison result
*/
public @Nullable List<Expression> isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
Map<Edge, Edge> queryToView = constructEdgeMap(viewHG, ctx.getQueryToViewEdgeExpressionMapping());
public ComparisonResult isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
// 1 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructMapWithNode(viewHG, ctx.getQueryToViewNodeIDMapping());
// All edge in view must have a mapped edge in query
if (queryToView.size() != viewHG.edgeSize()) {
return null;
// 2. compare them by expression and extract residual expr
ComparisonResult.Builder builder = new ComparisonResult.Builder();
ComparisonResult edgeCompareRes = compareEdgesWithExpr(queryToView, ctx.getQueryToViewEdgeExpressionMapping());
if (edgeCompareRes.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(edgeCompareRes);
// 3. pull join edge of view is no sense, so reject them
if (!queryToView.values().containsAll(viewHG.joinEdges)) {
return ComparisonResult.INVALID;
}
boolean allMatch = queryToView.entrySet().stream()
.allMatch(entry ->
compareEdgeWithNode(entry.getKey(), entry.getValue(), ctx.getQueryToViewNodeIDMapping()));
if (!allMatch) {
return null;
// 4. process residual edges
List<Expression> residualQueryJoin =
processOrphanEdges(Sets.difference(Sets.newHashSet(joinEdges), queryToView.keySet()));
if (residualQueryJoin == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryJoin);
// join edges must be identical
boolean isJoinIdentical = joinEdges.stream()
.allMatch(queryToView::containsKey);
if (!isJoinIdentical) {
return null;
List<Expression> residualQueryFilter =
processOrphanEdges(Sets.difference(Sets.newHashSet(filterEdges), queryToView.keySet()));
if (residualQueryFilter == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryFilter);
// extract all top filters
List<FilterEdge> residualFilterEdges = filterEdges.stream()
.filter(e -> !queryToView.containsKey(e))
.collect(ImmutableList.toImmutableList());
if (residualFilterEdges.stream().anyMatch(e -> !e.isTopFilter())) {
return null;
List<Expression> residualViewFilter =
processOrphanEdges(
Sets.difference(Sets.newHashSet(viewHG.filterEdges), Sets.newHashSet(queryToView.values())));
if (residualViewFilter == null) {
return ComparisonResult.INVALID;
}
return residualFilterEdges.stream()
.flatMap(e -> e.getExpressions().stream())
.collect(ImmutableList.toImmutableList());
builder.addViewExpressions(residualViewFilter);
return builder.build();
}
private Map<Edge, Edge> constructEdgeMap(HyperGraph viewHG, Map<Expression, Expression> exprMap) {
Map<Expression, Edge> exprToEdge = constructExprMap(viewHG);
Map<Edge, Edge> queryToView = new HashMap<>();
joinEdges.stream()
.filter(e -> !e.getExpressions().isEmpty()
&& exprMap.containsKey(e.getExpression(0))
&& compareEdgeWithExpr(e, exprToEdge.get(exprMap.get(e.getExpression(0))), exprMap))
.forEach(e -> queryToView.put(e, exprToEdge.get(exprMap.get(e.getExpression(0)))));
filterEdges.stream()
.filter(e -> !e.getExpressions().isEmpty()
&& exprMap.containsKey(e.getExpression(0))
&& compareEdgeWithExpr(e, exprToEdge.get(exprMap.get(e.getExpression(0))), exprMap))
.forEach(e -> queryToView.put(e, exprToEdge.get(exprMap.get(e.getExpression(0)))));
return queryToView;
private List<Expression> processOrphanEdges(Set<Edge> edges) {
List<Expression> expressions = new ArrayList<>();
for (Edge edge : edges) {
if (!edge.canPullUp()) {
return null;
}
expressions.addAll(edge.getExpressions());
}
return expressions;
}
private Map<Edge, Edge> constructMapWithNode(HyperGraph viewHG, Map<Integer, Integer> nodeMap) {
// TODO use hash map to reduce loop
Map<Edge, Edge> joinEdgeMap = joinEdges.stream().map(qe -> {
Optional<JoinEdge> viewEdge = viewHG.joinEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
Map<Edge, Edge> filterEdgeMap = filterEdges.stream().map(qe -> {
Optional<FilterEdge> viewEdge = viewHG.filterEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
return ImmutableMap.<Edge, Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build();
}
private boolean compareEdgeWithNode(Edge t, Edge o, Map<Integer, Integer> nodeMap) {
@ -686,24 +712,40 @@ public class HyperGraph {
return bitmap2 == newBitmap1;
}
private boolean compareEdgeWithExpr(Edge t, Edge o, Map<Expression, Expression> expressionMap) {
if (t.getExpressions().size() != o.getExpressions().size()) {
return false;
}
int size = t.getExpressions().size();
for (int i = 0; i < size; i++) {
if (!Objects.equals(expressionMap.get(t.getExpression(i)), o.getExpression(i))) {
return false;
private ComparisonResult compareEdgesWithExpr(Map<Edge, Edge> queryToViewedgeMap,
Map<Expression, Expression> queryToView) {
ComparisonResult.Builder builder = new ComparisonResult.Builder();
for (Entry<Edge, Edge> e : queryToViewedgeMap.entrySet()) {
ComparisonResult res = compareEdgeWithExpr(e.getKey(), e.getValue(), queryToView);
if (res.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(res);
}
return true;
return builder.build();
}
private Map<Expression, Edge> constructExprMap(HyperGraph hyperGraph) {
Map<Expression, Edge> exprToEdge = new HashMap<>();
hyperGraph.joinEdges.forEach(edge -> edge.getExpressions().forEach(expr -> exprToEdge.put(expr, edge)));
hyperGraph.filterEdges.forEach(edge -> edge.getExpressions().forEach(expr -> exprToEdge.put(expr, edge)));
return exprToEdge;
private ComparisonResult compareEdgeWithExpr(Edge query, Edge view, Map<Expression, Expression> queryToView) {
Set<? extends Expression> queryExprSet = query.getExpressionSet();
Set<? extends Expression> viewExprSet = view.getExpressionSet();
Set<Expression> equalViewExpr = new HashSet<>();
List<Expression> residualQueryExpr = new ArrayList<>();
for (Expression queryExpr : queryExprSet) {
if (queryToView.containsKey(queryExpr) && viewExprSet.contains(queryToView.get(queryExpr))) {
equalViewExpr.add(queryToView.get(queryExpr));
} else {
residualQueryExpr.add(queryExpr);
}
}
List<Expression> residualViewExpr = ImmutableList.copyOf(Sets.difference(viewExprSet, equalViewExpr));
if (!residualViewExpr.isEmpty() && !view.canPullUp()) {
return ComparisonResult.INVALID;
}
if (!residualQueryExpr.isEmpty() && !query.canPullUp()) {
return ComparisonResult.INVALID;
}
return new ComparisonResult(residualQueryExpr, residualViewExpr);
}
/**

View File

@ -22,6 +22,8 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.collect.ImmutableSet;
import java.util.BitSet;
import java.util.List;
import java.util.Set;
@ -51,6 +53,8 @@ public abstract class Edge {
// record all sub nodes behind in this operator. It's T function in paper
private final long subTreeNodes;
private long rejectNodes = 0;
/**
* Create simple edge.
*/
@ -71,6 +75,10 @@ public abstract class Edge {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}
public void addRejectEdge(Edge edge) {
rejectNodes = LongBitmap.newBitmapUnion(edge.getReferenceNodes(), rejectNodes);
}
public void addLeftExtendNode(long left) {
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left);
}
@ -171,6 +179,20 @@ public abstract class Edge {
public abstract List<? extends Expression> getExpressions();
public Set<? extends Expression> getExpressionSet() {
return ImmutableSet.copyOf(getExpressions());
}
public boolean canPullUp() {
// Only inner join and filter with none rejectNodes can be pull up
return rejectNodes == 0
&& !(this instanceof JoinEdge && !((JoinEdge) this).getJoinType().isInnerJoin());
}
public long getRejectNodes() {
return rejectNodes;
}
public Expression getExpression(int i) {
return getExpressions().get(i);
}

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Set;
@ -32,25 +31,11 @@ import java.util.Set;
*/
public class FilterEdge extends Edge {
private final LogicalFilter<? extends Plan> filter;
private final List<Integer> rejectEdges;
public FilterEdge(LogicalFilter<? extends Plan> filter, int index,
BitSet childEdges, long subTreeNodes, long childRequireNodes) {
super(index, childEdges, new BitSet(), subTreeNodes, childRequireNodes, 0L);
this.filter = filter;
rejectEdges = new ArrayList<>();
}
public void addRejectJoin(JoinEdge joinEdge) {
rejectEdges.add(joinEdge.getIndex());
}
public List<Integer> getRejectEdges() {
return rejectEdges;
}
public boolean isTopFilter() {
return rejectEdges.isEmpty();
}
@Override

View File

@ -136,12 +136,14 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
LogicalCompatibilityContext compatibilityContext =
LogicalCompatibilityContext.from(queryToViewTableMapping, queryToViewSlotMapping,
queryStructInfo, viewStructInfo);
List<Expression> pulledUpExpressions = StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo,
ComparisonResult comparisonResult = StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo,
compatibilityContext);
if (pulledUpExpressions == null) {
if (comparisonResult.isInvalid()) {
logger.debug(currentClassName + " graph logical is not equals so continue");
continue;
}
// TODO: Use set of list? And consider view expr
List<Expression> pulledUpExpressions = ImmutableList.copyOf(comparisonResult.getQueryExpressions());
// set pulled up expression to queryStructInfo predicates and update related predicates
if (!pulledUpExpressions.isEmpty()) {
queryStructInfo.addPredicates(pulledUpExpressions);

View File

@ -0,0 +1,94 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.List;
/**
* comparison result of view and query
*/
public class ComparisonResult {
public static final ComparisonResult INVALID = new ComparisonResult(ImmutableList.of(), ImmutableList.of(), false);
public static final ComparisonResult EMPTY = new ComparisonResult(ImmutableList.of(), ImmutableList.of(), true);
private final boolean valid;
private final List<Expression> viewExpressions;
private final List<Expression> queryExpressions;
public ComparisonResult(List<Expression> queryExpressions, List<Expression> viewExpressions) {
this(queryExpressions, viewExpressions, true);
}
ComparisonResult(List<Expression> queryExpressions, List<Expression> viewExpressions, boolean valid) {
this.viewExpressions = ImmutableList.copyOf(viewExpressions);
this.queryExpressions = ImmutableList.copyOf(queryExpressions);
this.valid = valid;
}
public List<Expression> getViewExpressions() {
return viewExpressions;
}
public List<Expression> getQueryExpressions() {
return queryExpressions;
}
public boolean isInvalid() {
return !valid;
}
/**
* Builder
*/
public static class Builder {
ImmutableList.Builder<Expression> queryBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<Expression> viewBuilder = new ImmutableList.Builder<>();
boolean valid = true;
/**
* add comparisonResult
*/
public Builder addComparisonResult(ComparisonResult comparisonResult) {
if (comparisonResult.isInvalid()) {
valid = false;
return this;
}
queryBuilder.addAll(comparisonResult.getQueryExpressions());
viewBuilder.addAll(comparisonResult.getViewExpressions());
return this;
}
public Builder addQueryExpressions(Collection<Expression> expressions) {
queryBuilder.addAll(expressions);
return this;
}
public Builder addViewExpressions(Collection<Expression> expressions) {
viewBuilder.addAll(expressions);
return this;
}
public ComparisonResult build() {
return new ComparisonResult(queryBuilder.build(), viewBuilder.build(), valid);
}
}
}

View File

@ -263,7 +263,7 @@ public class StructInfo {
* For inner join should judge only the join tables,
* for other join type should also judge the join direction, it's input filter that can not be pulled up etc.
*/
public static @Nullable List<Expression> isGraphLogicalEquals(StructInfo queryStructInfo, StructInfo viewStructInfo,
public static ComparisonResult isGraphLogicalEquals(StructInfo queryStructInfo, StructInfo viewStructInfo,
LogicalCompatibilityContext compatibilityContext) {
return queryStructInfo.hyperGraph.isLogicCompatible(viewStructInfo.hyperGraph, compatibilityContext);
}