diff --git a/be/src/util/timezone_utils.cpp b/be/src/util/timezone_utils.cpp index 2b22e52126..e4d19946a7 100644 --- a/be/src/util/timezone_utils.cpp +++ b/be/src/util/timezone_utils.cpp @@ -33,7 +33,7 @@ bool TimezoneUtils::find_cctz_time_zone(const std::string& timezone, cctz::time_ 1)) { bool positive = value[0] != '-'; - //Regular expression guarantees hour and minute mush be int + //Regular expression guarantees hour and minute must be int int hour = std::stoi(value.substr(1, 2).as_string()); int minute = std::stoi(value.substr(4, 2).as_string()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java index 408fcc0952..591f8191f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java @@ -448,11 +448,11 @@ public class NativeInsertStmt extends InsertStmt { } targetColumns.add(col); } - // hll column mush in mentionedColumns + // hll column must in mentionedColumns for (Column col : targetTable.getBaseSchema()) { if (col.getType().isObjectStored() && !mentionedColumns.contains(col.getName())) { throw new AnalysisException( - " object-stored column " + col.getName() + " mush in insert into columns"); + "object-stored column " + col.getName() + " must in insert into columns"); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index 6c04521e79..7d0b57d649 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -28,12 +28,14 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; +import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; @@ -45,6 +47,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; +import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapHash; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash; import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; @@ -67,6 +71,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Streams; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -1159,6 +1164,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial */ @Override public Expression visitCount(Count count, RewriteContext context) { + Expression result = visitAggregateFunction(count, context); + if (result != count) { + return result; + } if (count.isDistinct() && count.arity() == 1) { // count(distinct col) -> bitmap_union_count(mv_bitmap_union_col) Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0)); @@ -1225,6 +1234,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial */ @Override public Expression visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, RewriteContext context) { + Expression result = visitAggregateFunction(bitmapUnionCount, context); + if (result != bitmapUnionCount) { + return result; + } if (bitmapUnionCount.child() instanceof ToBitmap) { ToBitmap toBitmap = (ToBitmap) bitmapUnionCount.child(); Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child()); @@ -1291,6 +1304,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial */ @Override public Expression visitHllUnion(HllUnion hllUnion, RewriteContext context) { + Expression result = visitAggregateFunction(hllUnion, context); + if (result != hllUnion) { + return result; + } if (hllUnion.child() instanceof HllHash) { HllHash hllHash = (HllHash) hllUnion.child(); Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child()); @@ -1327,6 +1344,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial */ @Override public Expression visitHllUnionAgg(HllUnionAgg hllUnionAgg, RewriteContext context) { + Expression result = visitAggregateFunction(hllUnionAgg, context); + if (result != hllUnionAgg) { + return result; + } if (hllUnionAgg.child() instanceof HllHash) { HllHash hllHash = (HllHash) hllUnionAgg.child(); Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child()); @@ -1363,6 +1384,10 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial */ @Override public Expression visitNdv(Ndv ndv, RewriteContext context) { + Expression result = visitAggregateFunction(ndv, context); + if (result != ndv) { + return result; + } Optional slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0)); // ndv on a value column. if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey( @@ -1391,6 +1416,36 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial } return ndv; } + + /** + * agg(col) -> agg_merge(mva_generic_aggregation__agg_state(col)) eg: max_by(k2, + * k3) -> max_by_merge(mva_generic_aggregation__max_by_state(k2, k3)) + */ + @Override + public Expression visitAggregateFunction(AggregateFunction aggregateFunction, RewriteContext context) { + String aggStateName = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder( + AggregateType.GENERIC_AGGREGATION, StateCombinator.create(aggregateFunction).toSql())); + + Column mvColumn = context.checkContext.scan.getTable().getVisibleColumn(aggStateName); + if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) { + Slot aggStateSlot = context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream() + .filter(s -> aggStateName.equalsIgnoreCase(normalizeName(s.getName()))).findFirst() + .orElseThrow(() -> new AnalysisException("cannot find agg state slot when select mv")); + + Set slots = aggregateFunction.collect(SlotReference.class::isInstance); + for (Slot slot : slots) { + if (!context.checkContext.keyNameToColumn.containsKey(normalizeName(slot.toSql()))) { + context.exprRewriteMap.slotMap.put(slot, aggStateSlot); + context.exprRewriteMap.projectExprMap.put(slot, aggStateSlot); + } + } + + MergeCombinator mergeCombinator = new MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction); + context.exprRewriteMap.aggFuncMap.put(aggregateFunction, mergeCombinator); + return mergeCombinator; + } + return aggregateFunction; + } } private List replaceAggOutput( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java index 9b97a7afd4..db001a6793 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java @@ -57,6 +57,10 @@ public class StateCombinator extends ScalarFunction }).collect(ImmutableList.toImmutableList())); } + public static StateCombinator create(AggregateFunction nested) { + return new StateCombinator(nested.getArguments(), nested); + } + @Override public StateCombinator withChildren(List children) { return new StateCombinator(children, nested); diff --git a/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out b/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out index e8082f928a..406e9fa334 100644 --- a/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out +++ b/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out @@ -1,8 +1,24 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !select_star -- \N 4 \N d --4 -4 -4 d +1 -4 -4 d +1 -3 \N c 1 1 1 a -2 2 2 b -3 -3 \N c +1 2 2 b + +-- !select_mv -- +\N \N +1 2 + +-- !select_mv -- +\N \N +1 4 + +-- !select_mv -- +\N \N +1 4 + +-- !select_mv -- +\N \N +1 -4 diff --git a/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy b/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy index 071f36bb69..7fd7d4fef1 100644 --- a/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy +++ b/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy @@ -36,20 +36,45 @@ suite ("test_agg_state_max_by") { """ sql "insert into d_table select 1,1,1,'a';" - sql "insert into d_table select 2,2,2,'b';" - sql "insert into d_table select 3,-3,null,'c';" + sql "insert into d_table select 1,2,2,'b';" + sql "insert into d_table select 1,-3,null,'c';" sql "insert into d_table(k4,k2) values('d',4);" createMV("create materialized view k1mb as select k1,max_by(k2,k3) from d_table group by k1;") - sql "insert into d_table select -4,-4,-4,'d';" + sql "insert into d_table select 1,-4,-4,'d';" - qt_select_star "select * from d_table order by k1;" -/* + qt_select_star "select * from d_table order by 1,2;" explain { - sql("select k1,max_by(k2,k3) from d_table group by k1 order by k1;") + sql("select k1,max_by(k2,k3) from d_table group by k1 order by 1,2;") contains "(k1mb)" } - qt_select_mv "select k1,max_by(k2,k3) from d_table group by k1 order by k1;" -*/ + qt_select_mv "select k1,max_by(k2,k3) from d_table group by k1 order by 1,2;" + + createMV("create materialized view k1mbcp1 as select k1,max_by(k2+k3,abs(k3)) from d_table group by k1;") + createMV("create materialized view k1mbcp2 as select k1,max_by(k2+k3,k3) from d_table group by k1;") + createMV("create materialized view k1mbcp3 as select k1,max_by(k2,abs(k3)) from d_table group by k1;") + + sql "insert into d_table(k4,k2) values('d',4);" + sql "set enable_nereids_dml = true" + sql "insert into d_table(k4,k2) values('d',4);" + sql "insert into d_table select 1,-4,-4,'d';" + + explain { + sql("select k1,max_by(k2+k3,abs(k3)) from d_table group by k1 order by 1,2;") + contains "(k1mbcp1)" + } + qt_select_mv "select k1,max_by(k2+k3,k3) from d_table group by k1 order by 1,2;" + + explain { + sql("select k1,max_by(k2+k3,k3) from d_table group by k1 order by 1,2;") + contains "(k1mbcp2)" + } + qt_select_mv "select k1,max_by(k2+k3,k3) from d_table group by k1 order by 1,2;" + + explain { + sql("select k1,max_by(k2,abs(k3)) from d_table group by k1 order by 1,2;") + contains "(k1mbcp3)" + } + qt_select_mv "select k1,max_by(k2,abs(k3)) from d_table group by k1 order by 1,2;" }