[fix](nereids) fix aggregate function roll up when expression arguments is not equals (#29256)

when aggregate function roll up, we should check the qury and mv function argument is equal
such as mv def and query sql as following, it should not rewrite success, because the  bitmap_union_basic field augument is
not equal to the `count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end)`  field in query

mv def:
>      select l_shipdate, o_orderdate, l_partkey, l_suppkey, 
>            sum(o_totalprice) as sum_total, 
>            max(o_totalprice) as max_total, 
>            min(o_totalprice) as min_total, 
>            count(*) as count_all, 
>            bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic 
>           from lineitem 
>           left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate 
>            group by 
>         l_shipdate, 
>         o_orderdate, 
>          l_partkey, 
>         l_suppkey;

query sql:

>             select t1.l_partkey, t1.l_suppkey, o_orderdate,
>           sum(o_totalprice),
>            max(o_totalprice),
>           min(o_totalprice),
>           count(*),
>            count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end)
>            from (select * from lineitem where l_shipdate = '2023-12-11') t1
>            left join orders on t1.l_orderkey = orders.o_orderkey and t1.l_shipdate = o_orderdate
>            group by
>            o_orderdate, 
>            l_partkey,
>            l_suppkey;
This commit is contained in:
seawinde
2024-01-03 18:58:18 +08:00
committed by GitHub
parent 2386a1ce5a
commit 49a3bab399
5 changed files with 158 additions and 24 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -35,11 +36,14 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
@ -47,10 +51,11 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@ -60,15 +65,18 @@ import java.util.stream.Collectors;
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {
protected static final Map<Expression, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();
// we only support roll up function which has only one argument currently
protected static final Multimap<Expression, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
protected final String currentClassName = this.getClass().getSimpleName();
private final Logger logger = LogManager.getLogger(this.getClass());
static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
}
@Override
@ -249,17 +257,30 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
return rewrittenAggregate;
}
// only support sum roll up, support other agg functions later.
private Function rollup(AggregateFunction queryFunction,
Expression queryFunctionShuttled,
/**
* Roll up query aggregate function when query dimension num is less than mv dimension num,
*
* @param queryAggregateFunction query aggregate function to roll up.
* @param queryAggregateFunctionShuttled query aggregate function shuttled by lineage.
* @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv result data output mapping.
* <p>
* Such as query is
* select max(a) + 1 from table group by b.
* mv is
* select max(a) from table group by a, b.
* the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1
* mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
*/
private Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryFunction instanceof CouldRollUp)) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
return null;
}
Expression rollupParam = null;
if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)) {
// function can rewrite by view
rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
} else {
// function can not rewrite by view, try to use complex roll up param
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
@ -267,7 +288,8 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
if (!(mvExprShuttled instanceof Function)) {
continue;
}
if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) {
if (isAggregateFunctionEquivalent(queryAggregateFunction, queryAggregateFunctionShuttled,
(Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
}
}
@ -276,7 +298,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
return null;
}
// do roll up
return ((CouldRollUp) queryFunction).constructRollUp(rollupParam);
return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam);
}
private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
@ -347,22 +369,55 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
return true;
}
private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
/**
* Check the queryFunction is equivalent to view function when function roll up.
* Not only check the function name but also check the argument between query and view aggregate function.
* Such as query is
* select count(distinct a) + 1 from table group by b.
* mv is
* select bitmap_union(to_bitmap(a)) from table group by a, b.
* the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1
* mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) }
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private boolean isAggregateFunctionEquivalent(Function queryFunction, Expression queryFunctionShuttled,
Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// get query equivalent function
Expression equivalentFunction = null;
for (Map.Entry<Expression, Expression> entry : AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) {
if (entry.getKey().equals(queryFunction)) {
equivalentFunction = entry.getValue();
// check the argument of rollup function is equivalent to view function or not
for (Map.Entry<Expression, Collection<Expression>> equivalentFunctionEntry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
// check is have equivalent function or not
for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) {
if (!Any.equals(equivalentFunction, viewFunction)) {
continue;
}
// check param in query function is same as the view function
List<Expression> viewFunctionArguments = extractViewArguments(equivalentFunction, viewFunction);
if (queryFunctionShuttled.getArguments().size() != 1 || viewFunctionArguments.size() != 1) {
continue;
}
if (Objects.equals(queryFunctionShuttled.getArguments().get(0), viewFunctionArguments.get(0))) {
return true;
}
}
}
}
// check is have equivalent function or not
if (equivalentFunction == null) {
return false;
}
// current compare
return equivalentFunction.equals(viewFunction);
return false;
}
/**
* Extract the view function arguments by equivalentFunction pattern
* Such as equivalentFunction def is bitmap_union(to_bitmap(Any.INSTANCE)),
* viewFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private List<Expression> extractViewArguments(Expression equivalentFunction, Function viewFunction) {
Set<Object> exprSetToRemove = equivalentFunction.collectToSet(expr -> !(expr instanceof Any));
return viewFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}
}

View File

@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Set;
@ -150,6 +151,19 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
return rewriteFunction.apply(rewrittenChildren);
}
/**
* Foreach treeNode. Top-down traverse implicitly, stop traverse if satisfy test.
* @param func foreach function
*/
default void foreach(Predicate<TreeNode<NODE_TYPE>> func) {
boolean valid = func.test(this);
if (!valid) {
for (NODE_TYPE child : children()) {
child.foreach(func);
}
}
}
/**
* Foreach treeNode. Top-down traverse implicitly.
* @param func foreach function
@ -241,6 +255,20 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
return (Set<T>) result.build();
}
/**
* Collect the nodes that satisfied the predicate firstly.
*/
default <T> List<T> collectFirst(Predicate<TreeNode<NODE_TYPE>> predicate) {
List<TreeNode<NODE_TYPE>> result = new ArrayList<>();
foreach(node -> {
if (result.isEmpty() && predicate.test(node)) {
result.add(node);
}
return !result.isEmpty();
});
return (List<T>) ImmutableList.copyOf(result);
}
/**
* iterate top down and test predicate if contains any instance of the classes
* @param types classes array

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/**
* This represents any expression, it means it equals any expression
@ -55,6 +56,15 @@ public class Any extends Expression implements LeafExpression {
return true;
}
/**
* Equals with direction
* Since the equals method in Any is always true, that means Any is equals to others, but not equal in reverse.
* The expression with Any should always be the first argument.
*/
public static boolean equals(Expression expressionWithAny, Expression target) {
return Objects.equals(expressionWithAny, target);
}
@Override
public int hashCode() {
return 0;