extract agg in struct info node (#29853)
This commit is contained in:
@ -19,14 +19,26 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
|
||||
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.LeafPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* HyperGraph Node.
|
||||
@ -34,9 +46,13 @@ import java.util.List;
|
||||
public class StructInfoNode extends AbstractNode {
|
||||
|
||||
private List<HyperGraph> graphs = new ArrayList<>();
|
||||
private final List<Set<Expression>> expressions;
|
||||
private final Set<CatalogRelation> relationSet;
|
||||
|
||||
public StructInfoNode(int index, Plan plan, List<Edge> edges) {
|
||||
super(extractPlan(plan), index, edges);
|
||||
relationSet = plan.collect(CatalogRelation.class::isInstance);
|
||||
expressions = collectExpressions(plan);
|
||||
}
|
||||
|
||||
public StructInfoNode(int index, Plan plan) {
|
||||
@ -48,6 +64,52 @@ public class StructInfoNode extends AbstractNode {
|
||||
this.graphs = graphs;
|
||||
}
|
||||
|
||||
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
|
||||
if (plan instanceof LeafPlan) {
|
||||
return ImmutableList.of();
|
||||
}
|
||||
List<Set<Expression>> childExpressions = collectExpressions(plan.child(0));
|
||||
if (!isValidNodePlan(plan) || childExpressions == null) {
|
||||
return null;
|
||||
}
|
||||
if (plan instanceof LogicalAggregate) {
|
||||
return ImmutableList.<Set<Expression>>builder()
|
||||
.add(ImmutableSet.copyOf(plan.getExpressions()))
|
||||
.add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()))
|
||||
.addAll(childExpressions)
|
||||
.build();
|
||||
}
|
||||
return ImmutableList.<Set<Expression>>builder()
|
||||
.add(ImmutableSet.copyOf(plan.getExpressions()))
|
||||
.addAll(childExpressions)
|
||||
.build();
|
||||
}
|
||||
|
||||
private boolean isValidNodePlan(Plan plan) {
|
||||
return plan instanceof LogicalProject || plan instanceof LogicalAggregate
|
||||
|| plan instanceof LogicalFilter || plan instanceof LogicalCatalogRelation;
|
||||
}
|
||||
|
||||
/**
|
||||
* get all expressions of nodes
|
||||
*/
|
||||
public @Nullable List<Expression> getExpressions() {
|
||||
return expressions.stream()
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public @Nullable List<Set<Expression>> getExprSetList() {
|
||||
return expressions;
|
||||
}
|
||||
|
||||
/**
|
||||
* return catalog relation
|
||||
*/
|
||||
public Set<CatalogRelation> getCatalogRelation() {
|
||||
return relationSet;
|
||||
}
|
||||
|
||||
private static Plan extractPlan(Plan plan) {
|
||||
if (plan instanceof GroupPlan) {
|
||||
//TODO: Note mv can be in logicalExpression, how can we choose it
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -90,14 +91,21 @@ public class HyperGraphComparator {
|
||||
}
|
||||
|
||||
private ComparisonResult isLogicCompatible() {
|
||||
// 1 try to construct a map which can be mapped from edge to edge
|
||||
// 1 compare nodes
|
||||
boolean nodeMatches = logicalCompatibilityContext.getQueryToViewNodeMapping().entrySet()
|
||||
.stream().allMatch(e -> compareNodeWithExpr(e.getKey(), e.getValue()));
|
||||
if (!nodeMatches) {
|
||||
return ComparisonResult.newInvalidResWithErrorMessage("StructInfoNode are not compatible\n");
|
||||
}
|
||||
|
||||
// 2 try to construct a map which can be mapped from edge to edge
|
||||
Map<Edge, Edge> queryToView = constructQueryToViewMapWithExpr();
|
||||
if (!makeViewJoinCompatible(queryToView)) {
|
||||
return ComparisonResult.newInvalidResWithErrorMessage("Join types are not compatible\n");
|
||||
}
|
||||
refreshViewEdges();
|
||||
|
||||
// 2. compare them by expression and nodes. Note compare edges after inferring for nodes
|
||||
// 3. compare them by expression and nodes. Note compare edges after inferring for nodes
|
||||
boolean matchNodes = queryToView.entrySet().stream()
|
||||
.allMatch(e -> compareEdgeWithNode(e.getKey(), e.getValue()));
|
||||
if (!matchNodes) {
|
||||
@ -105,7 +113,7 @@ public class HyperGraphComparator {
|
||||
}
|
||||
queryToView.forEach(this::compareEdgeWithExpr);
|
||||
|
||||
// 3. process residual edges
|
||||
// 1. process residual edges
|
||||
Sets.difference(getQueryJoinEdgeSet(), queryToView.keySet())
|
||||
.forEach(e -> pullUpQueryExprWithEdge.put(e, e.getExpressions()));
|
||||
Sets.difference(getQueryFilterEdgeSet(), queryToView.keySet())
|
||||
@ -118,6 +126,25 @@ public class HyperGraphComparator {
|
||||
return buildComparisonRes();
|
||||
}
|
||||
|
||||
private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode view) {
|
||||
List<Set<Expression>> queryExprSetList = query.getExprSetList();
|
||||
List<Set<Expression>> viewExprSetList = view.getExprSetList();
|
||||
if (queryExprSetList == null || viewExprSetList == null
|
||||
|| queryExprSetList.size() != viewExprSetList.size()) {
|
||||
return false;
|
||||
}
|
||||
int size = queryExprSetList.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
Set<Expression> mappingQueryExprSet = queryExprSetList.get(i).stream()
|
||||
.map(e -> logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping().get(e))
|
||||
.collect(Collectors.toSet());
|
||||
if (!mappingQueryExprSet.equals(viewExprSetList.get(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private ComparisonResult buildComparisonRes() {
|
||||
ComparisonResult.Builder builder = new ComparisonResult.Builder();
|
||||
for (Entry<Edge, List<? extends Expression>> e : pullUpQueryExprWithEdge.entrySet()) {
|
||||
@ -134,7 +161,7 @@ public class HyperGraphComparator {
|
||||
.filter(expr -> !ExpressionUtils.isInferred(expr))
|
||||
.collect(Collectors.toList());
|
||||
if (!rawFilter.isEmpty() && !canPullUp(getViewEdgeAfterInferring(e.getKey()))) {
|
||||
return ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "with error edge\n" + e);
|
||||
return ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "\nwith error edge\n" + e);
|
||||
}
|
||||
builder.addViewExpressions(rawFilter);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user