[feature](Nereids) support agg state type in create table (#32171)

this PR introduce a behavior change, syntax of create table with agg_state type is changed.
This commit is contained in:
morrySnow
2024-03-14 14:57:32 +08:00
committed by yiguolei
parent 62023d705d
commit ea2fbfaffa
24 changed files with 121 additions and 111 deletions

View File

@ -41,8 +41,8 @@ Create table example:
```sql
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int),
k3 agg_state group_concat(string)
k2 agg_state<max_by(int not null,int)> generic,
k3 agg_state<group_concat(string)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3

View File

@ -41,8 +41,8 @@ under the License.
```sql
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int),
k3 agg_state group_concat(string)
k2 agg_state<max_by(int not null,int)> generic,
k3 agg_state<group_concat(string)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3

View File

@ -72,7 +72,7 @@ public class AggStateType extends ScalarType {
@Override
public String toSql(int depth) {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("AGG_STATE(");
stringBuilder.append("AGG_STATE<").append(functionName).append("(");
for (int i = 0; i < subTypes.size(); i++) {
if (i > 0) {
stringBuilder.append(", ");
@ -82,7 +82,7 @@ public class AggStateType extends ScalarType {
stringBuilder.append(" NULL");
}
}
stringBuilder.append(")");
stringBuilder.append(")>");
return stringBuilder.toString();
}

View File

@ -271,6 +271,7 @@ FRONTENDS: 'FRONTENDS';
FULL: 'FULL';
FUNCTION: 'FUNCTION';
FUNCTIONS: 'FUNCTIONS';
GENERIC: 'GENERIC';
GLOBAL: 'GLOBAL';
GRANT: 'GRANT';
GRANTS: 'GRANTS';

View File

@ -530,7 +530,10 @@ columnDefs
columnDef
: colName=identifier type=dataType
KEY? (aggType=aggTypeDef)? ((NOT NULL) | NULL)? (AUTO_INCREMENT (LEFT_PAREN autoIncInitValue=number RIGHT_PAREN)?)?
KEY?
(aggType=aggTypeDef)?
((NOT)? NULL)?
(AUTO_INCREMENT (LEFT_PAREN autoIncInitValue=number RIGHT_PAREN)?)?
(DEFAULT (nullValue=NULL | INTEGER_VALUE | stringValue=STRING_LITERAL
| CURRENT_TIMESTAMP (LEFT_PAREN defaultValuePrecision=number RIGHT_PAREN)?))?
(ON UPDATE CURRENT_TIMESTAMP (LEFT_PAREN onUpdateValuePrecision=number RIGHT_PAREN)?)?
@ -587,7 +590,7 @@ rollupDef
;
aggTypeDef
: MAX | MIN | SUM | REPLACE | REPLACE_IF_NOT_NULL | HLL_UNION | BITMAP_UNION | QUANTILE_UNION
: MAX | MIN | SUM | REPLACE | REPLACE_IF_NOT_NULL | HLL_UNION | BITMAP_UNION | QUANTILE_UNION | GENERIC
;
tabletList
@ -846,10 +849,17 @@ unitIdentifier
: YEAR | MONTH | WEEK | DAY | HOUR | MINUTE | SECOND
;
dataTypeWithNullable
: dataType ((NOT)? NULL)?
;
dataType
: complex=ARRAY LT dataType GT #complexDataType
| complex=MAP LT dataType COMMA dataType GT #complexDataType
| complex=STRUCT LT complexColTypeList GT #complexDataType
| AGG_STATE LT functionNameIdentifier
LEFT_PAREN dataTypes+=dataTypeWithNullable
(COMMA dataTypes+=dataTypeWithNullable)* RIGHT_PAREN GT #aggStateDataType
| primitiveColType (LEFT_PAREN (INTEGER_VALUE | ASTERISK)
(COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType
;
@ -1061,6 +1071,7 @@ nonReserved
| FREE
| FRONTENDS
| FUNCTION
| GENERIC
| GLOBAL
| GRAPH
| GROUPING

View File

@ -401,6 +401,7 @@ terminal String
KW_FULL,
KW_FUNCTION,
KW_FUNCTIONS,
KW_GENERIC,
KW_GLOBAL,
KW_GRANT,
KW_GRANTS,
@ -3132,6 +3133,10 @@ opt_agg_type ::=
{:
RESULT = AggregateType.QUANTILE_UNION;
:}
| KW_GENERIC
{:
RESULT = AggregateType.GENERIC;
:}
;
opt_partition ::=
@ -3731,31 +3736,11 @@ column_definition ::=
ColumnDef columnDef = new ColumnDef(columnName, typeDef, isKey, null, isAllowNull, autoIncInitValue, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName KW_AGG_STATE IDENT:fnName LPAREN type_def_nullable_list:list RPAREN opt_auto_inc_init_value:autoIncInitValue opt_default_value:defaultValue opt_comment:comment
{:
for (TypeDef def : list) {
def.analyze(null);
}
ColumnDef columnDef = new ColumnDef(columnName, new TypeDef(Expr.createAggStateType(fnName,
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()))), false, AggregateType.GENERIC_AGGREGATION, false, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName type_def:typeDef opt_is_key:isKey opt_agg_type:aggType opt_is_allow_null:isAllowNull opt_auto_inc_init_value:autoIncInitValue opt_default_value:defaultValue opt_comment:comment
{:
ColumnDef columnDef = new ColumnDef(columnName, typeDef, isKey, aggType, isAllowNull, autoIncInitValue, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName KW_AGG_STATE opt_is_key:isKey opt_agg_type:aggType LPAREN type_def_nullable_list:list RPAREN opt_default_value:defaultValue opt_comment:comment
{:
for (TypeDef def : list) {
def.analyze(null);
}
ColumnDef columnDef = new ColumnDef(columnName, new TypeDef(Expr.createAggStateType(aggType.name().toLowerCase(),
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()))), isKey, AggregateType.GENERIC_AGGREGATION, false, defaultValue, comment);
RESULT = columnDef;
:}
;
index_definition ::=
@ -6553,6 +6538,12 @@ type ::=
{: ScalarType type = ScalarType.createHllType();
RESULT = type;
:}
| KW_AGG_STATE LESSTHAN IDENT:fnName LPAREN type_def_nullable_list:list RPAREN GREATERTHAN
{:
RESULT = Expr.createAggStateType(fnName,
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()));
:}
| KW_ALL
{: RESULT = Type.ALL; :}
;
@ -7782,6 +7773,8 @@ keyword ::=
{: RESULT = id; :}
| KW_GLOBAL:id
{: RESULT = id; :}
| KW_GENERIC:id
{: RESULT = id; :}
| KW_GRAPH:id
{: RESULT = id; :}
| KW_HASH:id

View File

@ -332,12 +332,12 @@ public class ColumnDef {
}
// check if aggregate type is valid
if (aggregateType != AggregateType.GENERIC_AGGREGATION
if (aggregateType != AggregateType.GENERIC
&& !aggregateType.checkCompatibility(type.getPrimitiveType())) {
throw new AnalysisException(String.format("Aggregate type %s is not compatible with primitive type %s",
toString(), type.toSql()));
}
if (aggregateType == AggregateType.GENERIC_AGGREGATION) {
if (aggregateType == AggregateType.GENERIC) {
if (!SessionVariable.enableAggState()) {
throw new AnalysisException("agg state not enable, need set enable_agg_state=true");
}

View File

@ -549,7 +549,7 @@ public class CreateMaterializedViewStmt extends DdlStmt {
type = Type.BIGINT;
break;
default:
mvAggregateType = AggregateType.GENERIC_AGGREGATION;
mvAggregateType = AggregateType.GENERIC;
if (functionCallExpr.getParams().isDistinct() || functionCallExpr.getParams().isStar()) {
throw new AnalysisException(
"The Materialized-View's generic aggregation not support star or distinct");

View File

@ -37,7 +37,7 @@ public enum AggregateType {
NONE("NONE"),
BITMAP_UNION("BITMAP_UNION"),
QUANTILE_UNION("QUANTILE_UNION"),
GENERIC_AGGREGATION("GENERIC_AGGREGATION");
GENERIC("GENERIC");
private static EnumMap<AggregateType, EnumSet<PrimitiveType>> compatibilityMap;

View File

@ -128,9 +128,6 @@ public class Column implements Writable, GsonPostProcessable {
@SerializedName(value = "uniqueId")
private int uniqueId;
@SerializedName(value = "genericAggregationName")
private String genericAggregationName;
@SerializedName(value = "clusterKeyId")
private int clusterKeyId = -1;
@ -244,8 +241,8 @@ public class Column implements Writable, GsonPostProcessable {
c.setIsAllowNull(aggState.getSubTypeNullables().get(i));
addChildrenColumn(c);
}
this.genericAggregationName = aggState.getFunctionName();
this.aggregationType = AggregateType.GENERIC_AGGREGATION;
this.isAllowNull = false;
this.aggregationType = AggregateType.GENERIC;
}
}
@ -449,11 +446,7 @@ public class Column implements Writable, GsonPostProcessable {
}
public String getAggregationString() {
if (getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
return getGenericAggregationString();
} else {
return getAggregationType().name();
}
return getAggregationType().name();
}
public boolean isAggregated() {
@ -764,22 +757,6 @@ public class Column implements Writable, GsonPostProcessable {
return toSql(isUniqueTable, false);
}
public String getGenericAggregationString() {
StringBuilder sb = new StringBuilder();
sb.append(genericAggregationName).append("(");
for (int i = 0; i < children.size(); i++) {
if (i != 0) {
sb.append(", ");
}
sb.append(children.get(i).getType().toSql());
if (children.get(i).isAllowNull()) {
sb.append(" NULL");
}
}
sb.append(")");
return sb.toString();
}
public String toSql(boolean isUniqueTable, boolean isCompatible) {
StringBuilder sb = new StringBuilder();
sb.append("`").append(name).append("` ");
@ -791,11 +768,9 @@ public class Column implements Writable, GsonPostProcessable {
} else {
sb.append(typeStr);
}
if (aggregationType == AggregateType.GENERIC_AGGREGATION) {
sb.append(" ").append(getGenericAggregationString());
} else if (aggregationType != null && aggregationType != AggregateType.NONE && !isUniqueTable
if (aggregationType != null && aggregationType != AggregateType.NONE && !isUniqueTable
&& !isAggregationTypeImplicit) {
sb.append(" ").append(aggregationType.name());
sb.append(" ").append(aggregationType.toSql());
}
if (isAllowNull) {
sb.append(" NULL");

View File

@ -23,6 +23,7 @@ import org.apache.doris.analysis.StorageBackend;
import org.apache.doris.analysis.TableName;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.BuiltinAggregateFunctions;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.ScalarType;
@ -42,6 +43,7 @@ import org.apache.doris.mtmv.MTMVRefreshTriggerInfo;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.DorisParser.AddConstraintContext;
import org.apache.doris.nereids.DorisParser.AggClauseContext;
import org.apache.doris.nereids.DorisParser.AggStateDataTypeContext;
import org.apache.doris.nereids.DorisParser.AliasQueryContext;
import org.apache.doris.nereids.DorisParser.AliasedQueryContext;
import org.apache.doris.nereids.DorisParser.AlterMTMVContext;
@ -75,6 +77,7 @@ import org.apache.doris.nereids.DorisParser.CreateProcedureContext;
import org.apache.doris.nereids.DorisParser.CreateRowPolicyContext;
import org.apache.doris.nereids.DorisParser.CreateTableContext;
import org.apache.doris.nereids.DorisParser.CteContext;
import org.apache.doris.nereids.DorisParser.DataTypeWithNullableContext;
import org.apache.doris.nereids.DorisParser.DateCeilContext;
import org.apache.doris.nereids.DorisParser.DateFloorContext;
import org.apache.doris.nereids.DorisParser.Date_addContext;
@ -422,6 +425,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
@ -2519,7 +2523,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
String colName = ctx.colName.getText();
DataType colType = ctx.type instanceof PrimitiveDataTypeContext
? visitPrimitiveDataType(((PrimitiveDataTypeContext) ctx.type))
: visitComplexDataType(((ComplexDataTypeContext) ctx.type));
: ctx.type instanceof ComplexDataTypeContext
? visitComplexDataType((ComplexDataTypeContext) ctx.type)
: visitAggStateDataType((AggStateDataTypeContext) ctx.type);
colType = colType.conversion();
boolean isKey = ctx.KEY() != null;
boolean isNotNull = ctx.NOT() != null;
@ -3248,6 +3254,32 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return ExplainLevel.ALL_PLAN;
}
@Override
public Pair<DataType, Boolean> visitDataTypeWithNullable(DataTypeWithNullableContext ctx) {
return ParserUtils.withOrigin(ctx, () -> Pair.of(typedVisit(ctx.dataType()), ctx.NOT() == null));
}
@Override
public DataType visitAggStateDataType(AggStateDataTypeContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
List<Pair<DataType, Boolean>> dataTypeWithNullables = ctx.dataTypes.stream()
.map(this::visitDataTypeWithNullable)
.collect(Collectors.toList());
List<DataType> dataTypes = dataTypeWithNullables.stream()
.map(dt -> dt.first)
.collect(ImmutableList.toImmutableList());
List<Boolean> nullables = dataTypeWithNullables.stream()
.map(dt -> dt.second)
.collect(ImmutableList.toImmutableList());
String functionName = ctx.functionNameIdentifier().getText();
if (!BuiltinAggregateFunctions.INSTANCE.aggFuncNames.contains(functionName)) {
// TODO use function binder to check function exists
throw new ParseException("Can not found function '" + functionName + "'", ctx);
}
return new AggStateType(functionName, dataTypes, nullables);
});
}
@Override
public DataType visitPrimitiveDataType(PrimitiveDataTypeContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {

View File

@ -1487,7 +1487,7 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, RewriteContext context) {
String aggStateName = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
AggregateType.GENERIC_AGGREGATION, StateCombinator.create(aggregateFunction).toSql()));
AggregateType.GENERIC, StateCombinator.create(aggregateFunction).toSql()));
Column mvColumn = context.checkContext.getColumn(aggStateName);
if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {

View File

@ -246,6 +246,20 @@ public class ColumnDefinition {
}
}
if (aggType != null) {
// check if aggregate type is valid
if (aggType != AggregateType.GENERIC
&& !aggType.checkCompatibility(type.toCatalogDataType().getPrimitiveType())) {
throw new AnalysisException(String.format("Aggregate type %s is not compatible with primitive type %s",
aggType, type.toSql()));
}
if (aggType == AggregateType.GENERIC) {
if (!SessionVariable.enableAggState()) {
throw new AnalysisException("agg state not enable, need set enable_agg_state=true");
}
}
}
if (isOlap) {
if (!isKey && keysType.equals(KeysType.UNIQUE_KEYS)) {
aggTypeImplicit = true;
@ -334,14 +348,28 @@ public class ColumnDefinition {
// from old planner CreateTableStmt's analyze method, after call columnDef.analyze(engineName.equals("olap"));
if (isOlap && type.isComplexType()) {
if (aggType != null && aggType != AggregateType.NONE
&& aggType != AggregateType.REPLACE) {
throw new AnalysisException(type.toCatalogDataType().getPrimitiveType()
+ " column can't support aggregation " + aggType);
}
if (isKey) {
throw new AnalysisException(type.toCatalogDataType().getPrimitiveType()
+ " can only be used in the non-key column of the duplicate table at present.");
+ " can only be used in the non-key column at present.");
}
if (type.isAggStateType()) {
if (aggType == null) {
throw new AnalysisException(type.toCatalogDataType().getPrimitiveType()
+ " column must have aggregation type");
} else {
if (aggType != AggregateType.GENERIC
&& aggType != AggregateType.NONE
&& aggType != AggregateType.REPLACE) {
throw new AnalysisException(type.toCatalogDataType().getPrimitiveType()
+ " column can't support aggregation " + aggType);
}
}
isNullable = false;
} else {
if (aggType != null && aggType != AggregateType.NONE && aggType != AggregateType.REPLACE) {
throw new AnalysisException(type.toCatalogDataType().getPrimitiveType()
+ " column can't support aggregation " + aggType);
}
}
}
@ -350,30 +378,6 @@ public class ColumnDefinition {
}
}
/**
* check if is nested complex type.
*/
private boolean isNestedComplexType(DataType dataType) {
if (!dataType.isComplexType()) {
return false;
}
if (dataType instanceof ArrayType) {
if (((ArrayType) dataType).getItemType() instanceof ArrayType) {
return isNestedComplexType(((ArrayType) dataType).getItemType());
} else {
return ((ArrayType) dataType).getItemType().isComplexType();
}
}
if (dataType instanceof MapType) {
return ((MapType) dataType).getKeyType().isComplexType()
|| ((MapType) dataType).getValueType().isComplexType();
}
if (dataType instanceof StructType) {
return ((StructType) dataType).getFields().stream().anyMatch(f -> f.getDataType().isComplexType());
}
return false;
}
// from TypeDef.java analyze()
private void validateDataType(Type catalogType) {
if (catalogType.exceedsMaxNestingDepth()) {

View File

@ -35,8 +35,6 @@ import java.util.stream.Collectors;
*/
public class AggStateType extends DataType {
public static final AggStateType SYSTEM_DEFAULT = new AggStateType(null, ImmutableList.of(), ImmutableList.of());
public static final int WIDTH = 16;
private final List<DataType> subTypes;
@ -94,11 +92,6 @@ public class AggStateType extends DataType {
return "agg_state";
}
@Override
public DataType defaultConcreteType() {
return SYSTEM_DEFAULT;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof AggStateType)) {

View File

@ -251,6 +251,7 @@ import org.apache.doris.qe.SqlModeHelper;
keywordMap.put("function", new Integer(SqlParserSymbols.KW_FUNCTION));
keywordMap.put("functions", new Integer(SqlParserSymbols.KW_FUNCTIONS));
keywordMap.put("type_cast", new Integer(SqlParserSymbols.KW_TYPECAST));
keywordMap.put("generic", new Integer(SqlParserSymbols.KW_GENERIC));
keywordMap.put("global", new Integer(SqlParserSymbols.KW_GLOBAL));
keywordMap.put("grant", new Integer(SqlParserSymbols.KW_GRANT));
keywordMap.put("grants", new Integer(SqlParserSymbols.KW_GRANTS));

View File

@ -46,7 +46,7 @@ suite("test_vertical_compaction_agg_state") {
sql """
CREATE TABLE IF NOT EXISTS ${tableName} (
user_id VARCHAR,
agg_user_id agg_state collect_set(string)
agg_user_id agg_state<collect_set(string)> generic
)ENGINE=OLAP
AGGREGATE KEY(`user_id`)
COMMENT 'OLAP'

View File

@ -21,7 +21,7 @@ suite("test_agg_state_avg") {
sql """
create table a_table(
k1 int not null,
k2 agg_state avg(int not null)
k2 agg_state<avg(int not null)> generic
)
aggregate key (k1)
distributed BY hash(k1)

View File

@ -21,7 +21,7 @@ suite("test_agg_state_group_concat") {
sql """
create table a_table(
k1 int null,
k2 agg_state group_concat(string)
k2 agg_state<group_concat(string)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3

View File

@ -21,7 +21,7 @@ suite("test_agg_state_max") {
sql """
create table a_table(
k1 int not null,
k2 agg_state max(int not null)
k2 agg_state<max(int not null)> generic
)
aggregate key (k1)
distributed BY hash(k1)
@ -60,7 +60,7 @@ suite("test_agg_state_max") {
sql """
create table a_table2(
k1 int not null,
k2 agg_state max(int null)
k2 agg_state<max(int null)> generic
)
aggregate key (k1)
distributed BY hash(k1)

View File

@ -48,7 +48,7 @@ suite("test_agg_state_nereids") {
sql """
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int)
k2 agg_state<max_by(int not null, int)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3

View File

@ -21,7 +21,7 @@ suite("test_agg_state_quantile_union") {
sql """
create table a_table(
k1 int not null,
k2 agg_state quantile_union(quantile_state not null)
k2 agg_state<quantile_union(quantile_state not null)> generic
)
aggregate key (k1)
distributed BY hash(k1)

View File

@ -45,7 +45,7 @@ suite("test_agg_state") {
sql """
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int)
k2 agg_state<max_by(int not null,int)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3

View File

@ -22,7 +22,7 @@ suite ("dis_26495") {
sql """ DROP TABLE IF EXISTS doris_test; """
sql """
create table doris_test (a int,b int, agg_st_1 agg_state max_by(int ,int))
create table doris_test (a int,b int, agg_st_1 agg_state<max_by(int ,int)> generic)
DISTRIBUTED BY HASH(a) BUCKETS 1 properties("replication_num" = "1");
"""

View File

@ -24,7 +24,7 @@ suite("test_analyze_with_agg_complex_type") {
device_id bitmap BITMAP_UNION NULL,
hll_test hll hll_union,
qs QUANTILE_STATE QUANTILE_UNION,
agg_st_1 agg_state max_by(int ,int)
agg_st_1 agg_state<max_by(int, int)> GENERIC
)
aggregate key (datekey)
distributed by hash(datekey) buckets 1