[Feature](Nereids) Support hll and count for materialized index. (#15275)

This commit is contained in:
Shuo Wang
2022-12-27 00:38:04 +08:00
committed by GitHub
parent 650136c32e
commit 325d247b92
9 changed files with 442 additions and 143 deletions

View File

@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import com.google.common.collect.ImmutableList;
@ -57,7 +58,8 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(Sum.class),
agg(GroupBitAnd.class, "group_bit_and"),
agg(GroupBitOr.class, "group_bit_or"),
agg(GroupBitXor.class, "group_bit_xor")
agg(GroupBitXor.class, "group_bit_xor"),
agg(Ndv.class)
);

View File

@ -35,9 +35,13 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -136,8 +140,8 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
result.exprRewriteMap),
agg.isNormalized(),
agg.getSourceRepeat(),
// Not that no need to replace slots in the filter, because the slots to replace
// are value columns, which shouldn't appear in filters.
// Note that no need to replace slots in the filter, because the slots to
// replace are value columns, which shouldn't appear in filters.
filter.withChildren(
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
);
@ -311,15 +315,13 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
.collect(Collectors.groupingBy(index -> index.getId() == table.getBaseIndexId()));
// Duplicate-keys table could use base index and indexes that pre-aggregation status is on.
Stream<MaterializedIndex> checkPreAggResult = Stream.concat(
Set<MaterializedIndex> candidatesWithoutRewriting = Stream.concat(
indexesGroupByIsBaseOrNot.get(true).stream(),
indexesGroupByIsBaseOrNot.getOrDefault(false, ImmutableList.of())
.stream()
.filter(index -> checkPreAggStatus(scan, index.getId(), predicates,
aggregateFunctions, groupingExprs).isOn())
);
Set<MaterializedIndex> candidatesWithoutRewriting = checkPreAggResult.collect(Collectors.toSet());
).collect(ImmutableSet.toImmutableSet());
// try to rewrite bitmap, hll by materialized index columns.
List<AggRewriteResult> candidatesWithRewriting = indexesGroupByIsBaseOrNot.getOrDefault(false,
@ -328,6 +330,11 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
.filter(index -> !candidatesWithoutRewriting.contains(index))
.map(index -> rewriteAgg(index, scan, requiredScanOutput, predicates, aggregateFunctions,
groupingExprs))
.filter(aggRewriteResult -> checkPreAggStatus(scan, aggRewriteResult.index.getId(),
predicates,
// check pre-agg status of aggregate function that couldn't rewrite.
aggFuncsDiff(aggregateFunctions, aggRewriteResult),
groupingExprs).isOn())
.filter(result -> result.success)
.collect(Collectors.toList());
@ -354,6 +361,16 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
}
}
private List<AggregateFunction> aggFuncsDiff(List<AggregateFunction> aggregateFunctions,
AggRewriteResult aggRewriteResult) {
if (aggRewriteResult.success) {
return ImmutableList.copyOf(Sets.difference(ImmutableSet.copyOf(aggregateFunctions),
aggRewriteResult.exprRewriteMap.aggFuncMap.keySet()));
} else {
return aggregateFunctions;
}
}
private static class SelectResult {
public final PreAggStatus preAggStatus;
public final long indexId;
@ -468,10 +485,8 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
return checkAggFunc(sum, AggregateType.SUM, extractSlotId(sum.child()), context, false);
}
// TODO: select count(xxx) for duplicated-keys table.
@Override
public PreAggStatus visitCount(Count count, CheckContext context) {
// Now count(distinct key_column) is only supported for aggregate-keys and unique-keys OLAP table.
if (count.isDistinct() && count.arity() == 1) {
Optional<ExprId> exprIdOpt = extractSlotId(count.child(0));
if (exprIdOpt.isPresent() && context.exprIdToKeyColumn.containsKey(exprIdOpt.get())) {
@ -492,6 +507,16 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
}
}
@Override
public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, CheckContext context) {
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllUnionAgg.child());
if (slotOpt.isPresent() && context.exprIdToValueColumn.containsKey(slotOpt.get().getExprId())) {
return PreAggStatus.on();
} else {
return PreAggStatus.off("invalid hll_union_agg: " + hllUnionAgg.toSql());
}
}
private PreAggStatus checkAggFunc(
AggregateFunction aggFunc,
AggregateType matchingAggType,
@ -711,34 +736,62 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
/**
* count(distinct col) -> bitmap_union_count(mv_bitmap_union_col)
* count(col) -> sum(mv_count_col)
*/
@Override
public Expression visitCount(Count count, RewriteContext context) {
if (count.isDistinct() && count.arity() == 1) {
// count(distinct col) -> bitmap_union_count(mv_bitmap_union_col)
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
// count distinct a value column.
if (slotOpt.isPresent() && !context.checkContext.exprIdToKeyColumn.containsKey(
slotOpt.get().getExprId())) {
String bitmapUnionCountColumn = CreateMaterializedViewStmt
String bitmapUnionColumn = CreateMaterializedViewStmt
.mvColumnBuilder(AggregateType.BITMAP_UNION.name().toLowerCase(), slotOpt.get().getName());
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(bitmapUnionCountColumn);
// has bitmap_union_count column
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(bitmapUnionColumn);
// has bitmap_union column
if (mvColumn != null && context.checkContext.exprIdToValueColumn.containsValue(mvColumn)) {
Slot bitmapUnionCountSlot = context.checkContext.scan.getNonUserVisibleOutput()
Slot bitmapUnionSlot = context.checkContext.scan.getNonUserVisibleOutput()
.stream()
.filter(s -> s.getName().equals(bitmapUnionCountColumn))
.filter(s -> s.getName().equals(bitmapUnionColumn))
.findFirst()
.get();
context.exprRewriteMap.slotMap.put(slotOpt.get(), bitmapUnionCountSlot);
context.exprRewriteMap.projectExprMap.put(slotOpt.get(), bitmapUnionCountSlot);
BitmapUnionCount bitmapUnionCount = new BitmapUnionCount(bitmapUnionCountSlot);
context.exprRewriteMap.slotMap.put(slotOpt.get(), bitmapUnionSlot);
context.exprRewriteMap.projectExprMap.put(slotOpt.get(), bitmapUnionSlot);
BitmapUnionCount bitmapUnionCount = new BitmapUnionCount(bitmapUnionSlot);
context.exprRewriteMap.aggFuncMap.put(count, bitmapUnionCount);
return bitmapUnionCount;
}
}
} else if (!count.isDistinct() && count.arity() == 1) {
// count(col) -> sum(mv_count_col)
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
// count a value column.
if (slotOpt.isPresent() && !context.checkContext.exprIdToKeyColumn.containsKey(
slotOpt.get().getExprId())) {
String countColumn = CreateMaterializedViewStmt
.mvColumnBuilder("count", slotOpt.get().getName());
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(countColumn);
// has bitmap_union_count column
if (mvColumn != null && context.checkContext.exprIdToValueColumn.containsValue(mvColumn)) {
Slot countSlot = context.checkContext.scan.getNonUserVisibleOutput()
.stream()
.filter(s -> s.getName().equals(countColumn))
.findFirst()
.get();
context.exprRewriteMap.slotMap.put(slotOpt.get(), countSlot);
context.exprRewriteMap.projectExprMap.put(slotOpt.get(), countSlot);
Sum sum = new Sum(countSlot);
context.exprRewriteMap.aggFuncMap.put(count, sum);
return sum;
}
}
}
return count;
}
@ -776,6 +829,103 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
return bitmapUnionCount;
}
/**
* hll_union(hll_hash(col)) to hll_union(mv_hll_union_col)
*/
@Override
public Expression visitHllUnion(HllUnion hllUnion, RewriteContext context) {
if (hllUnion.child() instanceof HllHash) {
HllHash hllHash = (HllHash) hllUnion.child();
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
if (slotOpt.isPresent()) {
String hllUnionColumn = CreateMaterializedViewStmt
.mvColumnBuilder(AggregateType.HLL_UNION.name().toLowerCase(), slotOpt.get().getName());
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(hllUnionColumn);
// has hll_union column
if (mvColumn != null && context.checkContext.exprIdToValueColumn.containsValue(mvColumn)) {
Slot hllUnionSlot = context.checkContext.scan.getNonUserVisibleOutput()
.stream()
.filter(s -> s.getName().equals(hllUnionColumn))
.findFirst()
.get();
context.exprRewriteMap.slotMap.put(slotOpt.get(), hllUnionSlot);
context.exprRewriteMap.projectExprMap.put(hllHash, hllUnionSlot);
HllUnion newHllUnion = new HllUnion(hllUnionSlot);
context.exprRewriteMap.aggFuncMap.put(hllUnion, newHllUnion);
return newHllUnion;
}
}
}
return hllUnion;
}
/**
* hll_union_agg(hll_hash(col)) -> hll_union-agg(mv_hll_union_col)
*/
@Override
public Expression visitHllUnionAgg(HllUnionAgg hllUnionAgg, RewriteContext context) {
if (hllUnionAgg.child() instanceof HllHash) {
HllHash hllHash = (HllHash) hllUnionAgg.child();
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
if (slotOpt.isPresent()) {
String hllUnionColumn = CreateMaterializedViewStmt
.mvColumnBuilder(AggregateType.HLL_UNION.name().toLowerCase(), slotOpt.get().getName());
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(hllUnionColumn);
// has hll_union column
if (mvColumn != null && context.checkContext.exprIdToValueColumn.containsValue(mvColumn)) {
Slot hllUnionSlot = context.checkContext.scan.getNonUserVisibleOutput()
.stream()
.filter(s -> s.getName().equals(hllUnionColumn))
.findFirst()
.get();
context.exprRewriteMap.slotMap.put(slotOpt.get(), hllUnionSlot);
context.exprRewriteMap.projectExprMap.put(hllHash, hllUnionSlot);
HllUnionAgg newHllUnionAgg = new HllUnionAgg(hllUnionSlot);
context.exprRewriteMap.aggFuncMap.put(hllUnionAgg, newHllUnionAgg);
return newHllUnionAgg;
}
}
}
return hllUnionAgg;
}
/**
* ndv(col) -> hll_union_agg(mv_hll_union_col)
*/
@Override
public Expression visitNdv(Ndv ndv, RewriteContext context) {
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0));
// ndv on a value column.
if (slotOpt.isPresent() && !context.checkContext.exprIdToKeyColumn.containsKey(
slotOpt.get().getExprId())) {
String hllUnionColumn = CreateMaterializedViewStmt
.mvColumnBuilder(AggregateType.HLL_UNION.name().toLowerCase(), slotOpt.get().getName());
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(hllUnionColumn);
// has hll_union column
if (mvColumn != null && context.checkContext.exprIdToValueColumn.containsValue(mvColumn)) {
Slot hllUnionSlot = context.checkContext.scan.getNonUserVisibleOutput()
.stream()
.filter(s -> s.getName().equals(hllUnionColumn))
.findFirst()
.get();
context.exprRewriteMap.slotMap.put(slotOpt.get(), hllUnionSlot);
context.exprRewriteMap.projectExprMap.put(slotOpt.get(), hllUnionSlot);
HllUnionAgg hllUnionAgg = new HllUnionAgg(hllUnionSlot);
context.exprRewriteMap.aggFuncMap.put(ndv, hllUnionAgg);
return hllUnionAgg;
}
}
return ndv;
}
}
private List<NamedExpression> replaceAggOutput(

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableList;
@ -34,7 +35,6 @@ import java.util.List;
* Rewrite count distinct for bitmap and hll type value.
* <p>
* count(distinct bitmap_col) -> bitmap_union_count(bitmap col)
* todo: add support for HLL type.
*/
public class CountDistinctRewrite extends OneRewriteRuleFactory {
@Override
@ -63,6 +63,9 @@ public class CountDistinctRewrite extends OneRewriteRuleFactory {
if (child.getDataType().isBitmap()) {
return new BitmapUnionCount(child);
}
if (child.getDataType().isHll()) {
return new HllUnionAgg(child);
}
}
return count;
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.HllType;
@ -61,4 +62,9 @@ public class HllUnion extends AggregateFunction
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitHllUnion(this, context);
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.HllType;
@ -62,4 +63,9 @@ public class HllUnionAgg extends AggregateFunction
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitHllUnionAgg(this, context);
}
}

