[AliasFunction] Add support for cast in alias function (#6754)

support #6753
This commit is contained in:
qiye
2021-10-10 23:05:44 +08:00
committed by GitHub
parent 0941322dd6
commit 675aef7d75
12 changed files with 415 additions and 20 deletions

View File

@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
>
> `Function_name`: To create the name of the function, you can include the name of the database. For example: `db1.my_func'.
>
> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type.
> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter.
> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type.
> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter. In particular, the type `ALL` refers to any data type and can only be used for `ALIAS FUNCTION`.
>
> `ret_type`: Required for creating a new function. This parameter is not required if you are aliasing an existing function.
>
@ -130,8 +130,13 @@ If the `function_name` contains the database name, the custom function will be c
5. Create a custom alias function
```
-- create a custom functional alias function
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
-- create a custom cast alias function
CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale)
AS CAST(col AS decimal(precision, scale));
```
## keyword

View File

@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
>
> `function_name`: 要创建函数的名字, 可以包含数据库的名字。比如:`db1.my_func`。
>
> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。
> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 特别地,`ALL` 类型指任一数据类型,只可以用于 `ALIAS FUNCTION`.
>
> `ret_type`: 对创建新的函数来说,是必填项。如果是给已有函数取别名则可不用填写该参数。
>
@ -131,8 +131,13 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
5. 创建一个自定义别名函数
```
-- 创建自定义功能别名函数
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
-- 创建自定义 CAST 别名函数
CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale)
AS CAST(col AS decimal(precision, scale));
```
## keyword

View File

@ -4352,6 +4352,11 @@ type ::=
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
| KW_VARCHAR LPAREN ident_or_text:lenStr RPAREN
{: ScalarType type = ScalarType.createVarcharType(lenStr);
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
| KW_VARCHAR
{: RESULT = ScalarType.createVarcharType(-1); :}
| KW_ARRAY LESSTHAN type:value_type GREATERTHAN
@ -4365,6 +4370,11 @@ type ::=
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
| KW_CHAR LPAREN ident_or_text:lenStr RPAREN
{: ScalarType type = ScalarType.createCharType(lenStr);
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
| KW_CHAR
{: RESULT = ScalarType.createCharType(-1); :}
| KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN
@ -4373,11 +4383,17 @@ type ::=
{: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), scale.intValue()); :}
| KW_DECIMAL
{: RESULT = ScalarType.createDecimalV2Type(); :}
| KW_DECIMAL LPAREN ident_or_text:precision RPAREN
{: RESULT = ScalarType.createDecimalV2Type(precision); :}
| KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN
{: RESULT = ScalarType.createDecimalV2Type(precision, scale); :}
| KW_HLL
{: ScalarType type = ScalarType.createHllType();
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
| KW_ALL
{: RESULT = Type.ALL; :}
;
opt_field_length ::=
@ -5180,6 +5196,8 @@ keyword ::=
{: RESULT = id; :}
| KW_CHAIN:id
{: RESULT = id; :}
| KW_CHAR:id
{: RESULT = id; :}
| KW_CHARSET:id
{: RESULT = id; :}
| KW_CHECK:id
@ -5210,6 +5228,8 @@ keyword ::=
{: RESULT = id; :}
| KW_DATETIME:id
{: RESULT = id; :}
| KW_DECIMAL:id
{: RESULT = id; :}
| KW_DISTINCTPC:id
{: RESULT = id; :}
| KW_DISTINCTPCSA:id
@ -5456,6 +5476,8 @@ keyword ::=
{: RESULT = id; :}
| KW_MAP:id
{: RESULT = id; :}
| KW_VARCHAR:id
{: RESULT = id; :}
;
// Identifier that contain keyword

View File

@ -17,7 +17,11 @@
package org.apache.doris.analysis;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.doris.catalog.Catalog;
@ -46,7 +50,7 @@ public class CastExpr extends Expr {
private static final Logger LOG = LogManager.getLogger(CastExpr.class);
// Only set for explicit casts. Null for implicit casts.
private final TypeDef targetTypeDef;
private TypeDef targetTypeDef;
// True if this is a "pre-analyzed" implicit cast.
private boolean isImplicit;
@ -77,6 +81,11 @@ public class CastExpr extends Expr {
}
}
// only used restore from readFields.
public CastExpr() {
}
public CastExpr(Type targetType, Expr e) {
super();
Preconditions.checkArgument(targetType.isValid());
@ -120,6 +129,10 @@ public class CastExpr extends Expr {
return "castTo" + targetType.getPrimitiveType().toString();
}
public TypeDef getTargetTypeDef() {
return targetTypeDef;
}
public static void initBuiltins(FunctionSet functionSet) {
for (Type fromType : Type.getSupportedTypes()) {
if (fromType.isNull()) {
@ -206,6 +219,10 @@ public class CastExpr extends Expr {
}
public void analyze() throws AnalysisException {
// do not analyze ALL cast
if (type == Type.ALL) {
return;
}
// cast was asked for in the query, check for validity of cast
Type childType = getChild(0).getType();
@ -327,4 +344,111 @@ public class CastExpr extends Expr {
}
return this;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeBoolean(isImplicit);
if (targetTypeDef.getType() instanceof ScalarType) {
ScalarType scalarType = (ScalarType) targetTypeDef.getType();
scalarType.write(out);
} else {
throw new IOException("Can not write type " + targetTypeDef.getType());
}
out.writeInt(children.size());
for (Expr expr : children) {
Expr.writeTo(expr, out);
}
}
public static CastExpr read(DataInput input) throws IOException {
CastExpr castExpr = new CastExpr();
castExpr.readFields(input);
return castExpr;
}
@Override
public void readFields(DataInput in) throws IOException {
isImplicit = in.readBoolean();
ScalarType scalarType = ScalarType.read(in);
targetTypeDef = new TypeDef(scalarType);
int counter = in.readInt();
for (int i = 0; i < counter; i++) {
children.add(Expr.readIn(in));
}
}
public CastExpr rewriteExpr(List<String> parameters, List<Expr> inputParamsExprs) throws AnalysisException {
// child
Expr child = this.getChild(0);
Expr newChild = null;
if (child instanceof SlotRef) {
String columnName = ((SlotRef) child).getColumnName();
int index = parameters.indexOf(columnName);
if (index != -1) {
newChild = inputParamsExprs.get(index);
}
}
// rewrite cast expr in children
if (child instanceof CastExpr) {
newChild = ((CastExpr) child).rewriteExpr(parameters, inputParamsExprs);
}
// type def
ScalarType targetType = (ScalarType) targetTypeDef.getType();
PrimitiveType primitiveType = targetType.getPrimitiveType();
ScalarType newTargetType = null;
switch (primitiveType) {
case DECIMALV2:
// normal decimal
if (targetType.getPrecision() != 0) {
newTargetType = targetType;
break;
}
int precision = getDigital(targetType.getScalarPrecisionStr(), parameters, inputParamsExprs);
int scale = getDigital(targetType.getScalarScaleStr(), parameters, inputParamsExprs);
if (precision != -1 && scale != -1) {
newTargetType = ScalarType.createType(primitiveType, 0, precision, scale);
} else if (precision != -1 && scale == -1) {
newTargetType = ScalarType.createType(primitiveType, 0, precision, ScalarType.DEFAULT_SCALE);
}
break;
case CHAR:
case VARCHAR:
// normal char/varchar
if (targetType.getLength() != -1) {
newTargetType = targetType;
break;
}
int len = getDigital(targetType.getLenStr(), parameters, inputParamsExprs);
if (len != -1) {
newTargetType = ScalarType.createType(primitiveType, len, 0, 0);
}
// default char/varchar, which len is -1
if (len == -1 && targetType.getLength() == -1) {
newTargetType = targetType;
}
break;
default:
newTargetType = targetType;
break;
}
if (newTargetType != null && newChild != null) {
TypeDef typeDef = new TypeDef(newTargetType);
return new CastExpr(typeDef, newChild);
}
return this;
}
private int getDigital(String desc, List<String> parameters, List<Expr> inputParamsExprs) {
int index = parameters.indexOf(desc);
if (index != -1) {
Expr expr = inputParamsExprs.get(index);
if (expr.getType().isIntegerType()) {
return ((Long)((IntLiteral) expr).getRealValue()).intValue();
}
}
return -1;
}
}

View File

@ -1665,7 +1665,8 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
MAX_LITERAL(10),
BINARY_PREDICATE(11),
FUNCTION_CALL(12),
ARRAY_LITERAL(13);
ARRAY_LITERAL(13),
CAST_EXPR(14);
private static Map<Integer, ExprSerCode> codeMap = Maps.newHashMap();
@ -1715,7 +1716,9 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
output.writeInt(ExprSerCode.FUNCTION_CALL.getCode());
} else if (expr instanceof ArrayLiteral) {
output.writeInt(ExprSerCode.ARRAY_LITERAL.getCode());
} else {
} else if (expr instanceof CastExpr){
output.writeInt(ExprSerCode.CAST_EXPR.getCode());
}else {
throw new IOException("Unknown class " + expr.getClass().getName());
}
expr.write(output);
@ -1758,6 +1761,8 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return FunctionCallExpr.read(in);
case ARRAY_LITERAL:
return ArrayLiteral.read(in);
case CAST_EXPR:
return CastExpr.read(in);
default:
throw new IOException("Unknown code: " + code);
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.catalog;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.FunctionName;
@ -24,12 +25,14 @@ import org.apache.doris.analysis.SelectStmt;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SqlParser;
import org.apache.doris.analysis.SqlScanner;
import org.apache.doris.analysis.TypeDef;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.io.Text;
import org.apache.doris.common.util.SqlParserUtils;
import org.apache.doris.qe.SqlModeHelper;
import org.apache.doris.thrift.TFunctionBinaryType;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
@ -59,6 +62,7 @@ public class AliasFunction extends Function {
private Expr originFunction;
private List<String> parameters = new ArrayList<>();
private List<String> typeDefParams = new ArrayList<>();
// Only used for serialization
protected AliasFunction() {
@ -152,30 +156,63 @@ public class AliasFunction extends Function {
if (parameters.size() != getArgs().length) {
throw new AnalysisException("Alias function [" + functionName() + "] args number is not equal to parameters number");
}
List<Expr> exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
List<Expr> exprs;
if (originFunction instanceof FunctionCallExpr) {
exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
} else if (originFunction instanceof CastExpr) {
exprs = originFunction.getChildren();
TypeDef targetTypeDef = ((CastExpr) originFunction).getTargetTypeDef();
if (targetTypeDef.getType().isScalarType()) {
ScalarType scalarType = (ScalarType) targetTypeDef.getType();
PrimitiveType primitiveType = scalarType.getPrimitiveType();
switch (primitiveType) {
case DECIMALV2:
if (!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) {
typeDefParams.add(scalarType.getScalarPrecisionStr());
}
if (!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) {
typeDefParams.add(scalarType.getScalarScaleStr());
}
break;
case CHAR:
case VARCHAR:
if (!Strings.isNullOrEmpty(scalarType.getLenStr())) {
typeDefParams.add(scalarType.getLenStr());
}
break;
}
}
} else {
throw new AnalysisException("Not supported expr type: " + originFunction);
}
Set<String> set = new HashSet<>();
for (String str : parameters) {
if (!set.add(str)) {
throw new AnalysisException("Alias function [" + functionName() + "] has duplicate parameter [" + str + "].");
}
boolean existFlag = false;
// check exprs
for (Expr expr : exprs) {
existFlag |= checkParams(expr, str);
}
// check targetTypeDef
for (String typeDefParam : typeDefParams) {
existFlag |= typeDefParam.equals(str);
}
if (!existFlag) {
throw new AnalysisException("Alias function [" + functionName() + "] do not contain parameter [" + str + "].");
}
}
}
private boolean checkParams(Expr expr, String parma) {
private boolean checkParams(Expr expr, String param) {
for (Expr e : expr.getChildren()) {
if (checkParams(e, parma)) {
if (checkParams(e, param)) {
return true;
}
}
if (expr instanceof SlotRef) {
if (parma.equals(((SlotRef) expr).getColumnName())) {
if (param.equals(((SlotRef) expr).getColumnName())) {
return true;
}
}

View File

@ -59,7 +59,8 @@ public enum PrimitiveType {
STRUCT("MAP", 24, TPrimitiveType.STRUCT),
STRING("STRING", 16, TPrimitiveType.STRING),
// Unsupported scalar types.
BINARY("BINARY", -1, TPrimitiveType.BINARY);
BINARY("BINARY", -1, TPrimitiveType.BINARY),
ALL("ALL", -1, TPrimitiveType.INVALID_TYPE);
private static final int DATE_INDEX_LEN = 3;
@ -611,6 +612,8 @@ public enum PrimitiveType {
return MAP;
case STRUCT:
return STRUCT;
case ALL:
return ALL;
default:
return INVALID_TYPE;
}

View File

@ -17,8 +17,13 @@
package org.apache.doris.catalog;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Objects;
import org.apache.doris.common.io.Text;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.thrift.TColumnType;
import org.apache.doris.thrift.TScalarType;
import org.apache.doris.thrift.TTypeDesc;
@ -85,6 +90,16 @@ public class ScalarType extends Type {
@SerializedName(value = "scale")
private int scale;
// Only used for alias function decimal
@SerializedName(value = "precisionStr")
private String precisionStr;
// Only used for alias function decimal
@SerializedName(value = "scaleStr")
private String scaleStr;
// Only used for alias function char/varchar
@SerializedName(value = "lenStr")
private String lenStr;
protected ScalarType(PrimitiveType type) {
this.type = type;
}
@ -144,6 +159,8 @@ public class ScalarType extends Type {
return DEFAULT_DECIMALV2;
case LARGEINT:
return LARGEINT;
case ALL:
return ALL;
default:
LOG.warn("type={}", type);
Preconditions.checkState(false);
@ -207,6 +224,12 @@ public class ScalarType extends Type {
return type;
}
public static ScalarType createCharType(String lenStr) {
ScalarType type = new ScalarType(PrimitiveType.CHAR);
type.lenStr = lenStr;
return type;
}
public static ScalarType createChar(int len) {
ScalarType type = new ScalarType(PrimitiveType.CHAR);
type.len = len;
@ -230,6 +253,20 @@ public class ScalarType extends Type {
return type;
}
public static ScalarType createDecimalV2Type(String precisionStr) {
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
type.precisionStr = precisionStr;
type.scaleStr = null;
return type;
}
public static ScalarType createDecimalV2Type(String precisionStr, String scaleStr) {
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
type.precisionStr = precisionStr;
type.scaleStr = scaleStr;
return type;
}
public static ScalarType createDecimalV2TypeInternal(int precision, int scale) {
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
type.precision = Math.min(precision, MAX_PRECISION);
@ -244,6 +281,13 @@ public class ScalarType extends Type {
return type;
}
public static ScalarType createVarcharType(String lenStr) {
// length checked in analysis
ScalarType type = new ScalarType(PrimitiveType.VARCHAR);
type.lenStr = lenStr;
return type;
}
public static ScalarType createStringType() {
// length checked in analysis
ScalarType type = new ScalarType(PrimitiveType.STRING);
@ -296,13 +340,27 @@ public class ScalarType extends Type {
StringBuilder stringBuilder = new StringBuilder();
switch (type) {
case CHAR:
stringBuilder.append("char").append("(").append(len).append(")");
if (Strings.isNullOrEmpty(lenStr)) {
stringBuilder.append("char").append("(").append(len).append(")");
} else {
stringBuilder.append("char").append("(`").append(lenStr).append("`)");
}
break;
case VARCHAR:
stringBuilder.append("varchar").append("(").append(len).append(")");
if (Strings.isNullOrEmpty(lenStr)) {
stringBuilder.append("varchar").append("(").append(len).append(")");
} else {
stringBuilder.append("varchar").append("(`").append(lenStr).append("`)");
}
break;
case DECIMALV2:
stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")");
if (Strings.isNullOrEmpty(precisionStr)) {
stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")");
} else if (!Strings.isNullOrEmpty(precisionStr) && !Strings.isNullOrEmpty(scaleStr)) {
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`, `").append(scaleStr).append("`)");
} else {
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`)");
}
break;
case BOOLEAN:
return "boolean";
@ -393,6 +451,18 @@ public class ScalarType extends Type {
public int getScalarScale() { return scale; }
public int getScalarPrecision() { return precision; }
public String getScalarPrecisionStr() {
return precisionStr;
}
public String getScalarScaleStr() {
return scaleStr;
}
public String getLenStr() {
return lenStr;
}
@Override
public boolean isWildcardDecimal() {
return (type == PrimitiveType.DECIMALV2)
@ -606,6 +676,11 @@ public class ScalarType extends Type {
return INVALID;
}
// for cast all type
if (t1.type == PrimitiveType.ALL || t2.type == PrimitiveType.ALL) {
return Type.ALL;
}
if (t1.isStringType() || t2.isStringType()) {
if (t1.type == PrimitiveType.STRING || t2.type == PrimitiveType.STRING) {
return createStringType();
@ -708,4 +783,14 @@ public class ScalarType extends Type {
result = 31 * result + scale;
return result;
}
public void write(DataOutput out) throws IOException {
String json = GsonUtils.GSON.toJson(this);
Text.writeString(out, json);
}
public static ScalarType read(DataInput input) throws IOException {
String json = Text.readString(input);
return GsonUtils.GSON.fromJson(json, ScalarType.class);
}
}

View File

@ -76,6 +76,8 @@ public abstract class Type {
public static final ScalarType HLL = ScalarType.createHllType();
public static final ScalarType CHAR = (ScalarType) ScalarType.createCharType(-1);
public static final ScalarType BITMAP = new ScalarType(PrimitiveType.BITMAP);
// Only used for alias function, to represent any type in function args
public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL);
public static final MapType Map = new MapType();
private static ArrayList<ScalarType> integerTypes;
@ -944,9 +946,9 @@ public abstract class Type {
compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE;
// Check all of the necessary entries that should be filled.
// ignore binary
for (int i = 0; i < PrimitiveType.values().length - 1; ++i) {
for (int j = i; j < PrimitiveType.values().length - 1; ++j) {
// ignore binary and all
for (int i = 0; i < PrimitiveType.values().length - 2; ++i) {
for (int j = i; j < PrimitiveType.values().length - 2; ++j) {
PrimitiveType t1 = PrimitiveType.values()[i];
PrimitiveType t2 = PrimitiveType.values()[j];
// DECIMAL, NULL, and INVALID_TYPE are handled separately.

View File

@ -18,6 +18,7 @@
package org.apache.doris.rewrite;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.catalog.AliasFunction;
@ -39,7 +40,13 @@ public class RewriteAliasFunctionRule implements ExprRewriteRule{
if (expr instanceof FunctionCallExpr) {
Function fn = expr.getFn();
if (fn instanceof AliasFunction) {
return ((FunctionCallExpr) expr).rewriteExpr();
Expr originFn = ((AliasFunction) fn).getOriginFunction();
if (originFn instanceof FunctionCallExpr) {
return ((FunctionCallExpr) expr).rewriteExpr();
} else if (originFn instanceof CastExpr) {
return ((CastExpr) originFn).rewriteExpr(((AliasFunction) fn).getParameters(),
((FunctionCallExpr) expr).getParams().exprs());
}
}
}

View File

@ -19,8 +19,10 @@ package org.apache.doris.catalog;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateFunctionStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.common.FeConstants;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.planner.PlanFragment;
@ -29,6 +31,7 @@ import org.apache.doris.planner.UnionNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.QueryState;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
@ -48,11 +51,15 @@ import java.util.UUID;
public class CreateFunctionTest {
private static String runningDir = "fe/mocked/CreateFunctionTest/" + UUID.randomUUID().toString() + "/";
private static ConnectContext connectContext;
private static DorisAssert dorisAssert;
@BeforeClass
public static void setup() throws Exception {
UtFrameUtils.createDorisCluster(runningDir);
FeConstants.runningUnitTest = true;
// create connect context
connectContext = UtFrameUtils.createDefaultCtx();
}
@AfterClass
@ -71,6 +78,14 @@ public class CreateFunctionTest {
Catalog.getCurrentCatalog().createDb(createDbStmt);
System.out.println(Catalog.getCurrentCatalog().getDbNames());
String createTblStmtStr = "create table db1.tbl1(k1 int, k2 bigint, k3 varchar(10), k4 char(5)) duplicate key(k1) "
+ "distributed by hash(k2) buckets 1 properties('replication_num' = '1');";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, connectContext);
Catalog.getCurrentCatalog().createTable(createTableStmt);
dorisAssert = new DorisAssert();
dorisAssert.useDatabase("db1");
Database db = Catalog.getCurrentCatalog().getDbNullable("default_cluster:db1");
Assert.assertNotNull(db);
@ -126,5 +141,89 @@ public class CreateFunctionTest {
Assert.assertEquals(1, constExprLists.size());
Assert.assertEquals(1, constExprLists.get(0).size());
Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
queryStr = "select db1.id_masking(k1) from db1.tbl1";
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("concat(left(`k1`, 3), '****', right(`k1`, 4))"));
// create alias function with cast
// cast any type to decimal with specific precision and scale
createFuncStr = "create alias function db1.decimal(all, int, int) with parameter(col, precision, scale)" +
" as cast(col as decimal(precision, scale));";
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
functions = db.getFunctions();
Assert.assertEquals(3, functions.size());
queryStr = "select db1.decimal(333, 4, 1);";
ctx.getState().reset();
stmtExecutor = new StmtExecutor(ctx, queryStr);
stmtExecutor.execute();
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
planner = stmtExecutor.planner();
Assert.assertEquals(1, planner.getFragments().size());
fragment = planner.getFragments().get(0);
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
unionNode = (UnionNode)fragment.getPlanRoot();
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
System.out.println(constExprLists.get(0).get(0));
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4,1))"));
// cast any type to varchar with fixed length
createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " +
"cast(text as varchar(length));";
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
functions = db.getFunctions();
Assert.assertEquals(4, functions.size());
queryStr = "select db1.varchar(333, 4);";
ctx.getState().reset();
stmtExecutor = new StmtExecutor(ctx, queryStr);
stmtExecutor.execute();
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
planner = stmtExecutor.planner();
Assert.assertEquals(1, planner.getFragments().size());
fragment = planner.getFragments().get(0);
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
unionNode = (UnionNode)fragment.getPlanRoot();
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
Assert.assertEquals(1, constExprLists.size());
Assert.assertEquals(1, constExprLists.get(0).size());
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
// cast any type to char with fixed length
createFuncStr = "create alias function db1.char(all, int) with parameter(text, length) as " +
"cast(text as char(length));";
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
functions = db.getFunctions();
Assert.assertEquals(5, functions.size());
queryStr = "select db1.char(333, 4);";
ctx.getState().reset();
stmtExecutor = new StmtExecutor(ctx, queryStr);
stmtExecutor.execute();
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
planner = stmtExecutor.planner();
Assert.assertEquals(1, planner.getFragments().size());
fragment = planner.getFragments().get(0);
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
unionNode = (UnionNode)fragment.getPlanRoot();
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
Assert.assertEquals(1, constExprLists.size());
Assert.assertEquals(1, constExprLists.get(0).size());
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
queryStr = "select db1.char(k1, 4) from db1.tbl1;";
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
}
}

View File

@ -78,7 +78,8 @@ enum TPrimitiveType {
ARRAY,
MAP,
STRUCT,
STRING
STRING,
ALL
}
enum TTypeNodeType {