[Improvement](Nereids) Support to query rewrite by materialized view when join input has aggregate (#30230)

Support to query rewrite by materialized view when join input has aggregate, the aggregate should be simple
For example as following:
The materialized view def is 
>            select
>              l_linenumber,
>              count(distinct l_orderkey),
>              sum(case when l_orderkey in (1,2,3) then l_suppkey * l_linenumber else 0 end),
>              max(case when l_orderkey in (4, 5) then (l_quantity *2 + part_supp_a.qty_max) * 0.88 else 100 end),
>              avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
>            from lineitem
>            left join orders on l_orderkey = o_orderkey
>            left join 
>              (select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
>                min(ps_availqty) qty_min,
>                avg(ps_supplycost) cost_avg
>                from partsupp
>                group by ps_partkey,ps_suppkey) part_supp_a
>              on l_partkey = part_supp_a.ps_partkey
>                and l_suppkey = part_supp_a.ps_suppkey
>            group by l_linenumber;

when query is like following, it can be rewritten by mv above
>            select
>              l_linenumber,
>              sum(case when l_orderkey in (1,2,3) then l_suppkey * l_linenumber else 0 end),
>              avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
>            from lineitem
>            left join orders on l_orderkey = o_orderkey
>            left join 
>              (select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
>                min(ps_availqty) qty_min,
>                avg(ps_supplycost) cost_avg
>                from partsupp
>                group by ps_partkey,ps_suppkey) part_supp_a
>              on l_partkey = part_supp_a.ps_partkey
>                and l_suppkey = part_supp_a.ps_suppkey
>            group by l_linenumber;
This commit is contained in:
seawinde
2024-01-24 10:46:50 +08:00
committed by yiguolei
parent f85b04c2c6
commit 2f68aac885
13 changed files with 441 additions and 55 deletions

View File

@ -17,19 +17,21 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
@ -57,24 +59,50 @@ public class StructInfoNode extends AbstractNode {
}
private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
if (plan instanceof LeafPlan) {
return ImmutableList.of();
}
List<Set<Expression>> childExpressions = collectExpressions(plan.child(0));
if (!isValidNodePlan(plan) || childExpressions == null) {
return null;
}
if (plan instanceof LogicalAggregate) {
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()))
.addAll(childExpressions)
.build();
}
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.addAll(childExpressions)
.build();
Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, ImmutableList.builder());
plan.accept(new DefaultPlanVisitor<Void, Pair<Boolean, ImmutableList.Builder<Set<Expression>>>>() {
@Override
public Void visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
collector.value().add(ImmutableSet.copyOf(aggregate.getExpressions()));
collector.value().add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()));
return super.visit(aggregate, collector);
}
@Override
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
collector.value().add(ImmutableSet.copyOf(filter.getExpressions()));
return super.visit(filter, collector);
}
@Override
public Void visitGroupPlan(GroupPlan groupPlan,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
Plan groupActualPlan = groupPlan.getGroup().getLogicalExpressions().get(0).getPlan();
return groupActualPlan.accept(this, collector);
}
@Override
public Void visit(Plan plan, Pair<Boolean, ImmutableList.Builder<Set<Expression>>> context) {
if (!isValidNodePlan(plan)) {
context.first = false;
return null;
}
return super.visit(plan, context);
}
}, collector);
return collector.key() ? collector.value().build() : null;
}
private boolean isValidNodePlan(Plan plan) {
@ -104,7 +132,7 @@ public class StructInfoNode extends AbstractNode {
private static Plan extractPlan(Plan plan) {
if (plan instanceof GroupPlan) {
//TODO: Note mv can be in logicalExpression, how can we choose it
// TODO: Note mv can be in logicalExpression, how can we choose it
plan = ((GroupPlan) plan).getGroup().getLogicalExpressions().get(0)
.getPlan();
}

View File

@ -101,7 +101,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Split view to top plan and agg fail",
Pair.of("Split view to top plan and agg fail, view doesn't not contain aggregate",
String.format("view plan = %s\n", viewStructInfo.getOriginalPlan().treeString())));
return null;
}

View File

@ -186,11 +186,13 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Predicate compensate fail",
String.format("query predicates = %s,\n query equivalenceClass = %s, \n"
+ "view predicates = %s,\n query equivalenceClass = %s\n",
+ "view predicates = %s,\n query equivalenceClass = %s\n"
+ "comparisonResult = %s ",
queryStructInfo.getPredicates(),
queryStructInfo.getEquivalenceClass(),
viewStructInfo.getPredicates(),
viewStructInfo.getEquivalenceClass())));
viewStructInfo.getEquivalenceClass(),
comparisonResult)));
continue;
}
Plan rewrittenPlan;
@ -467,21 +469,22 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
Set<Set<Slot>> requireNoNullableViewSlot = comparisonResult.getViewNoNullableSlot();
// check query is use the null reject slot which view comparison need
if (!requireNoNullableViewSlot.isEmpty()) {
Set<Expression> queryPulledUpPredicates = queryStructInfo.getPredicates().getPulledUpPredicates();
Set<Expression> queryPulledUpPredicates = comparisonResult.getQueryAllPulledUpExpressions().stream()
.flatMap(expr -> ExpressionUtils.extractConjunction(expr).stream())
.collect(Collectors.toSet());
Set<Expression> nullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates,
cascadesContext);
if (nullRejectPredicates.isEmpty() || queryPulledUpPredicates.containsAll(nullRejectPredicates)) {
// query has not null reject predicates, so return
return SplitPredicate.INVALID_INSTANCE;
}
SlotMapping queryToViewMapping = viewToQuerySlotMapping.inverse();
Set<Expression> queryUsedNeedRejectNullSlotsViewBased = nullRejectPredicates.stream()
.map(expression -> TypeUtils.isNotNull(expression).orElse(null))
.filter(Objects::nonNull)
.map(expr -> ExpressionUtils.replace((Expression) expr, queryToViewMapping.toSlotReferenceMap()))
.collect(Collectors.toSet());
if (requireNoNullableViewSlot.stream().anyMatch(
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty())) {
// query pulledUp predicates should have null reject predicates and contains any require noNullable slot
boolean valid = !queryPulledUpPredicates.containsAll(nullRejectPredicates)
&& requireNoNullableViewSlot.stream().noneMatch(
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty());
if (!valid) {
return SplitPredicate.INVALID_INSTANCE;
}
}

