[Improvement](Nereids) Support to query rewrite by materialized view when join input has aggregate (#30230)
Support to query rewrite by materialized view when join input has aggregate, the aggregate should be simple For example as following: The materialized view def is > select > l_linenumber, > count(distinct l_orderkey), > sum(case when l_orderkey in (1,2,3) then l_suppkey * l_linenumber else 0 end), > max(case when l_orderkey in (4, 5) then (l_quantity *2 + part_supp_a.qty_max) * 0.88 else 100 end), > avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end) > from lineitem > left join orders on l_orderkey = o_orderkey > left join > (select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max, > min(ps_availqty) qty_min, > avg(ps_supplycost) cost_avg > from partsupp > group by ps_partkey,ps_suppkey) part_supp_a > on l_partkey = part_supp_a.ps_partkey > and l_suppkey = part_supp_a.ps_suppkey > group by l_linenumber; when query is like following, it can be rewritten by mv above > select > l_linenumber, > sum(case when l_orderkey in (1,2,3) then l_suppkey * l_linenumber else 0 end), > avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end) > from lineitem > left join orders on l_orderkey = o_orderkey > left join > (select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max, > min(ps_availqty) qty_min, > avg(ps_supplycost) cost_avg > from partsupp > group by ps_partkey,ps_suppkey) part_supp_a > on l_partkey = part_supp_a.ps_partkey > and l_suppkey = part_supp_a.ps_suppkey > group by l_linenumber;
This commit is contained in:
@ -17,19 +17,21 @@
|
||||
|
||||
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.LeafPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -57,24 +59,50 @@ public class StructInfoNode extends AbstractNode {
|
||||
}
|
||||
|
||||
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
|
||||
if (plan instanceof LeafPlan) {
|
||||
return ImmutableList.of();
|
||||
}
|
||||
List<Set<Expression>> childExpressions = collectExpressions(plan.child(0));
|
||||
if (!isValidNodePlan(plan) || childExpressions == null) {
|
||||
return null;
|
||||
}
|
||||
if (plan instanceof LogicalAggregate) {
|
||||
return ImmutableList.<Set<Expression>>builder()
|
||||
.add(ImmutableSet.copyOf(plan.getExpressions()))
|
||||
.add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()))
|
||||
.addAll(childExpressions)
|
||||
.build();
|
||||
}
|
||||
return ImmutableList.<Set<Expression>>builder()
|
||||
.add(ImmutableSet.copyOf(plan.getExpressions()))
|
||||
.addAll(childExpressions)
|
||||
.build();
|
||||
|
||||
Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, ImmutableList.builder());
|
||||
plan.accept(new DefaultPlanVisitor<Void, Pair<Boolean, ImmutableList.Builder<Set<Expression>>>>() {
|
||||
@Override
|
||||
public Void visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate,
|
||||
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
|
||||
if (!collector.key()) {
|
||||
return null;
|
||||
}
|
||||
collector.value().add(ImmutableSet.copyOf(aggregate.getExpressions()));
|
||||
collector.value().add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()));
|
||||
return super.visit(aggregate, collector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter,
|
||||
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
|
||||
if (!collector.key()) {
|
||||
return null;
|
||||
}
|
||||
collector.value().add(ImmutableSet.copyOf(filter.getExpressions()));
|
||||
return super.visit(filter, collector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitGroupPlan(GroupPlan groupPlan,
|
||||
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
|
||||
if (!collector.key()) {
|
||||
return null;
|
||||
}
|
||||
Plan groupActualPlan = groupPlan.getGroup().getLogicalExpressions().get(0).getPlan();
|
||||
return groupActualPlan.accept(this, collector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visit(Plan plan, Pair<Boolean, ImmutableList.Builder<Set<Expression>>> context) {
|
||||
if (!isValidNodePlan(plan)) {
|
||||
context.first = false;
|
||||
return null;
|
||||
}
|
||||
return super.visit(plan, context);
|
||||
}
|
||||
}, collector);
|
||||
return collector.key() ? collector.value().build() : null;
|
||||
}
|
||||
|
||||
private boolean isValidNodePlan(Plan plan) {
|
||||
@ -104,7 +132,7 @@ public class StructInfoNode extends AbstractNode {
|
||||
|
||||
private static Plan extractPlan(Plan plan) {
|
||||
if (plan instanceof GroupPlan) {
|
||||
//TODO: Note mv can be in logicalExpression, how can we choose it
|
||||
// TODO: Note mv can be in logicalExpression, how can we choose it
|
||||
plan = ((GroupPlan) plan).getGroup().getLogicalExpressions().get(0)
|
||||
.getPlan();
|
||||
}
|
||||
|
||||
@ -101,7 +101,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
|
||||
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
|
||||
if (viewTopPlanAndAggPair == null) {
|
||||
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
|
||||
Pair.of("Split view to top plan and agg fail",
|
||||
Pair.of("Split view to top plan and agg fail, view doesn't not contain aggregate",
|
||||
String.format("view plan = %s\n", viewStructInfo.getOriginalPlan().treeString())));
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -186,11 +186,13 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
|
||||
Pair.of("Predicate compensate fail",
|
||||
String.format("query predicates = %s,\n query equivalenceClass = %s, \n"
|
||||
+ "view predicates = %s,\n query equivalenceClass = %s\n",
|
||||
+ "view predicates = %s,\n query equivalenceClass = %s\n"
|
||||
+ "comparisonResult = %s ",
|
||||
queryStructInfo.getPredicates(),
|
||||
queryStructInfo.getEquivalenceClass(),
|
||||
viewStructInfo.getPredicates(),
|
||||
viewStructInfo.getEquivalenceClass())));
|
||||
viewStructInfo.getEquivalenceClass(),
|
||||
comparisonResult)));
|
||||
continue;
|
||||
}
|
||||
Plan rewrittenPlan;
|
||||
@ -467,21 +469,22 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
Set<Set<Slot>> requireNoNullableViewSlot = comparisonResult.getViewNoNullableSlot();
|
||||
// check query is use the null reject slot which view comparison need
|
||||
if (!requireNoNullableViewSlot.isEmpty()) {
|
||||
Set<Expression> queryPulledUpPredicates = queryStructInfo.getPredicates().getPulledUpPredicates();
|
||||
Set<Expression> queryPulledUpPredicates = comparisonResult.getQueryAllPulledUpExpressions().stream()
|
||||
.flatMap(expr -> ExpressionUtils.extractConjunction(expr).stream())
|
||||
.collect(Collectors.toSet());
|
||||
Set<Expression> nullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates,
|
||||
cascadesContext);
|
||||
if (nullRejectPredicates.isEmpty() || queryPulledUpPredicates.containsAll(nullRejectPredicates)) {
|
||||
// query has not null reject predicates, so return
|
||||
return SplitPredicate.INVALID_INSTANCE;
|
||||
}
|
||||
SlotMapping queryToViewMapping = viewToQuerySlotMapping.inverse();
|
||||
Set<Expression> queryUsedNeedRejectNullSlotsViewBased = nullRejectPredicates.stream()
|
||||
.map(expression -> TypeUtils.isNotNull(expression).orElse(null))
|
||||
.filter(Objects::nonNull)
|
||||
.map(expr -> ExpressionUtils.replace((Expression) expr, queryToViewMapping.toSlotReferenceMap()))
|
||||
.collect(Collectors.toSet());
|
||||
if (requireNoNullableViewSlot.stream().anyMatch(
|
||||
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty())) {
|
||||
// query pulledUp predicates should have null reject predicates and contains any require noNullable slot
|
||||
boolean valid = !queryPulledUpPredicates.containsAll(nullRejectPredicates)
|
||||
&& requireNoNullableViewSlot.stream().noneMatch(
|
||||
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty());
|
||||
if (!valid) {
|
||||
return SplitPredicate.INVALID_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,20 +35,23 @@ public class ComparisonResult {
|
||||
private final boolean valid;
|
||||
private final List<Expression> viewExpressions;
|
||||
private final List<Expression> queryExpressions;
|
||||
private final List<Expression> queryAllPulledUpExpressions;
|
||||
private final Set<Set<Slot>> viewNoNullableSlot;
|
||||
private final String errorMessage;
|
||||
|
||||
ComparisonResult(List<Expression> queryExpressions, List<Expression> viewExpressions,
|
||||
Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
|
||||
ComparisonResult(List<Expression> queryExpressions, List<Expression> queryAllPulledUpExpressions,
|
||||
List<Expression> viewExpressions, Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
|
||||
this.viewExpressions = ImmutableList.copyOf(viewExpressions);
|
||||
this.queryExpressions = ImmutableList.copyOf(queryExpressions);
|
||||
this.queryAllPulledUpExpressions = ImmutableList.copyOf(queryAllPulledUpExpressions);
|
||||
this.viewNoNullableSlot = ImmutableSet.copyOf(viewNoNullableSlot);
|
||||
this.valid = valid;
|
||||
this.errorMessage = message;
|
||||
}
|
||||
|
||||
public static ComparisonResult newInvalidResWithErrorMessage(String errorMessage) {
|
||||
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), false, errorMessage);
|
||||
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
|
||||
ImmutableSet.of(), false, errorMessage);
|
||||
}
|
||||
|
||||
public List<Expression> getViewExpressions() {
|
||||
@ -59,6 +62,10 @@ public class ComparisonResult {
|
||||
return queryExpressions;
|
||||
}
|
||||
|
||||
public List<Expression> getQueryAllPulledUpExpressions() {
|
||||
return queryAllPulledUpExpressions;
|
||||
}
|
||||
|
||||
public Set<Set<Slot>> getViewNoNullableSlot() {
|
||||
return viewNoNullableSlot;
|
||||
}
|
||||
@ -78,6 +85,7 @@ public class ComparisonResult {
|
||||
ImmutableList.Builder<Expression> queryBuilder = new ImmutableList.Builder<>();
|
||||
ImmutableList.Builder<Expression> viewBuilder = new ImmutableList.Builder<>();
|
||||
ImmutableSet.Builder<Set<Slot>> viewNoNullableSlotBuilder = new ImmutableSet.Builder<>();
|
||||
ImmutableList.Builder<Expression> queryAllPulledUpExpressionsBuilder = new ImmutableList.Builder<>();
|
||||
boolean valid = true;
|
||||
|
||||
/**
|
||||
@ -108,25 +116,29 @@ public class ComparisonResult {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addQueryAllPulledUpExpressions(Collection<? extends Expression> expressions) {
|
||||
queryAllPulledUpExpressionsBuilder.addAll(expressions);
|
||||
return this;
|
||||
}
|
||||
|
||||
public boolean isInvalid() {
|
||||
return !valid;
|
||||
}
|
||||
|
||||
public ComparisonResult build() {
|
||||
Preconditions.checkArgument(valid, "Comparison result must be valid");
|
||||
return new ComparisonResult(queryBuilder.build(), viewBuilder.build(),
|
||||
viewNoNullableSlotBuilder.build(), valid, "");
|
||||
return new ComparisonResult(queryBuilder.build(), queryAllPulledUpExpressionsBuilder.build(),
|
||||
viewBuilder.build(), viewNoNullableSlotBuilder.build(), valid, "");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (isInvalid()) {
|
||||
return "INVALID";
|
||||
}
|
||||
return String.format("viewExpressions: %s \n "
|
||||
+ "queryExpressions :%s \n "
|
||||
+ "viewNoNullableSlot :%s \n",
|
||||
viewExpressions, queryExpressions, viewNoNullableSlot);
|
||||
return String.format("valid: %s \n "
|
||||
+ "viewExpressions: %s \n "
|
||||
+ "queryExpressions :%s \n "
|
||||
+ "viewNoNullableSlot :%s \n"
|
||||
+ "queryAllPulledUpExpressions :%s \n", valid, viewExpressions, queryExpressions,
|
||||
viewNoNullableSlot, queryAllPulledUpExpressions);
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,6 +168,10 @@ public class HyperGraphComparator {
|
||||
for (Pair<JoinType, Set<Slot>> inferredCond : inferredViewEdgeWithCond.values()) {
|
||||
builder.addViewNoNullableSlot(inferredCond.second);
|
||||
}
|
||||
builder.addQueryAllPulledUpExpressions(
|
||||
getQueryFilterEdges().stream()
|
||||
.filter(this::canPullUp)
|
||||
.flatMap(filter -> filter.getExpressions().stream()).collect(Collectors.toList()));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
|
||||
@ -170,6 +170,8 @@ public class StructInfo {
|
||||
}
|
||||
// Collect relations from hyper graph which in the bottom plan
|
||||
hyperGraph.getNodes().forEach(node -> {
|
||||
// plan relation collector and set to map
|
||||
StructInfoNode structInfoNode = (StructInfoNode) node;
|
||||
// plan relation collector and set to map
|
||||
Plan nodePlan = node.getPlan();
|
||||
List<CatalogRelation> nodeRelations = new ArrayList<>();
|
||||
@ -177,6 +179,24 @@ public class StructInfo {
|
||||
relationBuilder.addAll(nodeRelations);
|
||||
// every node should only have one relation, this is for LogicalCompatibilityContext
|
||||
relationIdStructInfoNodeMap.put(nodeRelations.get(0).getRelationId(), (StructInfoNode) node);
|
||||
|
||||
// record expressions in node
|
||||
if (structInfoNode.getExpressions() != null) {
|
||||
structInfoNode.getExpressions().forEach(expression -> {
|
||||
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
|
||||
new ExpressionLineageReplacer.ExpressionReplaceContext(
|
||||
Lists.newArrayList(expression),
|
||||
ImmutableSet.of(),
|
||||
ImmutableSet.of());
|
||||
topPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
|
||||
// Replace expressions by expression map
|
||||
List<Expression> replacedExpressions = replaceContext.getReplacedExpressions();
|
||||
shuttledHashConjunctsToConjunctsMap.put(replacedExpressions.get(0), expression);
|
||||
// Record this, will be used in top level expression shuttle later, see the method
|
||||
// ExpressionLineageReplacer#visitGroupPlan
|
||||
namedExprIdAndExprMapping.putAll(replaceContext.getExprIdExpressionMap());
|
||||
});
|
||||
}
|
||||
});
|
||||
// Collect expression from where in hyper graph
|
||||
hyperGraph.getFilterEdges().forEach(filterEdge -> {
|
||||
@ -436,7 +456,9 @@ public class StructInfo {
|
||||
if (!(plan instanceof Filter)
|
||||
&& !(plan instanceof Project)
|
||||
&& !(plan instanceof CatalogRelation)
|
||||
&& !(plan instanceof Join)) {
|
||||
&& !(plan instanceof Join)
|
||||
&& !(plan instanceof LogicalAggregate && !((LogicalAggregate) plan).getSourceRepeat()
|
||||
.isPresent())) {
|
||||
return false;
|
||||
}
|
||||
if (plan instanceof Join) {
|
||||
|
||||
Reference in New Issue
Block a user