[feature](mtmv) Support agg state roll up and optimize the roll up code (#35026)

agg_state is agg  intermediate state, detail see 
state combinator: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/combinators/state

this support agg function roll up as following
 
+---------------------+---------------------------------------------+---------------------+
| query               | materialized view                           | roll up             |
| ------------------- | ------------------------------------------- | ------------------- |
| agg_funtion()       | agg_funtion_unoin()  or agg_funtion_state() | agg_funtion_merge() |
| agg_funtion_unoin() | agg_funtion_unoin() or agg_funtion_state()  | agg_funtion_union() |
| agg_funtion_merge() | agg_funtion_unoin() or agg_funtion_state()  | agg_funtion_merge() |
+---------------------+---------------------------------------------+---------------------+

for example which can be rewritten by mv sucessfully as following

MV defination is

```
            select
            o_orderstatus,
            l_partkey,
            l_suppkey,
            sum_union(sum_state(o_shippriority)),
            group_concat_union(group_concat_state(l_shipinstruct)),
            avg_union(avg_state(l_linenumber)),
            max_by_union(max_by_state(l_shipmode, l_suppkey)),
            count_union(count_state(l_orderkey)),
            multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
            from lineitem
            left join orders
            on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
            group by
            o_orderstatus,
            l_partkey,
            l_suppkey;
```

Query is

```
            select
            o_orderstatus,
            l_suppkey,
            sum(o_shippriority),
            group_concat(l_shipinstruct),
            avg(l_linenumber),
            max_by(l_shipmode,l_suppkey),
            count(l_orderkey),
            multi_distinct_count(l_shipmode)
            from lineitem
            left join orders 
            on l_orderkey = o_orderkey and l_shipdate = o_orderdate
            group by
            o_orderstatus,
            l_suppkey;
```
This commit is contained in:
seawinde
2024-05-24 12:02:43 +08:00
committed by yiguolei
parent 4b91ad003f
commit d6e8fb7d77
28 changed files with 1350 additions and 244 deletions

View File

@ -21,47 +21,35 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanCheckContext;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.exploration.mv.rollup.AggFunctionRollUpHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.BothCombinatorRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.DirectRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.MappingRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.SingleCombinatorRollupHandler;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@ -71,92 +59,15 @@ import java.util.stream.Collectors;
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {
protected static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
public static final List<AggFunctionRollUpHandler> ROLL_UP_HANDLERS =
ImmutableList.of(DirectRollupHandler.INSTANCE,
MappingRollupHandler.INSTANCE,
SingleCombinatorRollupHandler.INSTANCE,
BothCombinatorRollupHandler.INSTANCE);
protected static final AggregateExpressionRewriter AGGREGATE_EXPRESSION_REWRITER =
new AggregateExpressionRewriter();
static {
// support roll up when count distinct is in query
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when bitmap_union_count is in query
// the column type is bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(Any.INSTANCE)),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when the column type is not hll
// query is approx_count_distinct
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// support roll up when the column type is hll
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(Any.INSTANCE)),
new HllUnion(Any.INSTANCE));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
}
@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
@ -374,35 +285,22 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
private static Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
return null;
}
Expression rollupParam = null;
Expression viewRollupFunction = null;
// handle simple aggregate function roll up which is not in the AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP
if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)
&& AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.keySet().stream()
.noneMatch(aggFunction -> aggFunction.equals(queryAggregateFunction))) {
rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
viewRollupFunction = queryAggregateFunctionShuttled;
} else {
// handle complex functions roll up
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
for (Map.Entry<Expression, Expression> expressionEntry : mvExprToMvScanExprQueryBased.entrySet()) {
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair = Pair.of(expressionEntry.getKey(),
expressionEntry.getValue());
for (AggFunctionRollUpHandler rollUpHandler : ROLL_UP_HANDLERS) {
if (!rollUpHandler.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
continue;
}
if (isAggregateFunctionEquivalent(queryAggregateFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
viewRollupFunction = mvExprShuttled;
Function rollupFunction = rollUpHandler.doRollup(queryAggregateFunction,
queryAggregateFunctionShuttled, mvExprToMvScanExprQueryBasedPair);
if (rollupFunction != null) {
return rollupFunction;
}
}
}
if (rollupParam == null || !canRollup(viewRollupFunction)) {
return null;
}
// do roll up
return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam);
return null;
}
// Check the aggregate function can roll up or not, return true if could roll up
@ -417,7 +315,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
}
if (rollupExpression instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) rollupExpression;
return !aggregateFunction.isDistinct() && aggregateFunction instanceof CouldRollUp;
return !aggregateFunction.isDistinct() && aggregateFunction instanceof RollUpTrait;
}
return true;
}
@ -479,60 +377,6 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
&& checkContext.isContainsTopAggregate() && checkContext.getTopAggregateNum() <= 1;
}
/**
* Check the queryFunction is equivalent to view function when function roll up.
* Not only check the function name but also check the argument between query and view aggregate function.
* Such as query is
* select count(distinct a) + 1 from table group by b.
* mv is
* select bitmap_union(to_bitmap(a)) from table group by a, b.
* the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1
* mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) }
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private static boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// check the argument of rollup function is equivalent to view function or not
for (Map.Entry<Function, Collection<Expression>> equivalentFunctionEntry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
// check is have equivalent function or not
for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) {
if (!Any.equals(equivalentFunction, viewFunction)) {
continue;
}
// check param in query function is same as the view function
List<Expression> viewFunctionArguments = extractArguments(equivalentFunction, viewFunction);
List<Expression> queryFunctionArguments =
extractArguments(equivalentFunctionEntry.getKey(), queryFunction);
// check argument size,we only support roll up function which has only one argument currently
if (queryFunctionArguments.size() != 1 || viewFunctionArguments.size() != 1) {
continue;
}
if (Objects.equals(queryFunctionArguments.get(0), viewFunctionArguments.get(0))) {
return true;
}
}
}
}
return false;
}
/**
* Extract the function arguments by functionWithAny pattern
* Such as functionWithAny def is bitmap_union(to_bitmap(Any.INSTANCE)),
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}
/**
* Aggregate expression rewriter which is responsible for rewriting group by and
* aggregate function expression

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import java.util.List;
import java.util.Set;
/**
* Aggregate function roll up handler
*/
public abstract class AggFunctionRollUpHandler {
/**
* Decide the query and view function can roll up or not
*/
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression viewExpression = mvExprToMvScanExprQueryBasedPair.key();
if (!(viewExpression instanceof RollUpTrait) || !((RollUpTrait) viewExpression).canRollUp()) {
return false;
}
AggregateFunction aggregateFunction = (AggregateFunction) viewExpression;
return !aggregateFunction.isDistinct();
}
/**
* Do the aggregate function roll up
*/
public abstract Function doRollup(
AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair);
/**
* Extract the function arguments by functionWithAny pattern
* Such as functionWithAny def is bitmap_union(to_bitmap(Any.INSTANCE)),
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
protected static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}
/**
* Extract the target expression in actualFunction by targetClazz
* Such as actualFunction def is avg_merge(avg_union(c1)), target Clazz is Combinator
* after extracting, the return argument is avg_union(c1)
*/
protected static <T> T extractLastExpression(Expression actualFunction, Class<T> targetClazz) {
List<Expression> expressions = actualFunction.collectToList(targetClazz::isInstance);
return targetClazz.cast(expressions.get(expressions.size() - 1));
}
}

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.functions.combinator.Combinator;
import java.util.Objects;
/**
* Handle the combinator aggregate function roll up, both query and view are combinator
* Such as query is select c1 sum_merge(sum_state(c2)) from orders group by c1;
* view is select c1 sum_union(sum_state(c2)) from orders group by c1;
*/
public class BothCombinatorRollupHandler extends AggFunctionRollUpHandler {
public static BothCombinatorRollupHandler INSTANCE = new BothCombinatorRollupHandler();
@Override
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression viewFunction = mvExprToMvScanExprQueryBasedPair.key();
if (!super.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
return false;
}
if (queryAggregateFunction instanceof Combinator && viewFunction instanceof Combinator) {
Combinator queryCombinator = extractLastExpression(queryAggregateFunction, Combinator.class);
Combinator viewCombinator = extractLastExpression(viewFunction, Combinator.class);
// construct actual aggregate function in combinator and compare
return Objects.equals(queryCombinator.getNestedFunction().withChildren(queryCombinator.getArguments()),
viewCombinator.getNestedFunction().withChildren(viewCombinator.getArguments()));
}
return false;
}
@Override
public Function doRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression rollupParam = mvExprToMvScanExprQueryBasedPair.value();
return ((RollUpTrait) queryAggregateFunction).constructRollUp(rollupParam);
}
}