View File

@ -35,20 +35,23 @@ public class ComparisonResult {
private final boolean valid;
private final List<Expression> viewExpressions;
private final List<Expression> queryExpressions;
private final List<Expression> queryAllPulledUpExpressions;
private final Set<Set<Slot>> viewNoNullableSlot;
private final String errorMessage;
ComparisonResult(List<Expression> queryExpressions, List<Expression> viewExpressions,
Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
ComparisonResult(List<Expression> queryExpressions, List<Expression> queryAllPulledUpExpressions,
List<Expression> viewExpressions, Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
this.viewExpressions = ImmutableList.copyOf(viewExpressions);
this.queryExpressions = ImmutableList.copyOf(queryExpressions);
this.queryAllPulledUpExpressions = ImmutableList.copyOf(queryAllPulledUpExpressions);
this.viewNoNullableSlot = ImmutableSet.copyOf(viewNoNullableSlot);
this.valid = valid;
this.errorMessage = message;
}
public static ComparisonResult newInvalidResWithErrorMessage(String errorMessage) {
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), false, errorMessage);
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
ImmutableSet.of(), false, errorMessage);
}
public List<Expression> getViewExpressions() {
@ -59,6 +62,10 @@ public class ComparisonResult {
return queryExpressions;
}
public List<Expression> getQueryAllPulledUpExpressions() {
return queryAllPulledUpExpressions;
}
public Set<Set<Slot>> getViewNoNullableSlot() {
return viewNoNullableSlot;
}
@ -78,6 +85,7 @@ public class ComparisonResult {
ImmutableList.Builder<Expression> queryBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<Expression> viewBuilder = new ImmutableList.Builder<>();
ImmutableSet.Builder<Set<Slot>> viewNoNullableSlotBuilder = new ImmutableSet.Builder<>();
ImmutableList.Builder<Expression> queryAllPulledUpExpressionsBuilder = new ImmutableList.Builder<>();
boolean valid = true;
/**
@ -108,25 +116,29 @@ public class ComparisonResult {
return this;
}
public Builder addQueryAllPulledUpExpressions(Collection<? extends Expression> expressions) {
queryAllPulledUpExpressionsBuilder.addAll(expressions);
return this;
}
public boolean isInvalid() {
return !valid;
}
public ComparisonResult build() {
Preconditions.checkArgument(valid, "Comparison result must be valid");
return new ComparisonResult(queryBuilder.build(), viewBuilder.build(),
viewNoNullableSlotBuilder.build(), valid, "");
return new ComparisonResult(queryBuilder.build(), queryAllPulledUpExpressionsBuilder.build(),
viewBuilder.build(), viewNoNullableSlotBuilder.build(), valid, "");
}
}
@Override
public String toString() {
if (isInvalid()) {
return "INVALID";
}
return String.format("viewExpressions: %s \n "
+ "queryExpressions :%s \n "
+ "viewNoNullableSlot :%s \n",
viewExpressions, queryExpressions, viewNoNullableSlot);
return String.format("valid: %s \n "
+ "viewExpressions: %s \n "
+ "queryExpressions :%s \n "
+ "viewNoNullableSlot :%s \n"
+ "queryAllPulledUpExpressions :%s \n", valid, viewExpressions, queryExpressions,
viewNoNullableSlot, queryAllPulledUpExpressions);
}
}

View File

@ -168,6 +168,10 @@ public class HyperGraphComparator {
for (Pair<JoinType, Set<Slot>> inferredCond : inferredViewEdgeWithCond.values()) {
builder.addViewNoNullableSlot(inferredCond.second);
}
builder.addQueryAllPulledUpExpressions(
getQueryFilterEdges().stream()
.filter(this::canPullUp)
.flatMap(filter -> filter.getExpressions().stream()).collect(Collectors.toList()));
return builder.build();
}

View File

@ -170,6 +170,8 @@ public class StructInfo {
}
// Collect relations from hyper graph which in the bottom plan
hyperGraph.getNodes().forEach(node -> {
// plan relation collector and set to map
StructInfoNode structInfoNode = (StructInfoNode) node;
// plan relation collector and set to map
Plan nodePlan = node.getPlan();
List<CatalogRelation> nodeRelations = new ArrayList<>();
@ -177,6 +179,24 @@ public class StructInfo {
relationBuilder.addAll(nodeRelations);
// every node should only have one relation, this is for LogicalCompatibilityContext
relationIdStructInfoNodeMap.put(nodeRelations.get(0).getRelationId(), (StructInfoNode) node);
// record expressions in node
if (structInfoNode.getExpressions() != null) {
structInfoNode.getExpressions().forEach(expression -> {
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(
Lists.newArrayList(expression),
ImmutableSet.of(),
ImmutableSet.of());
topPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
// Replace expressions by expression map
List<Expression> replacedExpressions = replaceContext.getReplacedExpressions();
shuttledHashConjunctsToConjunctsMap.put(replacedExpressions.get(0), expression);
// Record this, will be used in top level expression shuttle later, see the method
// ExpressionLineageReplacer#visitGroupPlan
namedExprIdAndExprMapping.putAll(replaceContext.getExprIdExpressionMap());
});
}
});
// Collect expression from where in hyper graph
hyperGraph.getFilterEdges().forEach(filterEdge -> {
@ -436,7 +456,9 @@ public class StructInfo {
if (!(plan instanceof Filter)
&& !(plan instanceof Project)
&& !(plan instanceof CatalogRelation)
&& !(plan instanceof Join)) {
&& !(plan instanceof Join)
&& !(plan instanceof LogicalAggregate && !((LogicalAggregate) plan).getSourceRepeat()
.isPresent())) {
return false;
}
if (plan instanceof Join) {