[Bug](decimalv3) Use correct decimal scale for function round (#17232)
Co-authored-by: maochongxin <maochongxin@gmail.com>
This commit is contained in:
@ -67,12 +67,16 @@ import java.util.Set;
|
||||
|
||||
// TODO: for aggregations, we need to unify the code paths for builtins and UDAs.
|
||||
public class FunctionCallExpr extends Expr {
|
||||
public static final ImmutableSet<String> STDDEV_FUNCTION_SET =
|
||||
new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER)
|
||||
.add("stddev").add("stddev_val").add("stddev_samp").add("stddev_pop")
|
||||
.add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build();
|
||||
public static final ImmutableSet<String> STDDEV_FUNCTION_SET = new ImmutableSortedSet.Builder(
|
||||
String.CASE_INSENSITIVE_ORDER)
|
||||
.add("stddev").add("stddev_val").add("stddev_samp").add("stddev_pop")
|
||||
.add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build();
|
||||
public static final Map<String, java.util.function.BiFunction<ArrayList<Expr>, Type, Type>> PRECISION_INFER_RULE;
|
||||
public static final java.util.function.BiFunction<ArrayList<Expr>, Type, Type> DEFAULT_PRECISION_INFER_RULE;
|
||||
public static final ImmutableSet<String> ROUND_FUNCTION_SET = new ImmutableSortedSet.Builder(
|
||||
String.CASE_INSENSITIVE_ORDER)
|
||||
.add("round").add("round_bankers").add("ceil").add("floor")
|
||||
.add("truncate").add("dround").add("dceil").add("dfloor").build();
|
||||
|
||||
static {
|
||||
java.util.function.BiFunction<ArrayList<Expr>, Type, Type> sumRule = (children, returnType) -> {
|
||||
@ -101,11 +105,10 @@ public class FunctionCallExpr extends Expr {
|
||||
java.util.function.BiFunction<ArrayList<Expr>, Type, Type> roundRule = (children, returnType) -> {
|
||||
Preconditions.checkArgument(children != null && children.size() > 0);
|
||||
if (children.size() == 1 && children.get(0).getType().isDecimalV3()) {
|
||||
return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
|
||||
((ScalarType) children.get(0).getType()).decimalScale());
|
||||
return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0);
|
||||
} else if (children.size() == 2) {
|
||||
Preconditions.checkArgument(children.get(1) instanceof IntLiteral
|
||||
|| (children.get(1) instanceof CastExpr
|
||||
|| (children.get(1) instanceof CastExpr
|
||||
&& children.get(1).getChild(0) instanceof IntLiteral),
|
||||
"2nd argument of function round/floor/ceil/truncate must be literal");
|
||||
if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) {
|
||||
@ -114,8 +117,9 @@ public class FunctionCallExpr extends Expr {
|
||||
} else {
|
||||
children.get(1).setType(Type.INT);
|
||||
}
|
||||
int scaleArg = (int) (((IntLiteral) children.get(1)).getValue());
|
||||
return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
|
||||
((ScalarType) children.get(0).getType()).decimalScale());
|
||||
Math.max(scaleArg, 0));
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
@ -155,9 +159,10 @@ public class FunctionCallExpr extends Expr {
|
||||
Math.max(((ScalarType) children.get(1).getType()).decimalScale(),
|
||||
((ScalarType) children.get(2).getType()).decimalScale()));
|
||||
} else if (children.get(1).getType().isDatetimeV2() && children.get(2).getType().isDatetimeV2()) {
|
||||
return ((ScalarType) children.get(1).getType()).decimalScale()
|
||||
> ((ScalarType) children.get(2).getType()).decimalScale()
|
||||
? children.get(1).getType() : children.get(2).getType();
|
||||
return ((ScalarType) children.get(1).getType())
|
||||
.decimalScale() > ((ScalarType) children.get(2).getType()).decimalScale()
|
||||
? children.get(1).getType()
|
||||
: children.get(2).getType();
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
@ -178,9 +183,9 @@ public class FunctionCallExpr extends Expr {
|
||||
PRECISION_INFER_RULE.put("truncate", roundRule);
|
||||
}
|
||||
|
||||
public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION =
|
||||
new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER)
|
||||
.add("now").add("current_timestamp").add("localtime").add("localtimestamp").build();
|
||||
public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION = new ImmutableSortedSet.Builder(
|
||||
String.CASE_INSENSITIVE_ORDER)
|
||||
.add("now").add("current_timestamp").add("localtime").add("localtimestamp").build();
|
||||
public static final int STDDEV_DECIMAL_SCALE = 9;
|
||||
private static final String ELEMENT_EXTRACT_FN_NAME = "%element_extract%";
|
||||
|
||||
@ -199,7 +204,8 @@ public class FunctionCallExpr extends Expr {
|
||||
// check table function
|
||||
private boolean isTableFnCall = false;
|
||||
|
||||
// Indicates whether this is a merge aggregation function that should use the merge
|
||||
// Indicates whether this is a merge aggregation function that should use the
|
||||
// merge
|
||||
// instead of the update symbol. This flag also affects the behavior of
|
||||
// resetAnalysisState() which is used during expr substitution.
|
||||
private boolean isMergeAggFn;
|
||||
@ -211,7 +217,8 @@ public class FunctionCallExpr extends Expr {
|
||||
|
||||
private boolean isRewrote = false;
|
||||
|
||||
// TODO: this field will be removed when we support analyze aggregate function in the nereids framework.
|
||||
// TODO: this field will be removed when we support analyze aggregate function
|
||||
// in the nereids framework.
|
||||
private boolean shouldFinalizeForNereids = true;
|
||||
|
||||
// this field is set by nereids, so we would not get arg types by the children.
|
||||
@ -466,7 +473,7 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
int len = children.size();
|
||||
List<String> result = Lists.newArrayList();
|
||||
//XXX_diff are used by nereids only
|
||||
// XXX_diff are used by nereids only
|
||||
if (fnName.getFunction().equalsIgnoreCase("years_diff")
|
||||
|| fnName.getFunction().equalsIgnoreCase("months_diff")
|
||||
|| fnName.getFunction().equalsIgnoreCase("days_diff")
|
||||
@ -477,7 +484,7 @@ public class FunctionCallExpr extends Expr {
|
||||
sb.append(children.get(0).toSql()).append(")");
|
||||
return sb.toString();
|
||||
}
|
||||
//used by nereids END
|
||||
// used by nereids END
|
||||
|
||||
if (fnName.getFunction().equalsIgnoreCase("json_array")
|
||||
|| fnName.getFunction().equalsIgnoreCase("json_object")) {
|
||||
@ -1115,7 +1122,7 @@ public class FunctionCallExpr extends Expr {
|
||||
if (!VectorizedUtil.isVectorized()) {
|
||||
type = getChild(0).type.getMaxResolutionType();
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {type},
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] { type },
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("count_distinct")) {
|
||||
Type compatibleType = this.children.get(0).getType();
|
||||
@ -1128,7 +1135,7 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
}
|
||||
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {compatibleType},
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] { compatibleType },
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.WINDOW_FUNNEL)) {
|
||||
if (fnParams.exprs() == null || fnParams.exprs().size() < 4) {
|
||||
@ -1174,12 +1181,12 @@ public class FunctionCallExpr extends Expr {
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
if (children.get(i).type != Type.BOOLEAN) {
|
||||
throw new AnalysisException("All params of "
|
||||
+ fnName + " function must be boolean");
|
||||
+ fnName + " function must be boolean");
|
||||
}
|
||||
childTypes[i] = children.get(i).type;
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.SEQUENCE_MATCH)
|
||||
|| fnName.getFunction().equalsIgnoreCase(FunctionSet.SEQUENCE_COUNT)) {
|
||||
if (fnParams.exprs() == null || fnParams.exprs().size() < 4) {
|
||||
@ -1205,12 +1212,12 @@ public class FunctionCallExpr extends Expr {
|
||||
for (int i = 2; i < children.size(); i++) {
|
||||
if (children.get(i).type != Type.BOOLEAN) {
|
||||
throw new AnalysisException("The 3th and subsequent params of "
|
||||
+ fnName + " function must be boolean");
|
||||
+ fnName + " function must be boolean");
|
||||
}
|
||||
childTypes[i] = children.get(i).type;
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("if")) {
|
||||
Type[] childTypes = collectChildReturnTypes();
|
||||
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
|
||||
@ -1230,7 +1237,6 @@ public class FunctionCallExpr extends Expr {
|
||||
fn.setReturnType(assignmentCompatibleType);
|
||||
}
|
||||
|
||||
|
||||
} else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
|
||||
fnName.getFunction().toLowerCase())) {
|
||||
// order by elements add as child like windows function. so if we get the
|
||||
@ -1242,7 +1248,7 @@ public class FunctionCallExpr extends Expr {
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (STDDEV_FUNCTION_SET.contains(fnName.getFunction().toLowerCase()) && children.size() == 1
|
||||
&& collectChildReturnTypes()[0].isDecimalV3()) {
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {Type.DOUBLE},
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] { Type.DOUBLE },
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else {
|
||||
// now first find table function in table function sets
|
||||
@ -1254,6 +1260,10 @@ public class FunctionCallExpr extends Expr {
|
||||
throw new AnalysisException(getFunctionNotFoundError(argTypes));
|
||||
}
|
||||
} else {
|
||||
if (ROUND_FUNCTION_SET.contains(fnName.getFunction()) && children.size() == 2
|
||||
&& children.get(0).getType().isDecimalV3() && children.get(1) instanceof IntLiteral) {
|
||||
children.get(1).setType(Type.INT);
|
||||
}
|
||||
// now first find function in built-in functions
|
||||
if (Strings.isNullOrEmpty(fnName.getDb())) {
|
||||
Type[] childTypes = collectChildReturnTypes();
|
||||
@ -1279,8 +1289,8 @@ public class FunctionCallExpr extends Expr {
|
||||
// TODO(gaoxin): ExternalDatabase not implement udf yet.
|
||||
DatabaseIf db = Env.getCurrentEnv().getInternalCatalog().getDbNullable(dbName);
|
||||
if (db != null && (db instanceof Database)) {
|
||||
Function searchDesc =
|
||||
new Function(fnName, Arrays.asList(collectChildReturnTypes()), Type.INVALID, false);
|
||||
Function searchDesc = new Function(fnName, Arrays.asList(collectChildReturnTypes()),
|
||||
Type.INVALID, false);
|
||||
fn = ((Database) db).getFunction(searchDesc,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
}
|
||||
@ -1366,22 +1376,23 @@ public class FunctionCallExpr extends Expr {
|
||||
continue;
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("array")
|
||||
&& (children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()
|
||||
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
|
||||
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
|
||||
continue;
|
||||
} else if ((fnName.getFunction().equalsIgnoreCase("array_min") || fnName.getFunction()
|
||||
.equalsIgnoreCase("array_max") || fnName.getFunction().equalsIgnoreCase("element_at"))
|
||||
&& ((
|
||||
children.get(0).getType().isDecimalV3() && ((ArrayType) args[ix]).getItemType()
|
||||
.isDecimalV3())
|
||||
|| (children.get(0).getType().isDatetimeV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|
||||
|| (children.get(0).getType().isDecimalV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
|
||||
|| (children.get(0).getType().isDatetimeV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|
||||
|| (children.get(0).getType().isDecimalV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
|
||||
continue;
|
||||
} else if (!argTypes[i].matchesType(args[ix]) && !(
|
||||
argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
|
||||
} else if (!argTypes[i].matchesType(args[ix])
|
||||
&& !(argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
|
||||
&& (!fn.getReturnType().isDecimalV3()
|
||||
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) {
|
||||
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3()
|
||||
&& args[ix].isDecimalV3()))) {
|
||||
uncheckedCastChild(args[ix], i);
|
||||
}
|
||||
}
|
||||
@ -1389,28 +1400,30 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
|
||||
/**
|
||||
* The return type of str_to_date depends on whether the time part is included in the format.
|
||||
* The return type of str_to_date depends on whether the time part is included
|
||||
* in the format.
|
||||
* If included, it is datetime, otherwise it is date.
|
||||
* If the format parameter is not constant, the return type will be datetime.
|
||||
* The above judgment has been completed in the FE query planning stage,
|
||||
* so here we directly set the value type to the return type set in the query plan.
|
||||
* so here we directly set the value type to the return type set in the query
|
||||
* plan.
|
||||
*
|
||||
* For example:
|
||||
* A table with one column k1 varchar, and has 2 lines:
|
||||
* "%Y-%m-%d"
|
||||
* "%Y-%m-%d %H:%i:%s"
|
||||
* "%Y-%m-%d"
|
||||
* "%Y-%m-%d %H:%i:%s"
|
||||
* Query:
|
||||
* SELECT str_to_date("2020-09-01", k1) from tbl;
|
||||
* SELECT str_to_date("2020-09-01", k1) from tbl;
|
||||
* Result will be:
|
||||
* 2020-09-01 00:00:00
|
||||
* 2020-09-01 00:00:00
|
||||
* 2020-09-01 00:00:00
|
||||
* 2020-09-01 00:00:00
|
||||
*
|
||||
* Query:
|
||||
* SELECT str_to_date("2020-09-01", "%Y-%m-%d");
|
||||
* SELECT str_to_date("2020-09-01", "%Y-%m-%d");
|
||||
* Return type is DATE
|
||||
*
|
||||
* Query:
|
||||
* SELECT str_to_date("2020-09-01", "%Y-%m-%d %H:%i:%s");
|
||||
* SELECT str_to_date("2020-09-01", "%Y-%m-%d %H:%i:%s");
|
||||
* Return type is DATETIME
|
||||
*/
|
||||
if (fn.getFunctionName().getFunction().equals("str_to_date")) {
|
||||
@ -1442,7 +1455,8 @@ public class FunctionCallExpr extends Expr {
|
||||
|
||||
if (this.type.isDecimalV3() || (this.type.isDatetimeV2()
|
||||
&& !TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase()))) {
|
||||
// TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change it to a double function
|
||||
// TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change
|
||||
// it to a double function
|
||||
this.type = PRECISION_INFER_RULE.getOrDefault(fnName.getFunction(), DEFAULT_PRECISION_INFER_RULE)
|
||||
.apply(children, this.type);
|
||||
}
|
||||
@ -1503,7 +1517,7 @@ public class FunctionCallExpr extends Expr {
|
||||
|
||||
private static int parseNumber(String s) {
|
||||
|
||||
String[] n = s.split(""); //array of strings
|
||||
String[] n = s.split(""); // array of strings
|
||||
int num = 0;
|
||||
for (String value : n) {
|
||||
// validating numbers
|
||||
@ -1589,7 +1603,8 @@ public class FunctionCallExpr extends Expr {
|
||||
+ "] args number is not equal to it's definition");
|
||||
List<Expr> oriParamsExprs = oriExpr.fnParams.exprs();
|
||||
|
||||
// replace origin function params exprs' with input params expr depending on parameter name
|
||||
// replace origin function params exprs' with input params expr depending on
|
||||
// parameter name
|
||||
for (int i = 0; i < oriParamsExprs.size(); i++) {
|
||||
Expr expr = replaceParams(parameters, inputParamsExprs, oriParamsExprs.get(i));
|
||||
oriParamsExprs.set(i, expr);
|
||||
@ -1605,7 +1620,8 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
|
||||
/**
|
||||
* replace origin function expr and it's children with input params exprs depending on parameter name
|
||||
* replace origin function expr and it's children with input params exprs
|
||||
* depending on parameter name
|
||||
*
|
||||
* @param parameters
|
||||
* @param inputParamsExprs
|
||||
@ -1626,7 +1642,8 @@ public class FunctionCallExpr extends Expr {
|
||||
return inputParamsExprs.get(index);
|
||||
}
|
||||
}
|
||||
// Initialize literalExpr without type information, because literalExpr does not save type information
|
||||
// Initialize literalExpr without type information, because literalExpr does not
|
||||
// save type information
|
||||
// when it is persisted, so after fe restart, read the image,
|
||||
// it will be missing type and report an error during analyze.
|
||||
if (oriExpr instanceof LiteralExpr && oriExpr.getType().equals(Type.INVALID)) {
|
||||
@ -1684,7 +1701,8 @@ public class FunctionCallExpr extends Expr {
|
||||
|
||||
@Override
|
||||
protected boolean isConstantImpl() {
|
||||
// TODO: we can't correctly determine const-ness before analyzing 'fn_'. We should
|
||||
// TODO: we can't correctly determine const-ness before analyzing 'fn_'. We
|
||||
// should
|
||||
// rework logic so that we do not call this function on unanalyzed exprs.
|
||||
// Aggregate functions are never constant.
|
||||
if (fn instanceof AggregateFunction || fn == null) {
|
||||
@ -1761,7 +1779,7 @@ public class FunctionCallExpr extends Expr {
|
||||
if (fnName.getFunction().equalsIgnoreCase("sum")) {
|
||||
// Prevent the cast type in vector exec engine
|
||||
Type childType = argTypes.get(0).getMaxResolutionType();
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {childType},
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] { childType },
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
type = fn.getReturnType();
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("count")) {
|
||||
@ -1778,7 +1796,7 @@ public class FunctionCallExpr extends Expr {
|
||||
|| fnName.getFunction().equalsIgnoreCase("avg")
|
||||
|| fnName.getFunction().equalsIgnoreCase("weekOfYear")) {
|
||||
Type childType = argTypes.get(0);
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {childType},
|
||||
fn = getBuiltinFunction(fnName.getFunction(), new Type[] { childType },
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
type = fn.getReturnType();
|
||||
} else {
|
||||
@ -1798,7 +1816,8 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
|
||||
/**
|
||||
* NOTICE: This function only used for Nereids, should not call it if u don't know what it is mean.
|
||||
* NOTICE: This function only used for Nereids, should not call it if u don't
|
||||
* know what it is mean.
|
||||
*/
|
||||
public void setMergeForNereids(boolean isMergeAggFn) {
|
||||
this.isMergeAggFn = isMergeAggFn;
|
||||
|
||||
Reference in New Issue
Block a user