View File

@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.functions.combinator.Combinator;
/**
* Roll up directly, for example,
* query is select c1, sum(c2) from t1 group by c1
* view is select c1, c2, sum(c2) from t1 group by c1, c2,
* the aggregate function in query and view is same, This handle the sum aggregate function roll up
*/
public class DirectRollupHandler extends AggFunctionRollUpHandler {
public static DirectRollupHandler INSTANCE = new DirectRollupHandler();
@Override
public boolean canRollup(
AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression viewExpression = mvExprToMvScanExprQueryBasedPair.key();
if (!super.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
return false;
}
return queryAggregateFunctionShuttled.equals(viewExpression)
&& MappingRollupHandler.AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.keySet().stream()
.noneMatch(aggFunction -> aggFunction.equals(queryAggregateFunction))
&& !(queryAggregateFunction instanceof Combinator);
}
@Override
public Function doRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression rollupParam = mvExprToMvScanExprQueryBasedPair.value();
if (rollupParam == null) {
return null;
}
return ((RollUpTrait) queryAggregateFunction).constructRollUp(rollupParam);
}
}

View File

@ -0,0 +1,181 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* Handle the aggregate functions roll up which are in AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP
*/
public class MappingRollupHandler extends AggFunctionRollUpHandler {
public static MappingRollupHandler INSTANCE = new MappingRollupHandler();
public static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
static {
// support roll up when count distinct is in query
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when bitmap_union_count is in query
// the column type is bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(Any.INSTANCE)),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when the column type is not hll
// query is approx_count_distinct
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// support roll up when the column type is hll
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(Any.INSTANCE)),
new HllUnion(Any.INSTANCE));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
}
@Override
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
// handle complex functions roll up by mapping and combinator expression
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
Expression viewExpression = mvExprToMvScanExprQueryBasedPair.key();
if (!super.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
return false;
}
Function viewFunction = (Function) viewExpression;
for (Map.Entry<Function, Collection<Expression>> equivalentFunctionEntry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
if (equivalentFunctionEntry.getKey().equals(queryAggregateFunction)) {
// check is have equivalent function or not
for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) {
if (!Any.equals(equivalentFunction, viewFunction)) {
continue;
}
// check param in query function is same as the view function
List<Expression> viewFunctionArguments = extractArguments(equivalentFunction, viewFunction);
List<Expression> queryFunctionArguments =
extractArguments(equivalentFunctionEntry.getKey(), queryAggregateFunction);
// check argument size,we only support roll up function which has only one argument currently
if (queryFunctionArguments.size() != 1 || viewFunctionArguments.size() != 1) {
continue;
}
if (Objects.equals(queryFunctionArguments.get(0), viewFunctionArguments.get(0))) {
return true;
}
}
}
}
return false;
}
@Override
public Function doRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression rollupParam = mvExprToMvScanExprQueryBasedPair.value();
return ((RollUpTrait) queryAggregateFunction).constructRollUp(rollupParam);
}
}

