diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index 685f8a8c3a..11faaa6a6d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.trees.expressions.Any; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -35,11 +36,14 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; @@ -47,10 +51,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -60,15 +65,18 @@ import java.util.stream.Collectors; */ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule { - protected static final Map - AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>(); + // we only support roll up function which has only one argument currently + protected static final Multimap + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create(); protected final String currentClassName = this.getClass().getSimpleName(); private final Logger logger = LogManager.getLogger(this.getClass()); static { AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE), - new BitmapUnion(Any.INSTANCE)); + new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE)))); + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE), + new BitmapUnion(new ToBitmap(Any.INSTANCE))); } @Override @@ -249,17 +257,30 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return rewrittenAggregate; } - // only support sum roll up, support other agg functions later. - private Function rollup(AggregateFunction queryFunction, - Expression queryFunctionShuttled, + /** + * Roll up query aggregate function when query dimension num is less than mv dimension num, + * + * @param queryAggregateFunction query aggregate function to roll up. + * @param queryAggregateFunctionShuttled query aggregate function shuttled by lineage. + * @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv result data output mapping. + *

+ * Such as query is + * select max(a) + 1 from table group by b. + * mv is + * select max(a) from table group by a, b. + * the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1 + * mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) } + */ + private Function rollup(AggregateFunction queryAggregateFunction, + Expression queryAggregateFunctionShuttled, Map mvExprToMvScanExprQueryBased) { - if (!(queryFunction instanceof CouldRollUp)) { + if (!(queryAggregateFunction instanceof CouldRollUp)) { return null; } Expression rollupParam = null; - if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) { + if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)) { // function can rewrite by view - rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled); + rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled); } else { // function can not rewrite by view, try to use complex roll up param // eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param)) @@ -267,7 +288,8 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate if (!(mvExprShuttled instanceof Function)) { continue; } - if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) { + if (isAggregateFunctionEquivalent(queryAggregateFunction, queryAggregateFunctionShuttled, + (Function) mvExprShuttled)) { rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled); } } @@ -276,7 +298,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return null; } // do roll up - return ((CouldRollUp) queryFunction).constructRollUp(rollupParam); + return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam); } private Pair, Set> topPlanSplitToGroupAndFunction( @@ -347,22 +369,55 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return true; } - private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) { + /** + * Check the queryFunction is equivalent to view function when function roll up. + * Not only check the function name but also check the argument between query and view aggregate function. + * Such as query is + * select count(distinct a) + 1 from table group by b. + * mv is + * select bitmap_union(to_bitmap(a)) from table group by a, b. + * the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1 + * mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) } + * This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv, + * and then check their arguments is equivalent. + */ + private boolean isAggregateFunctionEquivalent(Function queryFunction, Expression queryFunctionShuttled, + Function viewFunction) { if (queryFunction.equals(viewFunction)) { return true; } - // get query equivalent function - Expression equivalentFunction = null; - for (Map.Entry entry : AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) { - if (entry.getKey().equals(queryFunction)) { - equivalentFunction = entry.getValue(); + // check the argument of rollup function is equivalent to view function or not + for (Map.Entry> equivalentFunctionEntry : + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) { + if (equivalentFunctionEntry.getKey().equals(queryFunction)) { + // check is have equivalent function or not + for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) { + if (!Any.equals(equivalentFunction, viewFunction)) { + continue; + } + // check param in query function is same as the view function + List viewFunctionArguments = extractViewArguments(equivalentFunction, viewFunction); + if (queryFunctionShuttled.getArguments().size() != 1 || viewFunctionArguments.size() != 1) { + continue; + } + if (Objects.equals(queryFunctionShuttled.getArguments().get(0), viewFunctionArguments.get(0))) { + return true; + } + } } } - // check is have equivalent function or not - if (equivalentFunction == null) { - return false; - } - // current compare - return equivalentFunction.equals(viewFunction); + return false; + } + + /** + * Extract the view function arguments by equivalentFunction pattern + * Such as equivalentFunction def is bitmap_union(to_bitmap(Any.INSTANCE)), + * viewFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end)) + * after extracting, the return argument is: case when a = 5 then 1 else 2 end + */ + private List extractViewArguments(Expression equivalentFunction, Function viewFunction) { + Set exprSetToRemove = equivalentFunction.collectToSet(expr -> !(expr instanceof Any)); + return viewFunction.collectFirst(expr -> + exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr))); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index 557ff43b51..00ac71eaf2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Deque; import java.util.List; import java.util.Set; @@ -150,6 +151,19 @@ public interface TreeNode> { return rewriteFunction.apply(rewrittenChildren); } + /** + * Foreach treeNode. Top-down traverse implicitly, stop traverse if satisfy test. + * @param func foreach function + */ + default void foreach(Predicate> func) { + boolean valid = func.test(this); + if (!valid) { + for (NODE_TYPE child : children()) { + child.foreach(func); + } + } + } + /** * Foreach treeNode. Top-down traverse implicitly. * @param func foreach function @@ -241,6 +255,20 @@ public interface TreeNode> { return (Set) result.build(); } + /** + * Collect the nodes that satisfied the predicate firstly. + */ + default List collectFirst(Predicate> predicate) { + List> result = new ArrayList<>(); + foreach(node -> { + if (result.isEmpty() && predicate.test(node)) { + result.add(node); + } + return !result.isEmpty(); + }); + return (List) ImmutableList.copyOf(result); + } + /** * iterate top down and test predicate if contains any instance of the classes * @param types classes array diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java index 43d284bf67..2e4bc745b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Objects; /** * This represents any expression, it means it equals any expression @@ -55,6 +56,15 @@ public class Any extends Expression implements LeafExpression { return true; } + /** + * Equals with direction + * Since the equals method in Any is always true, that means Any is equals to others, but not equal in reverse. + * The expression with Any should always be the first argument. + */ + public static boolean equals(Expression expressionWithAny, Expression target) { + return Objects.equals(expressionWithAny, target); + } + @Override public int hashCode() { return 0; diff --git a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out index fb223bc661..334980ed00 100644 --- a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out +++ b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out @@ -5,6 +5,12 @@ -- !query13_0_after -- 3 3 2023-12-11 43.20 43.20 43.20 1 0 +-- !query13_1_before -- +3 3 2023-12-11 43.20 43.20 43.20 1 0 + +-- !query13_1_after -- +3 3 2023-12-11 43.20 43.20 43.20 1 0 + -- !query14_0_before -- 2 3 2023-12-08 20.00 10.50 9.50 2 0 2 3 2023-12-12 \N \N \N 1 0 diff --git a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy index fd3c02408d..e9d1ee76b3 100644 --- a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy @@ -247,6 +247,41 @@ suite("aggregate_with_roll_up") { sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_0""" + def mv13_1 = """ + select l_shipdate, o_orderdate, l_partkey, l_suppkey, + sum(o_totalprice) as sum_total, + max(o_totalprice) as max_total, + min(o_totalprice) as min_total, + count(*) as count_all, + bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic + from lineitem + left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate + group by + l_shipdate, + o_orderdate, + l_partkey, + l_suppkey; + """ + def query13_1 = """ + select t1.l_partkey, t1.l_suppkey, o_orderdate, + sum(o_totalprice), + max(o_totalprice), + min(o_totalprice), + count(*), + count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end) + from (select * from lineitem where l_shipdate = '2023-12-11') t1 + left join orders on t1.l_orderkey = orders.o_orderkey and t1.l_shipdate = o_orderdate + group by + o_orderdate, + l_partkey, + l_suppkey; + """ + order_qt_query13_1_before "${query13_1}" + check_not_match(mv13_1, query13_1, "mv13_1") + order_qt_query13_1_after "${query13_1}" + sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_1""" + + // filter inside + right + use roll up dimension def mv14_0 = "select l_shipdate, o_orderdate, l_partkey, l_suppkey, " + "sum(o_totalprice) as sum_total, " +