View File

@ -0,0 +1,121 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.JsonType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.QuantileStateType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TimeType;
import org.apache.doris.nereids.types.TimeV2Type;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* AggregateFunction 'ndv'. This class is generated by GenerateFunction.
*/
public class Ndv extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE).args(DecimalV3Type.DEFAULT_DECIMAL32),
FunctionSignature.ret(BigIntType.INSTANCE).args(DecimalV3Type.DEFAULT_DECIMAL64),
FunctionSignature.ret(BigIntType.INSTANCE).args(DecimalV3Type.DEFAULT_DECIMAL128),
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE).args(StringType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateTimeType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateV2Type.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE).args(TimeType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TimeV2Type.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(JsonType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(HllType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BitmapType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(QuantileStateType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(NullType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(CharType.SYSTEM_DEFAULT)
);
/**
* constructor with 1 argument.
*/
public Ndv(Expression arg) {
super("ndv", arg);
}
/**
* withChildren.
*/
@Override
public Ndv withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Ndv(children.get(0));
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public AggregateFunction withDistinctAndChildren(boolean isDistinct, List<Expression> children) {
return withChildren(children);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitNdv(this, context);
}
}

View File

@ -24,10 +24,13 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
/** AggregateFunctionVisitor. */
@ -77,4 +80,16 @@ public interface AggregateFunctionVisitor<R, C> {
default R visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, C context) {
return visitAggregateFunction(bitmapUnionCount, context);
}
default R visitNdv(Ndv ndv, C context) {
return visitAggregateFunction(ndv, context);
}
default R visitHllUnionAgg(HllUnionAgg hllUnionAgg, C context) {
return visitAggregateFunction(hllUnionAgg, context);
}
default R visitHllUnion(HllUnion hllUnion, C context) {
return visitAggregateFunction(hllUnion, context);
}
}

View File

@ -488,6 +488,10 @@ public abstract class DataType implements AbstractDataType {
return this instanceof BitmapType;
}
public boolean isHll() {
return this instanceof HllType;
}
public DataType promotion() {
if (PROMOTION_MAP.containsKey(this.getClass())) {
return PROMOTION_MAP.get(this.getClass()).get();

View File

@ -19,9 +19,16 @@ package org.apache.doris.nereids.rules.mv;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.common.FeConstants;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.PatternMatchSupported;
@ -135,17 +142,14 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
// dorisAssert.query(query2).explainWithout(QUERY_USE_EMPS_MV);
// }
/**
* TODO: enable this when union is supported.
*/
@Disabled
@Test
public void testUnionQueryOnProjectionMV() throws Exception {
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from "
+ EMPS_TABLE_NAME + " order by deptno;";
String union = "select empid from " + EMPS_TABLE_NAME + " where deptno > 300" + " union all select empid from"
+ " " + EMPS_TABLE_NAME + " where deptno < 200";
createMv(createMVSql);
testMv(union, EMPS_MV_NAME);
testMvWithTwoTable(union, EMPS_MV_NAME, EMPS_MV_NAME);
}
@Test
@ -167,22 +171,6 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
testMv(query, EMPS_MV_NAME);
}
/*
TODO
The deduplicate materialized view is not yet supported
@Test
public void testAggQueryOnDeduplicatedMV() throws Exception {
String deduplicateSQL = "select deptno, empid, name, salary, commission from " + EMPS_TABLE_NAME + " group "
+ "by" + " deptno, empid, name, salary, commission";
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as " + deduplicateSQL + ";";
String query1 = "select deptno, sum(salary) from (" + deduplicateSQL + ") A group by deptno;";
createMv(createMVSql);
testMv(query1, EMPS_MV_NAME);
String query2 = "select deptno, empid from " + EMPS_TABLE_NAME + ";";
dorisAssert.query(query2).explainWithout(QUERY_USE_EMPS_MV);
}
*/
@Test
public void testAggQueryOnAggMV3() throws Exception {
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary)"
@ -308,11 +296,7 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
testMv(query, EMPS_TABLE_NAME);
}
/**
* Aggregation query with set operand
* TODO: enable this when union is supported.
*/
@Disabled
@Test
public void testAggQueryWithSetOperandOnAggMV() throws Exception {
String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, count(salary) "
+ "from " + EMPS_TABLE_NAME + " group by deptno;";
@ -321,7 +305,7 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
+ "select deptno, count(salary) + count(1) from " + EMPS_TABLE_NAME
+ " group by deptno;";
createMv(createMVSql);
testMv(query, EMPS_TABLE_NAME);
testMvWithTwoTable(query, EMPS_TABLE_NAME, EMPS_TABLE_NAME);
}
@Test
@ -627,28 +611,24 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
testMv(query, EMPS_MV_NAME);
}
/**
* TODO: enable this when union is supported.
*/
@Disabled
@Test
public void testUnionAll() throws Exception {
// String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from "
// + EMPS_TABLE_NAME + " order by empid, deptno;";
// String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union all select empid,"
// + " deptno from " + EMPS_TABLE_NAME + " where empid <0;";
// dorisAssert.withMaterializedView(createEmpsMVsql).query(query).explainContains(QUERY_USE_EMPS_MV, 2);
String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from "
+ EMPS_TABLE_NAME + " order by empid, deptno;";
String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union all select empid,"
+ " deptno from " + EMPS_TABLE_NAME + " where empid <0;";
createMv(createEmpsMVsql);
testMvWithTwoTable(query, EMPS_MV_NAME, EMPS_MV_NAME);
}
/**
* TODO: enable this when union is supported.
*/
@Disabled
@Test
public void testUnionDistinct() throws Exception {
// String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from "
// + EMPS_TABLE_NAME + " order by empid, deptno;";
// String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union select empid,"
// + " deptno from " + EMPS_TABLE_NAME + " where empid <0;";
// dorisAssert.withMaterializedView(createEmpsMVsql).query(query).explainContains(QUERY_USE_EMPS_MV, 2);
String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from "
+ EMPS_TABLE_NAME + " order by empid, deptno;";
createMv(createEmpsMVsql);
String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union select empid,"
+ " deptno from " + EMPS_TABLE_NAME + " where empid <0;";
testMvWithTwoTable(query, EMPS_MV_NAME, EMPS_MV_NAME);
}
/**
@ -668,6 +648,7 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
String query = "select k1, k2 from agg_table;";
// todo: `preagg` should be ture when rollup could be used.
singleTableTest(query, "only_keys", false);
dropTable("agg_table", true);
}
/**
@ -855,18 +836,30 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
dropTable(TEST_TABLE_NAME, true);
}
/**
* TODO: enable this when hll is supported.
*/
@Disabled
@Test
public void testAggTableCountDistinctInHllType() throws Exception {
// String aggTable = "CREATE TABLE " + TEST_TABLE_NAME + " (k1 int, v1 hll " + FunctionSet.HLL_UNION
// + ") Aggregate KEY (k1) "
// + "DISTRIBUTED BY HASH(k1) BUCKETS 3 PROPERTIES ('replication_num' = '1');";
// dorisAssert.withTable(aggTable);
// String query = "select k1, count(distinct v1) from " + TEST_TABLE_NAME + " group by k1;";
// dorisAssert.query(query).explainContains(TEST_TABLE_NAME, "hll_union_agg");
// dorisAssert.dropTable(TEST_TABLE_NAME, true);
String aggTable = "CREATE TABLE " + TEST_TABLE_NAME + " (k1 int, v1 hll " + FunctionSet.HLL_UNION
+ ") Aggregate KEY (k1) "
+ "DISTRIBUTED BY HASH(k1) BUCKETS 3 PROPERTIES ('replication_num' = '1');";
createTable(aggTable);
String query = "select k1, count(distinct v1) from " + TEST_TABLE_NAME + " group by k1;";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate(logicalOlapScan()).when(agg -> {
// k1#0, hll_union_agg(v1#1) AS `count(distinct v1)`#2
List<NamedExpression> output = agg.getOutputExpressions();
Assertions.assertEquals(2, output.size());
NamedExpression output1 = output.get(1);
Assertions.assertTrue(output1 instanceof Alias);
Alias alias = (Alias) output1;
Expression aliasChild = alias.child();
Assertions.assertTrue(aliasChild instanceof HllUnionAgg);
HllUnionAgg hllUnionAgg = (HllUnionAgg) aliasChild;
Assertions.assertEquals("v1", ((Slot) hllUnionAgg.child()).getName());
return true;
}));
dropTable(TEST_TABLE_NAME, true);
}
/**
@ -904,17 +897,14 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
testMv(query, USER_TAG_TABLE_NAME);
}
/**
* TODO: enable this when hll is supported.
*/
@Disabled
@Test
public void testNDVToHll() throws Exception {
// String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
// + "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME
// + " group by user_id;";
// dorisAssert.withMaterializedView(createUserTagMVSql);
// String query = "select ndv(tag_id) from " + USER_TAG_TABLE_NAME + ";";
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, "hll_union_agg");
String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
+ "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME
+ " group by user_id;";
createMv(createUserTagMVSql);
String query = "select ndv(tag_id) from " + USER_TAG_TABLE_NAME + ";";
testMv(query, USER_TAG_MV_NAME);
}
/**
@ -930,23 +920,43 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, "hll_union_agg");
}
/**
* TODO: enable this when hll is supported.
*/
@Test
public void testHLLUnionFamilyRewrite() throws Exception {
// String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
// + "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME
// + " group by user_id;";
// createMv(createUserTagMVSql);
// String query = "select `" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from "
// + USER_TAG_TABLE_NAME + ";";
// String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder("" + FunctionSet.HLL_UNION + "", "tag_id");
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName);
// query = "select hll_union_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";";
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName);
// query = "select hll_raw_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";";
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName);
String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
+ "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME
+ " group by user_id;";
createMv(createUserTagMVSql);
String query = "select `" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from "
+ USER_TAG_TABLE_NAME + ";";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate().when(agg -> {
assertOneAggFuncType(agg, HllUnion.class);
return true;
}));
testMv(query, USER_TAG_MV_NAME);
query = "select hll_union_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate().when(agg -> {
assertOneAggFuncType(agg, HllUnionAgg.class);
return true;
}));
testMv(query, USER_TAG_MV_NAME);
query = "select hll_raw_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate().when(agg -> {
assertOneAggFuncType(agg, HllUnion.class);
return true;
}));
testMv(query, USER_TAG_MV_NAME);
}
@Test
@ -958,41 +968,20 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
testMv(query, EMPS_TABLE_NAME);
}
/**
* TODO: support count in mv.
*/
@Disabled
@Test
public void testCountFieldInQuery() throws Exception {
// String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
// + "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;";
// createMv(createUserTagMVSql);
// String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + ";";
// String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder(FunctionSet.COUNT, "tag_id");
// // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName);
//
// String explain = getSQLPlanOrErrorMsg(query);
// mv_count_tag_id
/*
PARTITION: HASH_PARTITIONED: `default_cluster:db1`.`user_tags`.`time_col`
STREAM DATA SINK
EXCHANGE ID: 02
UNPARTITIONED
1:VAGGREGATE (update serialize)
| output: sum(`mv_count_tag_id`)
| group by:
| cardinality=1
|
0:VOlapScanNode
TABLE: user_tags(user_tags_mv), PREAGGREGATION: ON
partitions=1/1, tablets=3/3, tabletList=10034,10036,10038
cardinality=0, avgRowSize=8.0, numNodes=1
*/
// System.out.println("mvColumnName:" + mvColumnName);
// System.out.println("explain:\n" + explain);
// query = "select user_name, count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_name;";
// dorisAssert.query(query).explainWithout(USER_TAG_MV_NAME);
String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
+ "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;";
createMv(createUserTagMVSql);
String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + ";";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate().when(agg -> {
assertOneAggFuncType(agg, Sum.class);
return true;
}));
testMv(query, USER_TAG_MV_NAME);
}
@Test
@ -1013,17 +1002,20 @@ public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implement
dropTable("agg_table", true);
}
/**
* TODO: support count in mv.
*/
@Disabled
@Test
public void testSelectMVWithTableAlias() throws Exception {
// String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
// + "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;";
// dorisAssert.withMaterializedView(createUserTagMVSql);
// String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + " t ;";
// String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder(FunctionSet.COUNT, "tag_id");
// dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName);
String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, "
+ "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;";
createMv(createUserTagMVSql);
String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + " t ;";
PlanChecker.from(connectContext)
.analyze(query)
.rewrite()
.matches(logicalAggregate().when(agg -> {
assertOneAggFuncType(agg, Sum.class);
return true;
}));
testMv(query, USER_TAG_MV_NAME);
}
@Test