View File

@ -0,0 +1,75 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.combinator.Combinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import java.util.Objects;
/**
* Handle the combinator aggregate function roll up, Only view is combinator, query is aggregate function.
* Such as query is select c1 sum(c2) from orders group by c1;
* view is select c1 sum_union(sum_state(c2)) from orders group by c1;
* */
public class SingleCombinatorRollupHandler extends AggFunctionRollUpHandler {
public static SingleCombinatorRollupHandler INSTANCE = new SingleCombinatorRollupHandler();
@Override
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression viewFunction = mvExprToMvScanExprQueryBasedPair.key();
if (!super.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
return false;
}
if (!(queryAggregateFunction instanceof Combinator)
&& (viewFunction instanceof UnionCombinator || viewFunction instanceof StateCombinator)) {
Combinator viewCombinator = extractLastExpression(viewFunction, Combinator.class);
return Objects.equals(queryAggregateFunction,
viewCombinator.getNestedFunction().withChildren(viewCombinator.getArguments()));
}
return false;
}
@Override
public Function doRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
String combinatorName = queryAggregateFunction.getName() + AggCombinerFunctionBuilder.MERGE_SUFFIX;
Expression rollupParam = mvExprToMvScanExprQueryBasedPair.value();
FunctionBuilder functionBuilder =
functionRegistry.findFunctionBuilder(combinatorName, rollupParam);
Pair<? extends Expression, ? extends BoundFunction> targetExpressionPair =
functionBuilder.build(combinatorName, rollupParam);
return (Function) targetExpressionPair.key();
}
}

