[AliasFunction] Add support for cast in alias function (#6754)
support #6753
This commit is contained in:
@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
|
||||
>
|
||||
> `Function_name`: To create the name of the function, you can include the name of the database. For example: `db1.my_func'.
|
||||
>
|
||||
> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type.
|
||||
> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter.
|
||||
> `arg_type`: The parameter type of the function is the same as the type defined at the time of table building. Variable-length parameters can be represented by `,...`. If it is a variable-length type, the type of the variable-length part of the parameters is the same as the last non-variable-length parameter type.
|
||||
> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, and there is at least one parameter. In particular, the type `ALL` refers to any data type and can only be used for `ALIAS FUNCTION`.
|
||||
>
|
||||
> `ret_type`: Required for creating a new function. This parameter is not required if you are aliasing an existing function.
|
||||
>
|
||||
@ -130,8 +130,13 @@ If the `function_name` contains the database name, the custom function will be c
|
||||
5. Create a custom alias function
|
||||
|
||||
```
|
||||
-- create a custom functional alias function
|
||||
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
|
||||
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
|
||||
|
||||
-- create a custom cast alias function
|
||||
CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale)
|
||||
AS CAST(col AS decimal(precision, scale));
|
||||
```
|
||||
|
||||
## keyword
|
||||
|
||||
@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
|
||||
>
|
||||
> `function_name`: 要创建函数的名字, 可以包含数据库的名字。比如:`db1.my_func`。
|
||||
>
|
||||
> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
|
||||
> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。
|
||||
> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, ...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
|
||||
> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 特别地,`ALL` 类型指任一数据类型,只可以用于 `ALIAS FUNCTION`.
|
||||
>
|
||||
> `ret_type`: 对创建新的函数来说,是必填项。如果是给已有函数取别名则可不用填写该参数。
|
||||
>
|
||||
@ -131,8 +131,13 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
|
||||
5. 创建一个自定义别名函数
|
||||
|
||||
```
|
||||
-- 创建自定义功能别名函数
|
||||
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
|
||||
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
|
||||
|
||||
-- 创建自定义 CAST 别名函数
|
||||
CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, precision, scale)
|
||||
AS CAST(col AS decimal(precision, scale));
|
||||
```
|
||||
|
||||
## keyword
|
||||
|
||||
@ -4352,6 +4352,11 @@ type ::=
|
||||
type.setAssignedStrLenInColDefinition();
|
||||
RESULT = type;
|
||||
:}
|
||||
| KW_VARCHAR LPAREN ident_or_text:lenStr RPAREN
|
||||
{: ScalarType type = ScalarType.createVarcharType(lenStr);
|
||||
type.setAssignedStrLenInColDefinition();
|
||||
RESULT = type;
|
||||
:}
|
||||
| KW_VARCHAR
|
||||
{: RESULT = ScalarType.createVarcharType(-1); :}
|
||||
| KW_ARRAY LESSTHAN type:value_type GREATERTHAN
|
||||
@ -4365,6 +4370,11 @@ type ::=
|
||||
type.setAssignedStrLenInColDefinition();
|
||||
RESULT = type;
|
||||
:}
|
||||
| KW_CHAR LPAREN ident_or_text:lenStr RPAREN
|
||||
{: ScalarType type = ScalarType.createCharType(lenStr);
|
||||
type.setAssignedStrLenInColDefinition();
|
||||
RESULT = type;
|
||||
:}
|
||||
| KW_CHAR
|
||||
{: RESULT = ScalarType.createCharType(-1); :}
|
||||
| KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN
|
||||
@ -4373,11 +4383,17 @@ type ::=
|
||||
{: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), scale.intValue()); :}
|
||||
| KW_DECIMAL
|
||||
{: RESULT = ScalarType.createDecimalV2Type(); :}
|
||||
| KW_DECIMAL LPAREN ident_or_text:precision RPAREN
|
||||
{: RESULT = ScalarType.createDecimalV2Type(precision); :}
|
||||
| KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN
|
||||
{: RESULT = ScalarType.createDecimalV2Type(precision, scale); :}
|
||||
| KW_HLL
|
||||
{: ScalarType type = ScalarType.createHllType();
|
||||
type.setAssignedStrLenInColDefinition();
|
||||
RESULT = type;
|
||||
:}
|
||||
| KW_ALL
|
||||
{: RESULT = Type.ALL; :}
|
||||
;
|
||||
|
||||
opt_field_length ::=
|
||||
@ -5180,6 +5196,8 @@ keyword ::=
|
||||
{: RESULT = id; :}
|
||||
| KW_CHAIN:id
|
||||
{: RESULT = id; :}
|
||||
| KW_CHAR:id
|
||||
{: RESULT = id; :}
|
||||
| KW_CHARSET:id
|
||||
{: RESULT = id; :}
|
||||
| KW_CHECK:id
|
||||
@ -5210,6 +5228,8 @@ keyword ::=
|
||||
{: RESULT = id; :}
|
||||
| KW_DATETIME:id
|
||||
{: RESULT = id; :}
|
||||
| KW_DECIMAL:id
|
||||
{: RESULT = id; :}
|
||||
| KW_DISTINCTPC:id
|
||||
{: RESULT = id; :}
|
||||
| KW_DISTINCTPCSA:id
|
||||
@ -5456,6 +5476,8 @@ keyword ::=
|
||||
{: RESULT = id; :}
|
||||
| KW_MAP:id
|
||||
{: RESULT = id; :}
|
||||
| KW_VARCHAR:id
|
||||
{: RESULT = id; :}
|
||||
;
|
||||
|
||||
// Identifier that contain keyword
|
||||
|
||||
@ -17,7 +17,11 @@
|
||||
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import java.io.DataInput;
|
||||
import java.io.DataOutput;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.doris.catalog.Catalog;
|
||||
@ -46,7 +50,7 @@ public class CastExpr extends Expr {
|
||||
private static final Logger LOG = LogManager.getLogger(CastExpr.class);
|
||||
|
||||
// Only set for explicit casts. Null for implicit casts.
|
||||
private final TypeDef targetTypeDef;
|
||||
private TypeDef targetTypeDef;
|
||||
|
||||
// True if this is a "pre-analyzed" implicit cast.
|
||||
private boolean isImplicit;
|
||||
@ -77,6 +81,11 @@ public class CastExpr extends Expr {
|
||||
}
|
||||
}
|
||||
|
||||
// only used restore from readFields.
|
||||
public CastExpr() {
|
||||
|
||||
}
|
||||
|
||||
public CastExpr(Type targetType, Expr e) {
|
||||
super();
|
||||
Preconditions.checkArgument(targetType.isValid());
|
||||
@ -120,6 +129,10 @@ public class CastExpr extends Expr {
|
||||
return "castTo" + targetType.getPrimitiveType().toString();
|
||||
}
|
||||
|
||||
public TypeDef getTargetTypeDef() {
|
||||
return targetTypeDef;
|
||||
}
|
||||
|
||||
public static void initBuiltins(FunctionSet functionSet) {
|
||||
for (Type fromType : Type.getSupportedTypes()) {
|
||||
if (fromType.isNull()) {
|
||||
@ -206,6 +219,10 @@ public class CastExpr extends Expr {
|
||||
}
|
||||
|
||||
public void analyze() throws AnalysisException {
|
||||
// do not analyze ALL cast
|
||||
if (type == Type.ALL) {
|
||||
return;
|
||||
}
|
||||
// cast was asked for in the query, check for validity of cast
|
||||
Type childType = getChild(0).getType();
|
||||
|
||||
@ -327,4 +344,111 @@ public class CastExpr extends Expr {
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(DataOutput out) throws IOException {
|
||||
out.writeBoolean(isImplicit);
|
||||
if (targetTypeDef.getType() instanceof ScalarType) {
|
||||
ScalarType scalarType = (ScalarType) targetTypeDef.getType();
|
||||
scalarType.write(out);
|
||||
} else {
|
||||
throw new IOException("Can not write type " + targetTypeDef.getType());
|
||||
}
|
||||
out.writeInt(children.size());
|
||||
for (Expr expr : children) {
|
||||
Expr.writeTo(expr, out);
|
||||
}
|
||||
}
|
||||
|
||||
public static CastExpr read(DataInput input) throws IOException {
|
||||
CastExpr castExpr = new CastExpr();
|
||||
castExpr.readFields(input);
|
||||
return castExpr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readFields(DataInput in) throws IOException {
|
||||
isImplicit = in.readBoolean();
|
||||
ScalarType scalarType = ScalarType.read(in);
|
||||
targetTypeDef = new TypeDef(scalarType);
|
||||
int counter = in.readInt();
|
||||
for (int i = 0; i < counter; i++) {
|
||||
children.add(Expr.readIn(in));
|
||||
}
|
||||
}
|
||||
|
||||
public CastExpr rewriteExpr(List<String> parameters, List<Expr> inputParamsExprs) throws AnalysisException {
|
||||
// child
|
||||
Expr child = this.getChild(0);
|
||||
Expr newChild = null;
|
||||
if (child instanceof SlotRef) {
|
||||
String columnName = ((SlotRef) child).getColumnName();
|
||||
int index = parameters.indexOf(columnName);
|
||||
if (index != -1) {
|
||||
newChild = inputParamsExprs.get(index);
|
||||
}
|
||||
}
|
||||
// rewrite cast expr in children
|
||||
if (child instanceof CastExpr) {
|
||||
newChild = ((CastExpr) child).rewriteExpr(parameters, inputParamsExprs);
|
||||
}
|
||||
|
||||
// type def
|
||||
ScalarType targetType = (ScalarType) targetTypeDef.getType();
|
||||
PrimitiveType primitiveType = targetType.getPrimitiveType();
|
||||
ScalarType newTargetType = null;
|
||||
switch (primitiveType) {
|
||||
case DECIMALV2:
|
||||
// normal decimal
|
||||
if (targetType.getPrecision() != 0) {
|
||||
newTargetType = targetType;
|
||||
break;
|
||||
}
|
||||
int precision = getDigital(targetType.getScalarPrecisionStr(), parameters, inputParamsExprs);
|
||||
int scale = getDigital(targetType.getScalarScaleStr(), parameters, inputParamsExprs);
|
||||
if (precision != -1 && scale != -1) {
|
||||
newTargetType = ScalarType.createType(primitiveType, 0, precision, scale);
|
||||
} else if (precision != -1 && scale == -1) {
|
||||
newTargetType = ScalarType.createType(primitiveType, 0, precision, ScalarType.DEFAULT_SCALE);
|
||||
}
|
||||
break;
|
||||
case CHAR:
|
||||
case VARCHAR:
|
||||
// normal char/varchar
|
||||
if (targetType.getLength() != -1) {
|
||||
newTargetType = targetType;
|
||||
break;
|
||||
}
|
||||
int len = getDigital(targetType.getLenStr(), parameters, inputParamsExprs);
|
||||
if (len != -1) {
|
||||
newTargetType = ScalarType.createType(primitiveType, len, 0, 0);
|
||||
}
|
||||
// default char/varchar, which len is -1
|
||||
if (len == -1 && targetType.getLength() == -1) {
|
||||
newTargetType = targetType;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
newTargetType = targetType;
|
||||
break;
|
||||
}
|
||||
|
||||
if (newTargetType != null && newChild != null) {
|
||||
TypeDef typeDef = new TypeDef(newTargetType);
|
||||
return new CastExpr(typeDef, newChild);
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
private int getDigital(String desc, List<String> parameters, List<Expr> inputParamsExprs) {
|
||||
int index = parameters.indexOf(desc);
|
||||
if (index != -1) {
|
||||
Expr expr = inputParamsExprs.get(index);
|
||||
if (expr.getType().isIntegerType()) {
|
||||
return ((Long)((IntLiteral) expr).getRealValue()).intValue();
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1665,7 +1665,8 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
MAX_LITERAL(10),
|
||||
BINARY_PREDICATE(11),
|
||||
FUNCTION_CALL(12),
|
||||
ARRAY_LITERAL(13);
|
||||
ARRAY_LITERAL(13),
|
||||
CAST_EXPR(14);
|
||||
|
||||
private static Map<Integer, ExprSerCode> codeMap = Maps.newHashMap();
|
||||
|
||||
@ -1715,7 +1716,9 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
output.writeInt(ExprSerCode.FUNCTION_CALL.getCode());
|
||||
} else if (expr instanceof ArrayLiteral) {
|
||||
output.writeInt(ExprSerCode.ARRAY_LITERAL.getCode());
|
||||
} else {
|
||||
} else if (expr instanceof CastExpr){
|
||||
output.writeInt(ExprSerCode.CAST_EXPR.getCode());
|
||||
}else {
|
||||
throw new IOException("Unknown class " + expr.getClass().getName());
|
||||
}
|
||||
expr.write(output);
|
||||
@ -1758,6 +1761,8 @@ abstract public class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
return FunctionCallExpr.read(in);
|
||||
case ARRAY_LITERAL:
|
||||
return ArrayLiteral.read(in);
|
||||
case CAST_EXPR:
|
||||
return CastExpr.read(in);
|
||||
default:
|
||||
throw new IOException("Unknown code: " + code);
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.analysis.CastExpr;
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.FunctionCallExpr;
|
||||
import org.apache.doris.analysis.FunctionName;
|
||||
@ -24,12 +25,14 @@ import org.apache.doris.analysis.SelectStmt;
|
||||
import org.apache.doris.analysis.SlotRef;
|
||||
import org.apache.doris.analysis.SqlParser;
|
||||
import org.apache.doris.analysis.SqlScanner;
|
||||
import org.apache.doris.analysis.TypeDef;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.common.io.Text;
|
||||
import org.apache.doris.common.util.SqlParserUtils;
|
||||
import org.apache.doris.qe.SqlModeHelper;
|
||||
import org.apache.doris.thrift.TFunctionBinaryType;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.gson.Gson;
|
||||
|
||||
@ -59,6 +62,7 @@ public class AliasFunction extends Function {
|
||||
|
||||
private Expr originFunction;
|
||||
private List<String> parameters = new ArrayList<>();
|
||||
private List<String> typeDefParams = new ArrayList<>();
|
||||
|
||||
// Only used for serialization
|
||||
protected AliasFunction() {
|
||||
@ -152,30 +156,63 @@ public class AliasFunction extends Function {
|
||||
if (parameters.size() != getArgs().length) {
|
||||
throw new AnalysisException("Alias function [" + functionName() + "] args number is not equal to parameters number");
|
||||
}
|
||||
List<Expr> exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
|
||||
List<Expr> exprs;
|
||||
if (originFunction instanceof FunctionCallExpr) {
|
||||
exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
|
||||
} else if (originFunction instanceof CastExpr) {
|
||||
exprs = originFunction.getChildren();
|
||||
TypeDef targetTypeDef = ((CastExpr) originFunction).getTargetTypeDef();
|
||||
if (targetTypeDef.getType().isScalarType()) {
|
||||
ScalarType scalarType = (ScalarType) targetTypeDef.getType();
|
||||
PrimitiveType primitiveType = scalarType.getPrimitiveType();
|
||||
switch (primitiveType) {
|
||||
case DECIMALV2:
|
||||
if (!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) {
|
||||
typeDefParams.add(scalarType.getScalarPrecisionStr());
|
||||
}
|
||||
if (!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) {
|
||||
typeDefParams.add(scalarType.getScalarScaleStr());
|
||||
}
|
||||
break;
|
||||
case CHAR:
|
||||
case VARCHAR:
|
||||
if (!Strings.isNullOrEmpty(scalarType.getLenStr())) {
|
||||
typeDefParams.add(scalarType.getLenStr());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new AnalysisException("Not supported expr type: " + originFunction);
|
||||
}
|
||||
Set<String> set = new HashSet<>();
|
||||
for (String str : parameters) {
|
||||
if (!set.add(str)) {
|
||||
throw new AnalysisException("Alias function [" + functionName() + "] has duplicate parameter [" + str + "].");
|
||||
}
|
||||
boolean existFlag = false;
|
||||
// check exprs
|
||||
for (Expr expr : exprs) {
|
||||
existFlag |= checkParams(expr, str);
|
||||
}
|
||||
// check targetTypeDef
|
||||
for (String typeDefParam : typeDefParams) {
|
||||
existFlag |= typeDefParam.equals(str);
|
||||
}
|
||||
if (!existFlag) {
|
||||
throw new AnalysisException("Alias function [" + functionName() + "] do not contain parameter [" + str + "].");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean checkParams(Expr expr, String parma) {
|
||||
private boolean checkParams(Expr expr, String param) {
|
||||
for (Expr e : expr.getChildren()) {
|
||||
if (checkParams(e, parma)) {
|
||||
if (checkParams(e, param)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (expr instanceof SlotRef) {
|
||||
if (parma.equals(((SlotRef) expr).getColumnName())) {
|
||||
if (param.equals(((SlotRef) expr).getColumnName())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,7 +59,8 @@ public enum PrimitiveType {
|
||||
STRUCT("MAP", 24, TPrimitiveType.STRUCT),
|
||||
STRING("STRING", 16, TPrimitiveType.STRING),
|
||||
// Unsupported scalar types.
|
||||
BINARY("BINARY", -1, TPrimitiveType.BINARY);
|
||||
BINARY("BINARY", -1, TPrimitiveType.BINARY),
|
||||
ALL("ALL", -1, TPrimitiveType.INVALID_TYPE);
|
||||
|
||||
|
||||
private static final int DATE_INDEX_LEN = 3;
|
||||
@ -611,6 +612,8 @@ public enum PrimitiveType {
|
||||
return MAP;
|
||||
case STRUCT:
|
||||
return STRUCT;
|
||||
case ALL:
|
||||
return ALL;
|
||||
default:
|
||||
return INVALID_TYPE;
|
||||
}
|
||||
|
||||
@ -17,8 +17,13 @@
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import java.io.DataInput;
|
||||
import java.io.DataOutput;
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.doris.common.io.Text;
|
||||
import org.apache.doris.persist.gson.GsonUtils;
|
||||
import org.apache.doris.thrift.TColumnType;
|
||||
import org.apache.doris.thrift.TScalarType;
|
||||
import org.apache.doris.thrift.TTypeDesc;
|
||||
@ -85,6 +90,16 @@ public class ScalarType extends Type {
|
||||
@SerializedName(value = "scale")
|
||||
private int scale;
|
||||
|
||||
// Only used for alias function decimal
|
||||
@SerializedName(value = "precisionStr")
|
||||
private String precisionStr;
|
||||
// Only used for alias function decimal
|
||||
@SerializedName(value = "scaleStr")
|
||||
private String scaleStr;
|
||||
// Only used for alias function char/varchar
|
||||
@SerializedName(value = "lenStr")
|
||||
private String lenStr;
|
||||
|
||||
protected ScalarType(PrimitiveType type) {
|
||||
this.type = type;
|
||||
}
|
||||
@ -144,6 +159,8 @@ public class ScalarType extends Type {
|
||||
return DEFAULT_DECIMALV2;
|
||||
case LARGEINT:
|
||||
return LARGEINT;
|
||||
case ALL:
|
||||
return ALL;
|
||||
default:
|
||||
LOG.warn("type={}", type);
|
||||
Preconditions.checkState(false);
|
||||
@ -207,6 +224,12 @@ public class ScalarType extends Type {
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createCharType(String lenStr) {
|
||||
ScalarType type = new ScalarType(PrimitiveType.CHAR);
|
||||
type.lenStr = lenStr;
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createChar(int len) {
|
||||
ScalarType type = new ScalarType(PrimitiveType.CHAR);
|
||||
type.len = len;
|
||||
@ -230,6 +253,20 @@ public class ScalarType extends Type {
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createDecimalV2Type(String precisionStr) {
|
||||
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
|
||||
type.precisionStr = precisionStr;
|
||||
type.scaleStr = null;
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createDecimalV2Type(String precisionStr, String scaleStr) {
|
||||
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
|
||||
type.precisionStr = precisionStr;
|
||||
type.scaleStr = scaleStr;
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createDecimalV2TypeInternal(int precision, int scale) {
|
||||
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
|
||||
type.precision = Math.min(precision, MAX_PRECISION);
|
||||
@ -244,6 +281,13 @@ public class ScalarType extends Type {
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createVarcharType(String lenStr) {
|
||||
// length checked in analysis
|
||||
ScalarType type = new ScalarType(PrimitiveType.VARCHAR);
|
||||
type.lenStr = lenStr;
|
||||
return type;
|
||||
}
|
||||
|
||||
public static ScalarType createStringType() {
|
||||
// length checked in analysis
|
||||
ScalarType type = new ScalarType(PrimitiveType.STRING);
|
||||
@ -296,13 +340,27 @@ public class ScalarType extends Type {
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
switch (type) {
|
||||
case CHAR:
|
||||
stringBuilder.append("char").append("(").append(len).append(")");
|
||||
if (Strings.isNullOrEmpty(lenStr)) {
|
||||
stringBuilder.append("char").append("(").append(len).append(")");
|
||||
} else {
|
||||
stringBuilder.append("char").append("(`").append(lenStr).append("`)");
|
||||
}
|
||||
break;
|
||||
case VARCHAR:
|
||||
stringBuilder.append("varchar").append("(").append(len).append(")");
|
||||
if (Strings.isNullOrEmpty(lenStr)) {
|
||||
stringBuilder.append("varchar").append("(").append(len).append(")");
|
||||
} else {
|
||||
stringBuilder.append("varchar").append("(`").append(lenStr).append("`)");
|
||||
}
|
||||
break;
|
||||
case DECIMALV2:
|
||||
stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")");
|
||||
if (Strings.isNullOrEmpty(precisionStr)) {
|
||||
stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")");
|
||||
} else if (!Strings.isNullOrEmpty(precisionStr) && !Strings.isNullOrEmpty(scaleStr)) {
|
||||
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`, `").append(scaleStr).append("`)");
|
||||
} else {
|
||||
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`)");
|
||||
}
|
||||
break;
|
||||
case BOOLEAN:
|
||||
return "boolean";
|
||||
@ -393,6 +451,18 @@ public class ScalarType extends Type {
|
||||
public int getScalarScale() { return scale; }
|
||||
public int getScalarPrecision() { return precision; }
|
||||
|
||||
public String getScalarPrecisionStr() {
|
||||
return precisionStr;
|
||||
}
|
||||
|
||||
public String getScalarScaleStr() {
|
||||
return scaleStr;
|
||||
}
|
||||
|
||||
public String getLenStr() {
|
||||
return lenStr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isWildcardDecimal() {
|
||||
return (type == PrimitiveType.DECIMALV2)
|
||||
@ -606,6 +676,11 @@ public class ScalarType extends Type {
|
||||
return INVALID;
|
||||
}
|
||||
|
||||
// for cast all type
|
||||
if (t1.type == PrimitiveType.ALL || t2.type == PrimitiveType.ALL) {
|
||||
return Type.ALL;
|
||||
}
|
||||
|
||||
if (t1.isStringType() || t2.isStringType()) {
|
||||
if (t1.type == PrimitiveType.STRING || t2.type == PrimitiveType.STRING) {
|
||||
return createStringType();
|
||||
@ -708,4 +783,14 @@ public class ScalarType extends Type {
|
||||
result = 31 * result + scale;
|
||||
return result;
|
||||
}
|
||||
|
||||
public void write(DataOutput out) throws IOException {
|
||||
String json = GsonUtils.GSON.toJson(this);
|
||||
Text.writeString(out, json);
|
||||
}
|
||||
|
||||
public static ScalarType read(DataInput input) throws IOException {
|
||||
String json = Text.readString(input);
|
||||
return GsonUtils.GSON.fromJson(json, ScalarType.class);
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,6 +76,8 @@ public abstract class Type {
|
||||
public static final ScalarType HLL = ScalarType.createHllType();
|
||||
public static final ScalarType CHAR = (ScalarType) ScalarType.createCharType(-1);
|
||||
public static final ScalarType BITMAP = new ScalarType(PrimitiveType.BITMAP);
|
||||
// Only used for alias function, to represent any type in function args
|
||||
public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL);
|
||||
public static final MapType Map = new MapType();
|
||||
|
||||
private static ArrayList<ScalarType> integerTypes;
|
||||
@ -944,9 +946,9 @@ public abstract class Type {
|
||||
compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE;
|
||||
|
||||
// Check all of the necessary entries that should be filled.
|
||||
// ignore binary
|
||||
for (int i = 0; i < PrimitiveType.values().length - 1; ++i) {
|
||||
for (int j = i; j < PrimitiveType.values().length - 1; ++j) {
|
||||
// ignore binary and all
|
||||
for (int i = 0; i < PrimitiveType.values().length - 2; ++i) {
|
||||
for (int j = i; j < PrimitiveType.values().length - 2; ++j) {
|
||||
PrimitiveType t1 = PrimitiveType.values()[i];
|
||||
PrimitiveType t2 = PrimitiveType.values()[j];
|
||||
// DECIMAL, NULL, and INVALID_TYPE are handled separately.
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.rewrite;
|
||||
|
||||
import org.apache.doris.analysis.Analyzer;
|
||||
import org.apache.doris.analysis.CastExpr;
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.FunctionCallExpr;
|
||||
import org.apache.doris.catalog.AliasFunction;
|
||||
@ -39,7 +40,13 @@ public class RewriteAliasFunctionRule implements ExprRewriteRule{
|
||||
if (expr instanceof FunctionCallExpr) {
|
||||
Function fn = expr.getFn();
|
||||
if (fn instanceof AliasFunction) {
|
||||
return ((FunctionCallExpr) expr).rewriteExpr();
|
||||
Expr originFn = ((AliasFunction) fn).getOriginFunction();
|
||||
if (originFn instanceof FunctionCallExpr) {
|
||||
return ((FunctionCallExpr) expr).rewriteExpr();
|
||||
} else if (originFn instanceof CastExpr) {
|
||||
return ((CastExpr) originFn).rewriteExpr(((AliasFunction) fn).getParameters(),
|
||||
((FunctionCallExpr) expr).getParams().exprs());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -19,8 +19,10 @@ package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.analysis.CreateDbStmt;
|
||||
import org.apache.doris.analysis.CreateFunctionStmt;
|
||||
import org.apache.doris.analysis.CreateTableStmt;
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.FunctionCallExpr;
|
||||
import org.apache.doris.analysis.StringLiteral;
|
||||
import org.apache.doris.common.FeConstants;
|
||||
import org.apache.doris.common.jmockit.Deencapsulation;
|
||||
import org.apache.doris.planner.PlanFragment;
|
||||
@ -29,6 +31,7 @@ import org.apache.doris.planner.UnionNode;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.QueryState;
|
||||
import org.apache.doris.qe.StmtExecutor;
|
||||
import org.apache.doris.utframe.DorisAssert;
|
||||
import org.apache.doris.utframe.UtFrameUtils;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
@ -48,11 +51,15 @@ import java.util.UUID;
|
||||
public class CreateFunctionTest {
|
||||
|
||||
private static String runningDir = "fe/mocked/CreateFunctionTest/" + UUID.randomUUID().toString() + "/";
|
||||
private static ConnectContext connectContext;
|
||||
private static DorisAssert dorisAssert;
|
||||
|
||||
@BeforeClass
|
||||
public static void setup() throws Exception {
|
||||
UtFrameUtils.createDorisCluster(runningDir);
|
||||
FeConstants.runningUnitTest = true;
|
||||
// create connect context
|
||||
connectContext = UtFrameUtils.createDefaultCtx();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
@ -71,6 +78,14 @@ public class CreateFunctionTest {
|
||||
Catalog.getCurrentCatalog().createDb(createDbStmt);
|
||||
System.out.println(Catalog.getCurrentCatalog().getDbNames());
|
||||
|
||||
String createTblStmtStr = "create table db1.tbl1(k1 int, k2 bigint, k3 varchar(10), k4 char(5)) duplicate key(k1) "
|
||||
+ "distributed by hash(k2) buckets 1 properties('replication_num' = '1');";
|
||||
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, connectContext);
|
||||
Catalog.getCurrentCatalog().createTable(createTableStmt);
|
||||
|
||||
dorisAssert = new DorisAssert();
|
||||
dorisAssert.useDatabase("db1");
|
||||
|
||||
Database db = Catalog.getCurrentCatalog().getDbNullable("default_cluster:db1");
|
||||
Assert.assertNotNull(db);
|
||||
|
||||
@ -126,5 +141,89 @@ public class CreateFunctionTest {
|
||||
Assert.assertEquals(1, constExprLists.size());
|
||||
Assert.assertEquals(1, constExprLists.get(0).size());
|
||||
Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
|
||||
|
||||
queryStr = "select db1.id_masking(k1) from db1.tbl1";
|
||||
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("concat(left(`k1`, 3), '****', right(`k1`, 4))"));
|
||||
|
||||
// create alias function with cast
|
||||
// cast any type to decimal with specific precision and scale
|
||||
createFuncStr = "create alias function db1.decimal(all, int, int) with parameter(col, precision, scale)" +
|
||||
" as cast(col as decimal(precision, scale));";
|
||||
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
|
||||
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
|
||||
|
||||
functions = db.getFunctions();
|
||||
Assert.assertEquals(3, functions.size());
|
||||
|
||||
queryStr = "select db1.decimal(333, 4, 1);";
|
||||
ctx.getState().reset();
|
||||
stmtExecutor = new StmtExecutor(ctx, queryStr);
|
||||
stmtExecutor.execute();
|
||||
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
|
||||
planner = stmtExecutor.planner();
|
||||
Assert.assertEquals(1, planner.getFragments().size());
|
||||
fragment = planner.getFragments().get(0);
|
||||
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
|
||||
unionNode = (UnionNode)fragment.getPlanRoot();
|
||||
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
|
||||
System.out.println(constExprLists.get(0).get(0));
|
||||
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
|
||||
|
||||
queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
|
||||
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4,1))"));
|
||||
|
||||
// cast any type to varchar with fixed length
|
||||
createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " +
|
||||
"cast(text as varchar(length));";
|
||||
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
|
||||
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
|
||||
|
||||
functions = db.getFunctions();
|
||||
Assert.assertEquals(4, functions.size());
|
||||
|
||||
queryStr = "select db1.varchar(333, 4);";
|
||||
ctx.getState().reset();
|
||||
stmtExecutor = new StmtExecutor(ctx, queryStr);
|
||||
stmtExecutor.execute();
|
||||
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
|
||||
planner = stmtExecutor.planner();
|
||||
Assert.assertEquals(1, planner.getFragments().size());
|
||||
fragment = planner.getFragments().get(0);
|
||||
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
|
||||
unionNode = (UnionNode)fragment.getPlanRoot();
|
||||
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
|
||||
Assert.assertEquals(1, constExprLists.size());
|
||||
Assert.assertEquals(1, constExprLists.get(0).size());
|
||||
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
|
||||
|
||||
queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
|
||||
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
|
||||
|
||||
// cast any type to char with fixed length
|
||||
createFuncStr = "create alias function db1.char(all, int) with parameter(text, length) as " +
|
||||
"cast(text as char(length));";
|
||||
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
|
||||
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
|
||||
|
||||
functions = db.getFunctions();
|
||||
Assert.assertEquals(5, functions.size());
|
||||
|
||||
queryStr = "select db1.char(333, 4);";
|
||||
ctx.getState().reset();
|
||||
stmtExecutor = new StmtExecutor(ctx, queryStr);
|
||||
stmtExecutor.execute();
|
||||
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
|
||||
planner = stmtExecutor.planner();
|
||||
Assert.assertEquals(1, planner.getFragments().size());
|
||||
fragment = planner.getFragments().get(0);
|
||||
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
|
||||
unionNode = (UnionNode)fragment.getPlanRoot();
|
||||
constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
|
||||
Assert.assertEquals(1, constExprLists.size());
|
||||
Assert.assertEquals(1, constExprLists.get(0).size());
|
||||
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
|
||||
|
||||
queryStr = "select db1.char(k1, 4) from db1.tbl1;";
|
||||
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,7 +78,8 @@ enum TPrimitiveType {
|
||||
ARRAY,
|
||||
MAP,
|
||||
STRUCT,
|
||||
STRING
|
||||
STRING,
|
||||
ALL
|
||||
}
|
||||
|
||||
enum TTypeNodeType {
|
||||
|
||||
Reference in New Issue
Block a user