[Feature](materialized view) support query match mv with agg_state on nereids planner (#21067)

* support create mv contain aggstate column

* update

* update

* update

* support query match mv with agg_state on nereids planner

update

* update

* update
This commit is contained in:
Pxl
2023-07-03 10:19:31 +08:00
committed by GitHub
parent f90e8fcb26
commit 59c1bbd163
6 changed files with 114 additions and 14 deletions

View File

@ -448,11 +448,11 @@ public class NativeInsertStmt extends InsertStmt {
}
targetColumns.add(col);
}
// hll column mush in mentionedColumns
// hll column must in mentionedColumns
for (Column col : targetTable.getBaseSchema()) {
if (col.getType().isObjectStored() && !mentionedColumns.contains(col.getName())) {
throw new AnalysisException(
" object-stored column " + col.getName() + " mush in insert into columns");
"object-stored column " + col.getName() + " must in insert into columns");
}
}
}

View File

@ -28,12 +28,14 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
@ -45,6 +47,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
@ -67,6 +71,7 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@ -1159,6 +1164,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
*/
@Override
public Expression visitCount(Count count, RewriteContext context) {
Expression result = visitAggregateFunction(count, context);
if (result != count) {
return result;
}
if (count.isDistinct() && count.arity() == 1) {
// count(distinct col) -> bitmap_union_count(mv_bitmap_union_col)
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
@ -1225,6 +1234,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
*/
@Override
public Expression visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, RewriteContext context) {
Expression result = visitAggregateFunction(bitmapUnionCount, context);
if (result != bitmapUnionCount) {
return result;
}
if (bitmapUnionCount.child() instanceof ToBitmap) {
ToBitmap toBitmap = (ToBitmap) bitmapUnionCount.child();
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child());
@ -1291,6 +1304,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
*/
@Override
public Expression visitHllUnion(HllUnion hllUnion, RewriteContext context) {
Expression result = visitAggregateFunction(hllUnion, context);
if (result != hllUnion) {
return result;
}
if (hllUnion.child() instanceof HllHash) {
HllHash hllHash = (HllHash) hllUnion.child();
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
@ -1327,6 +1344,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
*/
@Override
public Expression visitHllUnionAgg(HllUnionAgg hllUnionAgg, RewriteContext context) {
Expression result = visitAggregateFunction(hllUnionAgg, context);
if (result != hllUnionAgg) {
return result;
}
if (hllUnionAgg.child() instanceof HllHash) {
HllHash hllHash = (HllHash) hllUnionAgg.child();
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
@ -1363,6 +1384,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
*/
@Override
public Expression visitNdv(Ndv ndv, RewriteContext context) {
Expression result = visitAggregateFunction(ndv, context);
if (result != ndv) {
return result;
}
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0));
// ndv on a value column.
if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey(
@ -1391,6 +1416,36 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
}
return ndv;
}
/**
* agg(col) -> agg_merge(mva_generic_aggregation__agg_state(col)) eg: max_by(k2,
* k3) -> max_by_merge(mva_generic_aggregation__max_by_state(k2, k3))
*/
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, RewriteContext context) {
String aggStateName = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
AggregateType.GENERIC_AGGREGATION, StateCombinator.create(aggregateFunction).toSql()));
Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(aggStateName);
if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
Slot aggStateSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
.filter(s -> aggStateName.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
.orElseThrow(() -> new AnalysisException("cannot find agg state slot when select mv"));
Set<Slot> slots = aggregateFunction.collect(SlotReference.class::isInstance);
for (Slot slot : slots) {
if (!context.checkContext.keyNameToColumn.containsKey(normalizeName(slot.toSql()))) {
context.exprRewriteMap.slotMap.put(slot, aggStateSlot);
context.exprRewriteMap.projectExprMap.put(slot, aggStateSlot);
}
}
MergeCombinator mergeCombinator = new MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction);
context.exprRewriteMap.aggFuncMap.put(aggregateFunction, mergeCombinator);
return mergeCombinator;
}
return aggregateFunction;
}
}
private List<NamedExpression> replaceAggOutput(

View File

@ -57,6 +57,10 @@ public class StateCombinator extends ScalarFunction
}).collect(ImmutableList.toImmutableList()));
}
public static StateCombinator create(AggregateFunction nested) {
return new StateCombinator(nested.getArguments(), nested);
}
@Override
public StateCombinator withChildren(List<Expression> children) {
return new StateCombinator(children, nested);