extract agg in struct info node (#29853)

This commit is contained in:
谢健
2024-01-15 13:32:21 +08:00
committed by yiguolei
parent ffc6f58e85
commit e09118eb9c
3 changed files with 191 additions and 4 deletions

View File

@ -19,14 +19,26 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
* HyperGraph Node.
@ -34,9 +46,13 @@ import java.util.List;
public class StructInfoNode extends AbstractNode {
private List<HyperGraph> graphs = new ArrayList<>();
private final List<Set<Expression>> expressions;
private final Set<CatalogRelation> relationSet;
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
super(extractPlan(plan), index, edges);
relationSet = plan.collect(CatalogRelation.class::isInstance);
expressions = collectExpressions(plan);
}
public StructInfoNode(int index, Plan plan) {
@ -48,6 +64,52 @@ public class StructInfoNode extends AbstractNode {
this.graphs = graphs;
}
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
if (plan instanceof LeafPlan) {
return ImmutableList.of();
}
List<Set<Expression>> childExpressions = collectExpressions(plan.child(0));
if (!isValidNodePlan(plan) || childExpressions == null) {
return null;
}
if (plan instanceof LogicalAggregate) {
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()))
.addAll(childExpressions)
.build();
}
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.addAll(childExpressions)
.build();
}
private boolean isValidNodePlan(Plan plan) {
return plan instanceof LogicalProject || plan instanceof LogicalAggregate
|| plan instanceof LogicalFilter || plan instanceof LogicalCatalogRelation;
}
/**
* get all expressions of nodes
*/
public @Nullable List<Expression> getExpressions() {
return expressions.stream()
.flatMap(Collection::stream)
.collect(Collectors.toList());
}
public @Nullable List<Set<Expression>> getExprSetList() {
return expressions;
}
/**
* return catalog relation
*/
public Set<CatalogRelation> getCatalogRelation() {
return relationSet;
}
private static Plan extractPlan(Plan plan) {
if (plan instanceof GroupPlan) {
//TODO: Note mv can be in logicalExpression, how can we choose it

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -90,14 +91,21 @@ public class HyperGraphComparator {
}
private ComparisonResult isLogicCompatible() {
// 1 try to construct a map which can be mapped from edge to edge
// 1 compare nodes
boolean nodeMatches = logicalCompatibilityContext.getQueryToViewNodeMapping().entrySet()
.stream().allMatch(e -> compareNodeWithExpr(e.getKey(), e.getValue()));
if (!nodeMatches) {
return ComparisonResult.newInvalidResWithErrorMessage("StructInfoNode are not compatible\n");
}
// 2 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructQueryToViewMapWithExpr();
if (!makeViewJoinCompatible(queryToView)) {
return ComparisonResult.newInvalidResWithErrorMessage("Join types are not compatible\n");
}
refreshViewEdges();
// 2. compare them by expression and nodes. Note compare edges after inferring for nodes
// 3. compare them by expression and nodes. Note compare edges after inferring for nodes
boolean matchNodes = queryToView.entrySet().stream()
.allMatch(e -> compareEdgeWithNode(e.getKey(), e.getValue()));
if (!matchNodes) {
@ -105,7 +113,7 @@ public class HyperGraphComparator {
}
queryToView.forEach(this::compareEdgeWithExpr);
// 3. process residual edges
// 1. process residual edges
Sets.difference(getQueryJoinEdgeSet(), queryToView.keySet())
.forEach(e -> pullUpQueryExprWithEdge.put(e, e.getExpressions()));
Sets.difference(getQueryFilterEdgeSet(), queryToView.keySet())
@ -118,6 +126,25 @@ public class HyperGraphComparator {
return buildComparisonRes();
}
private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode view) {
List<Set<Expression>> queryExprSetList = query.getExprSetList();
List<Set<Expression>> viewExprSetList = view.getExprSetList();
if (queryExprSetList == null || viewExprSetList == null
|| queryExprSetList.size() != viewExprSetList.size()) {
return false;
}
int size = queryExprSetList.size();
for (int i = 0; i < size; i++) {
Set<Expression> mappingQueryExprSet = queryExprSetList.get(i).stream()
.map(e -> logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping().get(e))
.collect(Collectors.toSet());
if (!mappingQueryExprSet.equals(viewExprSetList.get(i))) {
return false;
}
}
return true;
}
private ComparisonResult buildComparisonRes() {
ComparisonResult.Builder builder = new ComparisonResult.Builder();
for (Entry<Edge, List<? extends Expression>> e : pullUpQueryExprWithEdge.entrySet()) {
@ -134,7 +161,7 @@ public class HyperGraphComparator {
.filter(expr -> !ExpressionUtils.isInferred(expr))
.collect(Collectors.toList());
if (!rawFilter.isEmpty() && !canPullUp(getViewEdgeAfterInferring(e.getKey()))) {
return ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "with error edge\n" + e);
return ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "\nwith error edge\n" + e);
}
builder.addViewExpressions(rawFilter);
}