[Improvement](Nereids) Support aggregate rewrite by materialized view with complex expression (#30440)

materialized view definition is

>            select
>            sum(o_totalprice) as sum_total,
>            max(o_totalprice) as max_total,
>            min(o_totalprice) as min_total,
>           count(*) as count_all,
>            bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) >cnt_1,
>            bitmap_union(to_bitmap(case when o_shippriority > 2 and o_orderkey IN (2) then o_custkey else null end)) as >cnt_2
>            from lineitem
>            left join orders on l_orderkey = o_orderkey and l_shipdate = o_orderdate;
   

the query following can be rewritten by materialized view above.
it use the aggregate fuction arithmetic calculation in the select 

>            select
>            count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) then o_custkey else null end) as cnt_2,
>            (sum(o_totalprice) + min(o_totalprice)) * count(*),
>            min(o_totalprice) + count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) then o_custkey else null >end)
>            from lineitem
>            left join orders on l_orderkey = o_orderkey and l_shipdate = o_orderdate;
This commit is contained in:
seawinde
2024-01-29 16:47:26 +08:00
committed by yiguolei
parent edeec320d3
commit dce6c8bd65
5 changed files with 402 additions and 28 deletions

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -65,6 +66,8 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
protected static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
protected static final AggregateExpressionRewriter AGGREGATE_EXPRESSION_REWRITER =
new AggregateExpressionRewriter();
static {
// support count distinct roll up
@ -156,7 +159,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
Set<? extends Expression> queryTopPlanFunctionSet = queryGroupAndFunctionPair.value();
// try to rewrite, contains both roll up aggregate functions and aggregate group expression
List<NamedExpression> finalAggregateExpressions = new ArrayList<>();
List<NamedExpression> finalOutputExpressions = new ArrayList<>();
List<Expression> finalGroupExpressions = new ArrayList<>();
List<? extends Expression> queryExpressions = queryTopPlan.getExpressions();
// permute the mv expr mapping to query based
@ -169,32 +172,29 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
// try to roll up
List<Object> queryFunctions =
queryFunctionShuttled.collectFirst(expr -> expr instanceof AggregateFunction);
if (queryFunctions.isEmpty()) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Can not found query function",
String.format("queryFunctionShuttled = %s", queryFunctionShuttled)));
return null;
}
Function rollupAggregateFunction = rollup((AggregateFunction) queryFunctions.get(0),
queryFunctionShuttled, mvExprToMvScanExprQueryBased);
if (rollupAggregateFunction == null) {
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
false, mvExprToMvScanExprQueryBased, queryTopPlan);
// queryFunctionShuttled maybe sum(column) + count(*), so need to use expression rewriter
Expression rollupedExpression = queryFunctionShuttled.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Query function roll up fail",
String.format("queryFunction = %s,\n queryFunctionShuttled = %s,\n"
+ "mvExprToMvScanExprQueryBased = %s",
queryFunctions.get(0), queryFunctionShuttled,
mvExprToMvScanExprQueryBased)));
String.format("queryFunctionShuttled = %s,\n mvExprToMvScanExprQueryBased = %s",
queryFunctionShuttled, mvExprToMvScanExprQueryBased)));
return null;
}
finalAggregateExpressions.add(new Alias(rollupAggregateFunction));
finalOutputExpressions.add(new Alias(rollupedExpression));
} else {
// if group by expression, try to rewrite group by expression
Expression queryGroupShuttledExpr =
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
true, mvExprToMvScanExprQueryBased, queryTopPlan);
// group by expression maybe group by a + b, so we need expression rewriter
Expression rewrittenGroupByExpression = queryGroupShuttledExpr.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
// group expr can not rewrite by view
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("View dimensions doesn't not cover the query dimensions",
@ -202,9 +202,10 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
mvExprToMvScanExprQueryBased, queryGroupShuttledExpr)));
return null;
}
Expression expression = mvExprToMvScanExprQueryBased.get(queryGroupShuttledExpr);
finalAggregateExpressions.add((NamedExpression) expression);
finalGroupExpressions.add(expression);
NamedExpression groupByExpression = rewrittenGroupByExpression instanceof NamedExpression
? (NamedExpression) rewrittenGroupByExpression : new Alias(rewrittenGroupByExpression);
finalOutputExpressions.add(groupByExpression);
finalGroupExpressions.add(groupByExpression);
}
}
// add project to guarantee group by column ref is slot reference,
@ -229,7 +230,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
return (NamedExpression) expr;
})
.collect(Collectors.toList());
finalAggregateExpressions = finalAggregateExpressions.stream()
finalOutputExpressions = finalOutputExpressions.stream()
.map(expr -> {
ExprId exprId = expr.getExprId();
if (projectOutPutExprIdMap.containsKey(exprId)) {
@ -238,7 +239,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
return expr;
})
.collect(Collectors.toList());
return new LogicalAggregate(finalGroupExpressions, finalAggregateExpressions, mvProject);
return new LogicalAggregate(finalGroupExpressions, finalOutputExpressions, mvProject);
}
private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair,
@ -273,7 +274,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
* the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1
* mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
*/
private Function rollup(AggregateFunction queryAggregateFunction,
private static Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
@ -310,7 +311,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
// Check the aggregate function can roll up or not, return true if could roll up
// if view aggregate function is distinct or is in the un supported rollup functions, it doesn't support
// roll up.
private boolean canRollup(Expression rollupExpression) {
private static boolean canRollup(Expression rollupExpression) {
if (rollupExpression == null) {
return false;
}
@ -402,7 +403,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
private static boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
@ -438,9 +439,109 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
private static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}
/**
* Aggregate expression rewriter which is responsible for rewriting group by and
* aggregate function expression
*/
protected static class AggregateExpressionRewriter
extends DefaultExpressionRewriter<AggregateExpressionRewriteContext> {
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return aggregateFunction;
}
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
aggregateFunction,
rewriteContext.getQueryTopPlan());
Function rollupAggregateFunction = rollup(aggregateFunction, queryFunctionShuttled,
rewriteContext.getMvExprToMvScanExprQueryBasedMapping());
if (rollupAggregateFunction == null) {
rewriteContext.setValid(false);
return aggregateFunction;
}
return rollupAggregateFunction;
}
@Override
public Expression visitSlot(Slot slot, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return slot;
}
if (rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(slot)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(slot);
}
rewriteContext.setValid(false);
return slot;
}
@Override
public Expression visit(Expression expr, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return expr;
}
// for group by expression try to get corresponding expression directly
if (rewriteContext.isOnlyContainGroupByExpression()
&& rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(expr)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(expr);
}
List<Expression> newChildren = new ArrayList<>(expr.arity());
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(this, rewriteContext);
if (!rewriteContext.isValid()) {
return expr;
}
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}
/**
* AggregateExpressionRewriteContext
*/
protected static class AggregateExpressionRewriteContext {
private boolean valid = true;
private final boolean onlyContainGroupByExpression;
private final Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping;
private final Plan queryTopPlan;
public AggregateExpressionRewriteContext(boolean onlyContainGroupByExpression,
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan) {
this.onlyContainGroupByExpression = onlyContainGroupByExpression;
this.mvExprToMvScanExprQueryBasedMapping = mvExprToMvScanExprQueryBasedMapping;
this.queryTopPlan = queryTopPlan;
}
public boolean isValid() {
return valid;
}
public void setValid(boolean valid) {
this.valid = valid;
}
public boolean isOnlyContainGroupByExpression() {
return onlyContainGroupByExpression;
}
public Map<Expression, Expression> getMvExprToMvScanExprQueryBasedMapping() {
return mvExprToMvScanExprQueryBasedMapping;
}
public Plan getQueryTopPlan() {
return queryTopPlan;
}
}
}