View File

@ -36,7 +36,7 @@ import java.util.List;
* AggregateFunction 'bitmap_union'. This class is generated by GenerateFunction.
*/
public class BitmapUnion extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, BitmapFunction, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, BitmapFunction, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE)
@ -84,4 +84,9 @@ public class BitmapUnion extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new BitmapUnion(this.isDistinct(), param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -37,7 +37,7 @@ import java.util.List;
* AggregateFunction 'bitmap_union_count'. This class is generated by GenerateFunction.
*/
public class BitmapUnionCount extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, BitmapFunction, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, BitmapFunction, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BitmapType.INSTANCE)
@ -85,4 +85,9 @@ public class BitmapUnionCount extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new BitmapUnionCount(param);
}
@Override
public boolean canRollUp() {
return false;
}
}

View File

@ -37,7 +37,7 @@ import java.util.List;
/** count agg function. */
public class Count extends AggregateFunction
implements ExplicitlyCastableSignature, AlwaysNotNullable, SupportWindowAnalytic, CouldRollUp {
implements ExplicitlyCastableSignature, AlwaysNotNullable, SupportWindowAnalytic, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
// count(*)
@ -152,4 +152,9 @@ public class Count extends AggregateFunction
return new Sum(param);
}
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -36,7 +36,7 @@ import java.util.List;
* AggregateFunction 'hll_union'. This class is generated by GenerateFunction.
*/
public class HllUnion extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(HllType.INSTANCE).args(HllType.INSTANCE)
@ -84,4 +84,9 @@ public class HllUnion extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new HllUnion(param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -37,7 +37,7 @@ import java.util.List;
* AggregateFunction 'hll_union_agg'. This class is generated by GenerateFunction.
*/
public class HllUnionAgg extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE)
@ -85,4 +85,9 @@ public class HllUnionAgg extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new HllUnionAgg(param);
}
@Override
public boolean canRollUp() {
return false;
}
}

View File

