From 675aef7d75f2cfd6a25e588d1dc480c8cbba67ba Mon Sep 17 00:00:00 2001 From: qiye Date: Sun, 10 Oct 2021 23:05:44 +0800 Subject: [PATCH] [AliasFunction] Add support for cast in alias function (#6754) support #6753 --- .../Data Definition/create-function.md | 9 +- .../Data Definition/create-function.md | 9 +- fe/fe-core/src/main/cup/sql_parser.cup | 22 +++ .../org/apache/doris/analysis/CastExpr.java | 126 +++++++++++++++++- .../java/org/apache/doris/analysis/Expr.java | 9 +- .../apache/doris/catalog/AliasFunction.java | 45 ++++++- .../apache/doris/catalog/PrimitiveType.java | 5 +- .../org/apache/doris/catalog/ScalarType.java | 91 ++++++++++++- .../java/org/apache/doris/catalog/Type.java | 8 +- .../rewrite/RewriteAliasFunctionRule.java | 9 +- .../doris/catalog/CreateFunctionTest.java | 99 ++++++++++++++ gensrc/thrift/Types.thrift | 3 +- 12 files changed, 415 insertions(+), 20 deletions(-) diff --git a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md index 7eed4e4b42..417678c2fd 100644 --- a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md +++ b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md @@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name > > `Function_name`: To create the name of the function, you can include the name of the database. For example: `db1.my_func'. > -> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type. -> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter. +> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type. +> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter. In particular, the type `ALL` refers to any data type and can only be used for `ALIAS FUNCTION`. > > `ret_type`: Required for creating a new function. This parameter is not required if you are aliasing an existing function. > @@ -130,8 +130,13 @@ If the `function_name` contains the database name, the custom function will be c 5. Create a custom alias function ``` + -- create a custom functional alias function CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id) AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4)); + + -- create a custom cast alias function + CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale) + AS CAST(col AS decimal(precision, scale)); ``` ## keyword diff --git a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md index 858262665a..cf6a4fe04c 100644 --- a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md +++ b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md @@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name > > `function_name`: 要创建函数的名字, 可以包含数据库的名字。比如:`db1.my_func`。 > -> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。 -> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 +> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。 +> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 特别地,`ALL` 类型指任一数据类型,只可以用于 `ALIAS FUNCTION`. > > `ret_type`: 对创建新的函数来说,是必填项。如果是给已有函数取别名则可不用填写该参数。 > @@ -131,8 +131,13 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name 5. 创建一个自定义别名函数 ``` + -- 创建自定义功能别名函数 CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id) AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4)); + + -- 创建自定义 CAST 别名函数 + CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale) + AS CAST(col AS decimal(precision, scale)); ``` ## keyword diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup index bfb936ee27..20e9c8e8a6 100644 --- a/fe/fe-core/src/main/cup/sql_parser.cup +++ b/fe/fe-core/src/main/cup/sql_parser.cup @@ -4352,6 +4352,11 @@ type ::= type.setAssignedStrLenInColDefinition(); RESULT = type; :} + | KW_VARCHAR LPAREN ident_or_text:lenStr RPAREN + {: ScalarType type = ScalarType.createVarcharType(lenStr); + type.setAssignedStrLenInColDefinition(); + RESULT = type; + :} | KW_VARCHAR {: RESULT = ScalarType.createVarcharType(-1); :} | KW_ARRAY LESSTHAN type:value_type GREATERTHAN @@ -4365,6 +4370,11 @@ type ::= type.setAssignedStrLenInColDefinition(); RESULT = type; :} + | KW_CHAR LPAREN ident_or_text:lenStr RPAREN + {: ScalarType type = ScalarType.createCharType(lenStr); + type.setAssignedStrLenInColDefinition(); + RESULT = type; + :} | KW_CHAR {: RESULT = ScalarType.createCharType(-1); :} | KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN @@ -4373,11 +4383,17 @@ type ::= {: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), scale.intValue()); :} | KW_DECIMAL {: RESULT = ScalarType.createDecimalV2Type(); :} + | KW_DECIMAL LPAREN ident_or_text:precision RPAREN + {: RESULT = ScalarType.createDecimalV2Type(precision); :} + | KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN + {: RESULT = ScalarType.createDecimalV2Type(precision, scale); :} | KW_HLL {: ScalarType type = ScalarType.createHllType(); type.setAssignedStrLenInColDefinition(); RESULT = type; :} + | KW_ALL + {: RESULT = Type.ALL; :} ; opt_field_length ::= @@ -5180,6 +5196,8 @@ keyword ::= {: RESULT = id; :} | KW_CHAIN:id {: RESULT = id; :} + | KW_CHAR:id + {: RESULT = id; :} | KW_CHARSET:id {: RESULT = id; :} | KW_CHECK:id @@ -5210,6 +5228,8 @@ keyword ::= {: RESULT = id; :} | KW_DATETIME:id {: RESULT = id; :} + | KW_DECIMAL:id + {: RESULT = id; :} | KW_DISTINCTPC:id {: RESULT = id; :} | KW_DISTINCTPCSA:id @@ -5456,6 +5476,8 @@ keyword ::= {: RESULT = id; :} | KW_MAP:id {: RESULT = id; :} + | KW_VARCHAR:id + {: RESULT = id; :} ; // Identifier that contain keyword diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index ca46025d63..f071e0fd9e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -17,7 +17,11 @@ package org.apache.doris.analysis; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.Map; import org.apache.doris.catalog.Catalog; @@ -46,7 +50,7 @@ public class CastExpr extends Expr { private static final Logger LOG = LogManager.getLogger(CastExpr.class); // Only set for explicit casts. Null for implicit casts. - private final TypeDef targetTypeDef; + private TypeDef targetTypeDef; // True if this is a "pre-analyzed" implicit cast. private boolean isImplicit; @@ -77,6 +81,11 @@ public class CastExpr extends Expr { } } + // only used restore from readFields. + public CastExpr() { + + } + public CastExpr(Type targetType, Expr e) { super(); Preconditions.checkArgument(targetType.isValid()); @@ -120,6 +129,10 @@ public class CastExpr extends Expr { return "castTo" + targetType.getPrimitiveType().toString(); } + public TypeDef getTargetTypeDef() { + return targetTypeDef; + } + public static void initBuiltins(FunctionSet functionSet) { for (Type fromType : Type.getSupportedTypes()) { if (fromType.isNull()) { @@ -206,6 +219,10 @@ public class CastExpr extends Expr { } public void analyze() throws AnalysisException { + // do not analyze ALL cast + if (type == Type.ALL) { + return; + } // cast was asked for in the query, check for validity of cast Type childType = getChild(0).getType(); @@ -327,4 +344,111 @@ public class CastExpr extends Expr { } return this; } + + @Override + public void write(DataOutput out) throws IOException { + out.writeBoolean(isImplicit); + if (targetTypeDef.getType() instanceof ScalarType) { + ScalarType scalarType = (ScalarType) targetTypeDef.getType(); + scalarType.write(out); + } else { + throw new IOException("Can not write type " + targetTypeDef.getType()); + } + out.writeInt(children.size()); + for (Expr expr : children) { + Expr.writeTo(expr, out); + } + } + + public static CastExpr read(DataInput input) throws IOException { + CastExpr castExpr = new CastExpr(); + castExpr.readFields(input); + return castExpr; + } + + @Override + public void readFields(DataInput in) throws IOException { + isImplicit = in.readBoolean(); + ScalarType scalarType = ScalarType.read(in); + targetTypeDef = new TypeDef(scalarType); + int counter = in.readInt(); + for (int i = 0; i < counter; i++) { + children.add(Expr.readIn(in)); + } + } + + public CastExpr rewriteExpr(List parameters, List inputParamsExprs) throws AnalysisException { + // child + Expr child = this.getChild(0); + Expr newChild = null; + if (child instanceof SlotRef) { + String columnName = ((SlotRef) child).getColumnName(); + int index = parameters.indexOf(columnName); + if (index != -1) { + newChild = inputParamsExprs.get(index); + } + } + // rewrite cast expr in children + if (child instanceof CastExpr) { + newChild = ((CastExpr) child).rewriteExpr(parameters, inputParamsExprs); + } + + // type def + ScalarType targetType = (ScalarType) targetTypeDef.getType(); + PrimitiveType primitiveType = targetType.getPrimitiveType(); + ScalarType newTargetType = null; + switch (primitiveType) { + case DECIMALV2: + // normal decimal + if (targetType.getPrecision() != 0) { + newTargetType = targetType; + break; + } + int precision = getDigital(targetType.getScalarPrecisionStr(), parameters, inputParamsExprs); + int scale = getDigital(targetType.getScalarScaleStr(), parameters, inputParamsExprs); + if (precision != -1 && scale != -1) { + newTargetType = ScalarType.createType(primitiveType, 0, precision, scale); + } else if (precision != -1 && scale == -1) { + newTargetType = ScalarType.createType(primitiveType, 0, precision, ScalarType.DEFAULT_SCALE); + } + break; + case CHAR: + case VARCHAR: + // normal char/varchar + if (targetType.getLength() != -1) { + newTargetType = targetType; + break; + } + int len = getDigital(targetType.getLenStr(), parameters, inputParamsExprs); + if (len != -1) { + newTargetType = ScalarType.createType(primitiveType, len, 0, 0); + } + // default char/varchar, which len is -1 + if (len == -1 && targetType.getLength() == -1) { + newTargetType = targetType; + } + break; + default: + newTargetType = targetType; + break; + } + + if (newTargetType != null && newChild != null) { + TypeDef typeDef = new TypeDef(newTargetType); + return new CastExpr(typeDef, newChild); + } + + return this; + } + + private int getDigital(String desc, List parameters, List inputParamsExprs) { + int index = parameters.indexOf(desc); + if (index != -1) { + Expr expr = inputParamsExprs.get(index); + if (expr.getType().isIntegerType()) { + return ((Long)((IntLiteral) expr).getRealValue()).intValue(); + } + } + return -1; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index c935ab5517..a7098a9e1c 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -1665,7 +1665,8 @@ abstract public class Expr extends TreeNode implements ParseNode, Cloneabl MAX_LITERAL(10), BINARY_PREDICATE(11), FUNCTION_CALL(12), - ARRAY_LITERAL(13); + ARRAY_LITERAL(13), + CAST_EXPR(14); private static Map codeMap = Maps.newHashMap(); @@ -1715,7 +1716,9 @@ abstract public class Expr extends TreeNode implements ParseNode, Cloneabl output.writeInt(ExprSerCode.FUNCTION_CALL.getCode()); } else if (expr instanceof ArrayLiteral) { output.writeInt(ExprSerCode.ARRAY_LITERAL.getCode()); - } else { + } else if (expr instanceof CastExpr){ + output.writeInt(ExprSerCode.CAST_EXPR.getCode()); + }else { throw new IOException("Unknown class " + expr.getClass().getName()); } expr.write(output); @@ -1758,6 +1761,8 @@ abstract public class Expr extends TreeNode implements ParseNode, Cloneabl return FunctionCallExpr.read(in); case ARRAY_LITERAL: return ArrayLiteral.read(in); + case CAST_EXPR: + return CastExpr.read(in); default: throw new IOException("Unknown code: " + code); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java index 53476da792..2e91f33d0f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java @@ -17,6 +17,7 @@ package org.apache.doris.catalog; +import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.FunctionName; @@ -24,12 +25,14 @@ import org.apache.doris.analysis.SelectStmt; import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.SqlParser; import org.apache.doris.analysis.SqlScanner; +import org.apache.doris.analysis.TypeDef; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.io.Text; import org.apache.doris.common.util.SqlParserUtils; import org.apache.doris.qe.SqlModeHelper; import org.apache.doris.thrift.TFunctionBinaryType; +import com.google.common.base.Strings; import com.google.common.collect.Lists; import com.google.gson.Gson; @@ -59,6 +62,7 @@ public class AliasFunction extends Function { private Expr originFunction; private List parameters = new ArrayList<>(); + private List typeDefParams = new ArrayList<>(); // Only used for serialization protected AliasFunction() { @@ -152,30 +156,63 @@ public class AliasFunction extends Function { if (parameters.size() != getArgs().length) { throw new AnalysisException("Alias function [" + functionName() + "] args number is not equal to parameters number"); } - List exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs(); + List exprs; + if (originFunction instanceof FunctionCallExpr) { + exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs(); + } else if (originFunction instanceof CastExpr) { + exprs = originFunction.getChildren(); + TypeDef targetTypeDef = ((CastExpr) originFunction).getTargetTypeDef(); + if (targetTypeDef.getType().isScalarType()) { + ScalarType scalarType = (ScalarType) targetTypeDef.getType(); + PrimitiveType primitiveType = scalarType.getPrimitiveType(); + switch (primitiveType) { + case DECIMALV2: + if (!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) { + typeDefParams.add(scalarType.getScalarPrecisionStr()); + } + if (!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) { + typeDefParams.add(scalarType.getScalarScaleStr()); + } + break; + case CHAR: + case VARCHAR: + if (!Strings.isNullOrEmpty(scalarType.getLenStr())) { + typeDefParams.add(scalarType.getLenStr()); + } + break; + } + } + } else { + throw new AnalysisException("Not supported expr type: " + originFunction); + } Set set = new HashSet<>(); for (String str : parameters) { if (!set.add(str)) { throw new AnalysisException("Alias function [" + functionName() + "] has duplicate parameter [" + str + "]."); } boolean existFlag = false; + // check exprs for (Expr expr : exprs) { existFlag |= checkParams(expr, str); } + // check targetTypeDef + for (String typeDefParam : typeDefParams) { + existFlag |= typeDefParam.equals(str); + } if (!existFlag) { throw new AnalysisException("Alias function [" + functionName() + "] do not contain parameter [" + str + "]."); } } } - private boolean checkParams(Expr expr, String parma) { + private boolean checkParams(Expr expr, String param) { for (Expr e : expr.getChildren()) { - if (checkParams(e, parma)) { + if (checkParams(e, param)) { return true; } } if (expr instanceof SlotRef) { - if (parma.equals(((SlotRef) expr).getColumnName())) { + if (param.equals(((SlotRef) expr).getColumnName())) { return true; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java index 992a9e0868..7def685b5a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java @@ -59,7 +59,8 @@ public enum PrimitiveType { STRUCT("MAP", 24, TPrimitiveType.STRUCT), STRING("STRING", 16, TPrimitiveType.STRING), // Unsupported scalar types. - BINARY("BINARY", -1, TPrimitiveType.BINARY); + BINARY("BINARY", -1, TPrimitiveType.BINARY), + ALL("ALL", -1, TPrimitiveType.INVALID_TYPE); private static final int DATE_INDEX_LEN = 3; @@ -611,6 +612,8 @@ public enum PrimitiveType { return MAP; case STRUCT: return STRUCT; + case ALL: + return ALL; default: return INVALID_TYPE; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java index b56d8a71af..57c9034af4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java @@ -17,8 +17,13 @@ package org.apache.doris.catalog; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; import java.util.Objects; +import org.apache.doris.common.io.Text; +import org.apache.doris.persist.gson.GsonUtils; import org.apache.doris.thrift.TColumnType; import org.apache.doris.thrift.TScalarType; import org.apache.doris.thrift.TTypeDesc; @@ -85,6 +90,16 @@ public class ScalarType extends Type { @SerializedName(value = "scale") private int scale; + // Only used for alias function decimal + @SerializedName(value = "precisionStr") + private String precisionStr; + // Only used for alias function decimal + @SerializedName(value = "scaleStr") + private String scaleStr; + // Only used for alias function char/varchar + @SerializedName(value = "lenStr") + private String lenStr; + protected ScalarType(PrimitiveType type) { this.type = type; } @@ -144,6 +159,8 @@ public class ScalarType extends Type { return DEFAULT_DECIMALV2; case LARGEINT: return LARGEINT; + case ALL: + return ALL; default: LOG.warn("type={}", type); Preconditions.checkState(false); @@ -207,6 +224,12 @@ public class ScalarType extends Type { return type; } + public static ScalarType createCharType(String lenStr) { + ScalarType type = new ScalarType(PrimitiveType.CHAR); + type.lenStr = lenStr; + return type; + } + public static ScalarType createChar(int len) { ScalarType type = new ScalarType(PrimitiveType.CHAR); type.len = len; @@ -230,6 +253,20 @@ public class ScalarType extends Type { return type; } + public static ScalarType createDecimalV2Type(String precisionStr) { + ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + type.precisionStr = precisionStr; + type.scaleStr = null; + return type; + } + + public static ScalarType createDecimalV2Type(String precisionStr, String scaleStr) { + ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + type.precisionStr = precisionStr; + type.scaleStr = scaleStr; + return type; + } + public static ScalarType createDecimalV2TypeInternal(int precision, int scale) { ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); type.precision = Math.min(precision, MAX_PRECISION); @@ -244,6 +281,13 @@ public class ScalarType extends Type { return type; } + public static ScalarType createVarcharType(String lenStr) { + // length checked in analysis + ScalarType type = new ScalarType(PrimitiveType.VARCHAR); + type.lenStr = lenStr; + return type; + } + public static ScalarType createStringType() { // length checked in analysis ScalarType type = new ScalarType(PrimitiveType.STRING); @@ -296,13 +340,27 @@ public class ScalarType extends Type { StringBuilder stringBuilder = new StringBuilder(); switch (type) { case CHAR: - stringBuilder.append("char").append("(").append(len).append(")"); + if (Strings.isNullOrEmpty(lenStr)) { + stringBuilder.append("char").append("(").append(len).append(")"); + } else { + stringBuilder.append("char").append("(`").append(lenStr).append("`)"); + } break; case VARCHAR: - stringBuilder.append("varchar").append("(").append(len).append(")"); + if (Strings.isNullOrEmpty(lenStr)) { + stringBuilder.append("varchar").append("(").append(len).append(")"); + } else { + stringBuilder.append("varchar").append("(`").append(lenStr).append("`)"); + } break; case DECIMALV2: - stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")"); + if (Strings.isNullOrEmpty(precisionStr)) { + stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")"); + } else if (!Strings.isNullOrEmpty(precisionStr) && !Strings.isNullOrEmpty(scaleStr)) { + stringBuilder.append("decimal").append("(`").append(precisionStr).append("`, `").append(scaleStr).append("`)"); + } else { + stringBuilder.append("decimal").append("(`").append(precisionStr).append("`)"); + } break; case BOOLEAN: return "boolean"; @@ -393,6 +451,18 @@ public class ScalarType extends Type { public int getScalarScale() { return scale; } public int getScalarPrecision() { return precision; } + public String getScalarPrecisionStr() { + return precisionStr; + } + + public String getScalarScaleStr() { + return scaleStr; + } + + public String getLenStr() { + return lenStr; + } + @Override public boolean isWildcardDecimal() { return (type == PrimitiveType.DECIMALV2) @@ -606,6 +676,11 @@ public class ScalarType extends Type { return INVALID; } + // for cast all type + if (t1.type == PrimitiveType.ALL || t2.type == PrimitiveType.ALL) { + return Type.ALL; + } + if (t1.isStringType() || t2.isStringType()) { if (t1.type == PrimitiveType.STRING || t2.type == PrimitiveType.STRING) { return createStringType(); @@ -708,4 +783,14 @@ public class ScalarType extends Type { result = 31 * result + scale; return result; } + + public void write(DataOutput out) throws IOException { + String json = GsonUtils.GSON.toJson(this); + Text.writeString(out, json); + } + + public static ScalarType read(DataInput input) throws IOException { + String json = Text.readString(input); + return GsonUtils.GSON.fromJson(json, ScalarType.class); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java index 7c74d34227..7e2d2ab19c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java @@ -76,6 +76,8 @@ public abstract class Type { public static final ScalarType HLL = ScalarType.createHllType(); public static final ScalarType CHAR = (ScalarType) ScalarType.createCharType(-1); public static final ScalarType BITMAP = new ScalarType(PrimitiveType.BITMAP); + // Only used for alias function, to represent any type in function args + public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL); public static final MapType Map = new MapType(); private static ArrayList integerTypes; @@ -944,9 +946,9 @@ public abstract class Type { compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; // Check all of the necessary entries that should be filled. - // ignore binary - for (int i = 0; i < PrimitiveType.values().length - 1; ++i) { - for (int j = i; j < PrimitiveType.values().length - 1; ++j) { + // ignore binary and all + for (int i = 0; i < PrimitiveType.values().length - 2; ++i) { + for (int j = i; j < PrimitiveType.values().length - 2; ++j) { PrimitiveType t1 = PrimitiveType.values()[i]; PrimitiveType t2 = PrimitiveType.values()[j]; // DECIMAL, NULL, and INVALID_TYPE are handled separately. diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java index c3a0f3e0f0..09e61fcb21 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java @@ -18,6 +18,7 @@ package org.apache.doris.rewrite; import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.catalog.AliasFunction; @@ -39,7 +40,13 @@ public class RewriteAliasFunctionRule implements ExprRewriteRule{ if (expr instanceof FunctionCallExpr) { Function fn = expr.getFn(); if (fn instanceof AliasFunction) { - return ((FunctionCallExpr) expr).rewriteExpr(); + Expr originFn = ((AliasFunction) fn).getOriginFunction(); + if (originFn instanceof FunctionCallExpr) { + return ((FunctionCallExpr) expr).rewriteExpr(); + } else if (originFn instanceof CastExpr) { + return ((CastExpr) originFn).rewriteExpr(((AliasFunction) fn).getParameters(), + ((FunctionCallExpr) expr).getParams().exprs()); + } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index d56856d84f..17ad6665bc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -19,8 +19,10 @@ package org.apache.doris.catalog; import org.apache.doris.analysis.CreateDbStmt; import org.apache.doris.analysis.CreateFunctionStmt; +import org.apache.doris.analysis.CreateTableStmt; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.FunctionCallExpr; +import org.apache.doris.analysis.StringLiteral; import org.apache.doris.common.FeConstants; import org.apache.doris.common.jmockit.Deencapsulation; import org.apache.doris.planner.PlanFragment; @@ -29,6 +31,7 @@ import org.apache.doris.planner.UnionNode; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.QueryState; import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.utframe.DorisAssert; import org.apache.doris.utframe.UtFrameUtils; import org.junit.AfterClass; @@ -48,11 +51,15 @@ import java.util.UUID; public class CreateFunctionTest { private static String runningDir = "fe/mocked/CreateFunctionTest/" + UUID.randomUUID().toString() + "/"; + private static ConnectContext connectContext; + private static DorisAssert dorisAssert; @BeforeClass public static void setup() throws Exception { UtFrameUtils.createDorisCluster(runningDir); FeConstants.runningUnitTest = true; + // create connect context + connectContext = UtFrameUtils.createDefaultCtx(); } @AfterClass @@ -71,6 +78,14 @@ public class CreateFunctionTest { Catalog.getCurrentCatalog().createDb(createDbStmt); System.out.println(Catalog.getCurrentCatalog().getDbNames()); + String createTblStmtStr = "create table db1.tbl1(k1 int, k2 bigint, k3 varchar(10), k4 char(5)) duplicate key(k1) " + + "distributed by hash(k2) buckets 1 properties('replication_num' = '1');"; + CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, connectContext); + Catalog.getCurrentCatalog().createTable(createTableStmt); + + dorisAssert = new DorisAssert(); + dorisAssert.useDatabase("db1"); + Database db = Catalog.getCurrentCatalog().getDbNullable("default_cluster:db1"); Assert.assertNotNull(db); @@ -126,5 +141,89 @@ public class CreateFunctionTest { Assert.assertEquals(1, constExprLists.size()); Assert.assertEquals(1, constExprLists.get(0).size()); Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr); + + queryStr = "select db1.id_masking(k1) from db1.tbl1"; + Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("concat(left(`k1`, 3), '****', right(`k1`, 4))")); + + // create alias function with cast + // cast any type to decimal with specific precision and scale + createFuncStr = "create alias function db1.decimal(all, int, int) with parameter(col, precision, scale)" + + " as cast(col as decimal(precision, scale));"; + createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); + Catalog.getCurrentCatalog().createFunction(createFunctionStmt); + + functions = db.getFunctions(); + Assert.assertEquals(3, functions.size()); + + queryStr = "select db1.decimal(333, 4, 1);"; + ctx.getState().reset(); + stmtExecutor = new StmtExecutor(ctx, queryStr); + stmtExecutor.execute(); + Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType()); + planner = stmtExecutor.planner(); + Assert.assertEquals(1, planner.getFragments().size()); + fragment = planner.getFragments().get(0); + Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode); + unionNode = (UnionNode)fragment.getPlanRoot(); + constExprLists = Deencapsulation.getField(unionNode, "constExprLists_"); + System.out.println(constExprLists.get(0).get(0)); + Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); + + queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;"; + Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4,1))")); + + // cast any type to varchar with fixed length + createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " + + "cast(text as varchar(length));"; + createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); + Catalog.getCurrentCatalog().createFunction(createFunctionStmt); + + functions = db.getFunctions(); + Assert.assertEquals(4, functions.size()); + + queryStr = "select db1.varchar(333, 4);"; + ctx.getState().reset(); + stmtExecutor = new StmtExecutor(ctx, queryStr); + stmtExecutor.execute(); + Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType()); + planner = stmtExecutor.planner(); + Assert.assertEquals(1, planner.getFragments().size()); + fragment = planner.getFragments().get(0); + Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode); + unionNode = (UnionNode)fragment.getPlanRoot(); + constExprLists = Deencapsulation.getField(unionNode, "constExprLists_"); + Assert.assertEquals(1, constExprLists.size()); + Assert.assertEquals(1, constExprLists.get(0).size()); + Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); + + queryStr = "select db1.varchar(k1, 4) from db1.tbl1;"; + Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); + + // cast any type to char with fixed length + createFuncStr = "create alias function db1.char(all, int) with parameter(text, length) as " + + "cast(text as char(length));"; + createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); + Catalog.getCurrentCatalog().createFunction(createFunctionStmt); + + functions = db.getFunctions(); + Assert.assertEquals(5, functions.size()); + + queryStr = "select db1.char(333, 4);"; + ctx.getState().reset(); + stmtExecutor = new StmtExecutor(ctx, queryStr); + stmtExecutor.execute(); + Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType()); + planner = stmtExecutor.planner(); + Assert.assertEquals(1, planner.getFragments().size()); + fragment = planner.getFragments().get(0); + Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode); + unionNode = (UnionNode)fragment.getPlanRoot(); + constExprLists = Deencapsulation.getField(unionNode, "constExprLists_"); + Assert.assertEquals(1, constExprLists.size()); + Assert.assertEquals(1, constExprLists.get(0).size()); + Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); + + queryStr = "select db1.char(k1, 4) from db1.tbl1;"; + Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); } } diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 8c9a46ff48..efa657a559 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -78,7 +78,8 @@ enum TPrimitiveType { ARRAY, MAP, STRUCT, - STRING + STRING, + ALL } enum TTypeNodeType {