@ -37,7 +37,7 @@ import java.util.List;
/** max agg function. */
public class Max extends NullableAggregateFunction
implements UnaryExpression, CustomSignature, SupportWindowAnalytic, CouldRollUp {
implements UnaryExpression, CustomSignature, SupportWindowAnalytic, RollUpTrait {
public Max(Expression child) {
this(false, false, child);
}
@ -91,4 +91,9 @@ public class Max extends NullableAggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new Max(this.distinct, this.alwaysNullable, param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -37,7 +37,7 @@ import java.util.List;
/** min agg function. */
public class Min extends NullableAggregateFunction
implements UnaryExpression, CustomSignature, SupportWindowAnalytic, CouldRollUp {
implements UnaryExpression, CustomSignature, SupportWindowAnalytic, RollUpTrait {
public Min(Expression child) {
this(false, false, child);
@ -92,4 +92,9 @@ public class Min extends NullableAggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new Min(this.distinct, this.alwaysNullable, param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -38,7 +38,7 @@ import java.util.List;
* AggregateFunction 'ndv'. This class is generated by GenserateFunction.
*/
public class Ndv extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(AnyDataType.INSTANCE_WITHOUT_INDEX)
@ -85,4 +85,9 @@ public class Ndv extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new HllUnionAgg(param);
}
@Override
public boolean canRollUp() {
return false;
}
}

View File

@ -21,14 +21,20 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
/**
* Could roll up trait, it could be rolled up if a function appear in query which can be represented
* by aggregate function in view.
* Roll up trait, which identify an function could be rolled up if a function appear in query
* which can be represented by aggregate function in view.
* Acquire the rolled up function by constructRollUp method.
*/
public interface CouldRollUp {
public interface RollUpTrait {
/**
* construct the roll up function with custom param
* Construct the roll up function with custom param
*/
Function constructRollUp(Expression param, Expression... varParams);
/**
* identify the function itself can be rolled up
* Such as Sum can be rolled up directly, but BitmapUnionCount can not
*/
boolean canRollUp();
}

View File

@ -47,7 +47,7 @@ import java.util.List;
*/
public class Sum extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, SupportWindowAnalytic,
CouldRollUp {
RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
@ -129,4 +129,9 @@ public class Sum extends NullableAggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new Sum(this.distinct, param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -49,7 +49,7 @@ import java.util.List;
*/
public class Sum0 extends AggregateFunction
implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature, ComputePrecisionForSum,
SupportWindowAnalytic, CouldRollUp {
SupportWindowAnalytic, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
@ -122,4 +122,9 @@ public class Sum0 extends AggregateFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new Sum0(this.distinct, param);
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
/**
* Combinator interface which identify the expression is Combinator
*/
public interface Combinator extends ExpressionTrait {
AggregateFunction getNestedFunction();
}

View File

@ -37,7 +37,7 @@ import java.util.Objects;
* combinator foreach
*/
public class ForEachCombinator extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable, Combinator {
private final AggregateFunction nested;
@ -79,6 +79,7 @@ public class ForEachCombinator extends AggregateFunction
return ArrayType.of(nested.getDataType(), nested.nullable());
}
@Override
public AggregateFunction getNestedFunction() {
return nested;
}

View File

@ -17,12 +17,19 @@
package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.AggStateType;
@ -37,7 +44,7 @@ import java.util.Objects;
* AggState combinator merge
*/
public class MergeCombinator extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, ComputeNullable {
implements UnaryExpression, ExplicitlyCastableSignature, ComputeNullable, Combinator, RollUpTrait {
private final AggregateFunction nested;
private final AggStateType inputType;
@ -71,6 +78,7 @@ public class MergeCombinator extends AggregateFunction
return nested.getDataType();
}
@Override
public AggregateFunction getNestedFunction() {
return nested;
}
@ -84,4 +92,18 @@ public class MergeCombinator extends AggregateFunction
public boolean nullable() {
return nested.nullable();
}
@Override
public Function constructRollUp(Expression param, Expression... varParams) {
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(getName(), param);
Pair<? extends Expression, ? extends BoundFunction> targetExpressionPair = functionBuilder.build(getName(),
param);
return (Function) targetExpressionPair.key();
}
@Override
public boolean canRollUp() {
return false;
}
}

View File

@ -17,12 +17,19 @@
package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -38,7 +45,7 @@ import java.util.Objects;
* AggState combinator state
*/
public class StateCombinator extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, Combinator, RollUpTrait {
private final AggregateFunction nested;
private final AggStateType returnType;
@ -83,7 +90,25 @@ public class StateCombinator extends ScalarFunction
return returnType;
}
@Override
public AggregateFunction getNestedFunction() {
return nested;
}
@Override
public Function constructRollUp(Expression param, Expression... varParams) {
String nestedName = AggCombinerFunctionBuilder.getNestedName(getName());
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
// state combinator roll up result should be union combinator
String combinatorName = nestedName + AggCombinerFunctionBuilder.UNION_SUFFIX;
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(combinatorName, param);
Pair<? extends Expression, ? extends BoundFunction> targetExpressionPair =
functionBuilder.build(combinatorName, param);
return (Function) targetExpressionPair.key();
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -17,12 +17,19 @@
package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.AggStateType;
@ -37,7 +44,7 @@ import java.util.Objects;
* AggState combinator union
*/
public class UnionCombinator extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, Combinator, RollUpTrait {
private final AggregateFunction nested;
private final AggStateType inputType;
@ -71,6 +78,7 @@ public class UnionCombinator extends AggregateFunction
return inputType;
}
@Override
public AggregateFunction getNestedFunction() {
return nested;
}
@ -79,4 +87,18 @@ public class UnionCombinator extends AggregateFunction
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
throw new UnsupportedOperationException("Unimplemented method 'withDistinctAndChildren'");
}
@Override
public Function constructRollUp(Expression param, Expression... varParams) {
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(getName(), param);
Pair<? extends Expression, ? extends BoundFunction> targetExpressionPair = functionBuilder.build(getName(),
param);
return (Function) targetExpressionPair.key();
}
@Override
public boolean canRollUp() {
return true;
}
}

View File

@ -22,9 +22,9 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
@ -39,7 +39,7 @@ import java.util.List;
* ScalarFunction 'hll_cardinality'. This class is generated by GenerateFunction.
*/
public class HllCardinality extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, CouldRollUp {
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable, HllFunction, RollUpTrait {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE)
@ -75,4 +75,9 @@ public class HllCardinality extends ScalarFunction
public Function constructRollUp(Expression param, Expression... varParams) {
return new HllUnionAgg(param);
}
@Override
public boolean canRollUp() {
return false;
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.analysis.ListPartitionDesc;
import org.apache.doris.analysis.PartitionDesc;
import org.apache.doris.analysis.RangePartitionDesc;
import org.apache.doris.analysis.TableName;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.KeysType;
@ -59,6 +60,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.NondeterministicFunctionCollector;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
@ -74,6 +76,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
@ -317,8 +320,14 @@ public class CreateMTMVInfo {
} else {
colNames.add(colName);
}
// If datatype is AggStateType, AggregateType should be generic, or column definition check will fail
columns.add(new ColumnDefinition(
colName, slots.get(i).getDataType(), slots.get(i).nullable(),
colName,
slots.get(i).getDataType(),
false,
slots.get(i).getDataType() instanceof AggStateType ? AggregateType.GENERIC : null,
slots.get(i).nullable(),
Optional.empty(),
CollectionUtils.isEmpty(simpleColumnDefinitions) ? null
: simpleColumnDefinitions.get(i).getComment()));
}

View File

@ -355,3 +355,35 @@
2023-12-11 1
2023-12-12 2
-- !query33_0_before --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query33_0_after --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query33_1_before --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query33_1_after --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query35_0_before --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query35_0_after --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query36_0_before --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1
-- !query36_0_after --
o 3 9 o,o,o,o,o,o 4.666666666666667 mi 6 2
o 4 2 o,o 4.0 yy 2 1

View File

@ -261,3 +261,57 @@ d c 17.00 2
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000
-- !query22_0_before --
2023-12-08 2 3 1
2023-12-09 4 3 1
2023-12-10 2 4 1
2023-12-11 3 3 2
2023-12-12 2 3 2
-- !query22_0_after --
2023-12-08 2 3 1
2023-12-09 4 3 1
2023-12-10 2 4 1
2023-12-11 3 3 2
2023-12-12 2 3 2
-- !query23_0_before --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query23_0_after --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query23_1_before --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query23_1_after --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query25_0_before --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query25_0_after --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query26_0_before --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2
-- !query26_0_after --
a 3 3 a,a,a 4.0 yy 3 1
a 4 2 a,a 4.0 yy 2 1
c 3 6 c,c,c 5.333333333333333 mi 3 2

View File

@ -1284,7 +1284,7 @@ suite("aggregate_with_roll_up") {
avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
from lineitem
left join orders on l_orderkey = o_orderkey
left join
left join
(select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
min(ps_availqty) qty_min,
avg(ps_supplycost) cost_avg
@ -1303,7 +1303,7 @@ suite("aggregate_with_roll_up") {
avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
from lineitem
left join orders on l_orderkey = o_orderkey
left join
left join
(select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
min(ps_availqty) qty_min,
avg(ps_supplycost) cost_avg
@ -1320,39 +1320,39 @@ suite("aggregate_with_roll_up") {
// should rewrite fail, because the part of query is join but mv is aggregate
def mv31_0 = """
select
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
o_orderkey,
count(*)
from
orders
select
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
o_orderkey,
count(*)
from
orders
left join lineitem on l_orderkey = o_orderkey
group by o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
group by o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
o_orderkey;
"""
def query31_0 = """
select
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
ps_partkey,
count(*)
from
orders left
select
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
ps_partkey,
count(*)
from
orders left
join lineitem on l_orderkey = o_orderkey
left join partsupp on ps_partkey = l_orderkey
group by
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
o_orderdate,
o_shippriority,
o_comment,
l_orderkey,
ps_partkey;
"""
order_qt_query31_0_before "${query31_0}"
@ -1362,22 +1362,22 @@ suite("aggregate_with_roll_up") {
// should rewrite fail, because the part of query is join but mv is aggregate
def mv32_0 = """
select
o_orderdate,
count(*)
from
orders
group by
select
o_orderdate,
count(*)
from
orders
group by
o_orderdate;
"""
def query32_0 = """
select
o_orderdate,
count(*)
from
orders
group by
o_orderdate,
select
o_orderdate,
count(*)
from
orders
group by
o_orderdate,
o_shippriority;
"""
order_qt_query32_0_before "${query32_0}"
@ -1385,4 +1385,264 @@ suite("aggregate_with_roll_up") {
order_qt_query32_0_after "${query32_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv32_0"""
// test combinator aggregate function rewrite
sql """set enable_agg_state=true"""
// query has no combinator and mv has combinator
// mv is union
def mv33_0 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderstatus)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query33_0 = """
select
o_orderstatus,
l_suppkey,
sum(o_shippriority),
group_concat(o_orderstatus),
avg(l_linenumber),
max_by(O_COMMENT,o_totalprice),
count(l_orderkey),
multi_distinct_count(l_shipmode)
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey;
"""
order_qt_query33_0_before "${query33_0}"
check_mv_rewrite_success(db, mv33_0, query33_0, "mv33_0")
order_qt_query33_0_after "${query33_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv33_0"""
// mv is merge
def mv33_1 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderstatus)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query33_1 = """
select
o_orderstatus,
l_suppkey,
sum(o_shippriority),
group_concat(o_orderstatus),
avg(l_linenumber),
max_by(O_COMMENT,o_totalprice),
count(l_orderkey),
multi_distinct_count(l_shipmode)
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey
order by o_orderstatus;
"""
order_qt_query33_1_before "${query33_1}"
check_mv_rewrite_fail(db, mv33_1, query33_1, "mv33_1")
order_qt_query33_1_after "${query33_1}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv33_1"""
// both query and mv are combinator
// mv is union, query is union
def mv34_0 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderstatus)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query34_0 = """
select
o_orderstatus,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderstatus)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey;
"""
check_mv_rewrite_success(db, mv34_0, query34_0, "mv34_0")
sql """ DROP MATERIALIZED VIEW IF EXISTS mv34_0"""
// mv is union, query is merge
def mv35_0 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderstatus)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query35_0 = """
select
o_orderstatus,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderstatus)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey
order by o_orderstatus;
"""
order_qt_query35_0_before "${query35_0}"
check_mv_rewrite_success(db, mv35_0, query35_0, "mv35_0")
order_qt_query35_0_after "${query35_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv35_0"""
// mv is merge, query is merge
def mv36_0 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderstatus)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query36_0 = """
select
o_orderstatus,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderstatus)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey
order by o_orderstatus;
"""
order_qt_query36_0_before "${query36_0}"
check_mv_rewrite_fail(db, mv36_0, query36_0, "mv36_0")
order_qt_query36_0_after "${query36_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv36_0"""
// mv is merge, query is union
def mv37_0 = """
select
o_orderstatus,
l_partkey,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderstatus)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_partkey,
l_suppkey;
"""
def query37_0 = """
select
o_orderstatus,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderstatus)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderstatus,
l_suppkey;
"""
check_mv_rewrite_fail(db, mv37_0, query37_0, "mv37_0")
sql """ DROP MATERIALIZED VIEW IF EXISTS mv37_0"""
}

View File

@ -24,6 +24,7 @@ suite("aggregate_without_roll_up") {
sql "SET enable_fallback_to_original_planner=false"
sql "SET enable_materialized_view_rewrite=true"
sql "SET enable_nereids_timeout = false"
sql "SET enable_agg_state = true"
sql """
drop table if exists orders
@ -1044,7 +1045,7 @@ suite("aggregate_without_roll_up") {
avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
from lineitem
left join orders on l_orderkey = o_orderkey
left join
left join
(select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
min(ps_availqty) qty_min,
avg(ps_supplycost) cost_avg
@ -1064,7 +1065,7 @@ suite("aggregate_without_roll_up") {
avg(case when l_partkey in (2, 3, 4) then l_discount + o_totalprice + part_supp_a.qty_sum else 50 end)
from lineitem
left join orders on l_orderkey = o_orderkey
left join
left join
(select ps_partkey, ps_suppkey, sum(ps_availqty) qty_sum, max(ps_availqty) qty_max,
min(ps_availqty) qty_min,
avg(ps_supplycost) cost_avg
@ -1078,4 +1079,289 @@ suite("aggregate_without_roll_up") {
check_mv_rewrite_fail(db, mv21_2, query21_2, "mv21_2")
order_qt_query21_2_after "${query21_2}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv21_2"""
def mv22_0 = """
select
o_orderdate,
l_partkey,
l_suppkey,
max_union(max_state(o_shippriority))
from lineitem
left join orders t2
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderdate,
l_partkey,
l_suppkey;
"""
def query22_0 = """
select
o_orderdate,
l_partkey,
l_suppkey,
max(o_shippriority)
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderdate,
l_partkey,
l_suppkey;
"""
order_qt_query22_0_before "${query22_0}"
check_mv_rewrite_success(db, mv22_0, query22_0, "mv22_0")
order_qt_query22_0_after "${query22_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv22_0"""
// test combinator aggregate function rewrite
sql """set enable_agg_state=true"""
// query has no combinator and mv has combinator
// mv is union
def mv23_0 = """
select
o_orderpriority,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderpriority)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query23_0 = """
select
o_orderpriority,
l_suppkey,
sum(o_shippriority),
group_concat(o_orderpriority),
avg(l_linenumber),
max_by(O_COMMENT,o_totalprice),
count(l_orderkey),
multi_distinct_count(l_shipmode)
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey
order by o_orderpriority;
"""
order_qt_query23_0_before "${query23_0}"
check_mv_rewrite_success(db, mv23_0, query23_0, "mv23_0")
order_qt_query23_0_after "${query23_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv23_0"""
// mv is merge
def mv23_1 = """
select
o_orderpriority,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderpriority)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query23_1 = """
select
o_orderpriority,
l_suppkey,
sum(o_shippriority),
group_concat(o_orderpriority),
avg(l_linenumber),
max_by(O_COMMENT,o_totalprice),
count(l_orderkey),
multi_distinct_count(l_shipmode)
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey
order by o_orderpriority;
"""
order_qt_query23_1_before "${query23_1}"
// not supported, this usage is rare
check_mv_rewrite_fail(db, mv23_1, query23_1, "mv23_1")
order_qt_query23_1_after "${query23_1}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv23_1"""
// both query and mv are combinator
// mv is union, query is union
def mv24_0 = """
select
o_orderpriority,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderpriority)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query24_0 = """
select
o_orderpriority,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderpriority)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
check_mv_rewrite_success(db, mv24_0, query24_0, "mv24_0")
sql """ DROP MATERIALIZED VIEW IF EXISTS mv24_0"""
// mv is union, query is merge
def mv25_0 = """
select
o_orderpriority,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderpriority)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query25_0 = """
select
o_orderpriority,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderpriority)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey
order by o_orderpriority;
"""
order_qt_query25_0_before "${query25_0}"
check_mv_rewrite_success(db, mv25_0, query25_0, "mv25_0")
order_qt_query25_0_after "${query25_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv25_0"""
// mv is merge, query is merge
def mv26_0 = """
select
o_orderpriority,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderpriority)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query26_0 = """
select
o_orderpriority,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderpriority)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey
order by o_orderpriority;
"""
order_qt_query26_0_before "${query26_0}"
check_mv_rewrite_success(db, mv26_0, query26_0, "mv26_0")
order_qt_query26_0_after "${query26_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv26_0"""
// mv is merge, query is union
def mv27_0 = """
select
o_orderpriority,
l_suppkey,
sum_merge(sum_state(o_shippriority)),
group_concat_merge(group_concat_state(o_orderpriority)),
avg_merge(avg_state(l_linenumber)),
max_by_merge(max_by_state(O_COMMENT,o_totalprice)),
count_merge(count_state(l_orderkey)),
multi_distinct_count_merge(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
def query27_0 = """
select
o_orderpriority,
l_suppkey,
sum_union(sum_state(o_shippriority)),
group_concat_union(group_concat_state(o_orderpriority)),
avg_union(avg_state(l_linenumber)),
max_by_union(max_by_state(O_COMMENT,o_totalprice)),
count_union(count_state(l_orderkey)),
multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
from lineitem
left join orders
on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
group by
o_orderpriority,
l_suppkey;
"""
check_mv_rewrite_fail(db, mv27_0, query27_0, "mv27_0")
sql """ DROP MATERIALIZED VIEW IF EXISTS mv27_0"""
}