[feature](array_function): add support for array_cum_sum function (#18231)

This commit is contained in:
brody715
2023-04-27 09:57:13 +08:00
committed by GitHub
parent 6eb12640a1
commit 20395ce501
14 changed files with 526 additions and 1 deletions

View File

@ -154,6 +154,19 @@ public class FunctionCallExpr extends Expr {
return returnType;
}
};
java.util.function.BiFunction<ArrayList<Expr>, Type, Type> arrayDecimal128ArrayRule
= (children, returnType) -> {
Preconditions.checkArgument(children != null && children.size() > 0);
if (children.get(0).getType().isArrayType() && (
((ArrayType) children.get(0).getType()).getItemType().isDecimalV3())) {
ArrayType childArrayType = (ArrayType) children.get(0).getType();
Type itemType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
((ScalarType) childArrayType.getItemType()).getScalarScale());
return ArrayType.create(itemType, childArrayType.getContainsNull());
} else {
return returnType;
}
};
PRECISION_INFER_RULE = new HashMap<>();
PRECISION_INFER_RULE.put("sum", sumRule);
PRECISION_INFER_RULE.put("multi_distinct_sum", sumRule);
@ -192,6 +205,7 @@ public class FunctionCallExpr extends Expr {
PRECISION_INFER_RULE.put("array_avg", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("array_sum", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("array_product", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("array_cum_sum", arrayDecimal128ArrayRule);
PRECISION_INFER_RULE.put("round", roundRule);
PRECISION_INFER_RULE.put("round_bankers", roundRule);
PRECISION_INFER_RULE.put("ceil", roundRule);
@ -1089,6 +1103,7 @@ public class FunctionCallExpr extends Expr {
|| fnName.getFunction().equalsIgnoreCase("array_product")
|| fnName.getFunction().equalsIgnoreCase("array_union")
|| fnName.getFunction().equalsIgnoreCase("array_except")
|| fnName.getFunction().equalsIgnoreCase("array_cum_sum")
|| fnName.getFunction().equalsIgnoreCase("array_intersect")
|| fnName.getFunction().equalsIgnoreCase("arrays_overlap")
|| fnName.getFunction().equalsIgnoreCase("array_concat")) {
@ -1554,6 +1569,7 @@ public class FunctionCallExpr extends Expr {
|| fnName.getFunction().equalsIgnoreCase("array_popback")
|| fnName.getFunction().equalsIgnoreCase("array_popfront")
|| fnName.getFunction().equalsIgnoreCase("array_pushfront")
|| fnName.getFunction().equalsIgnoreCase("array_cum_sum")
|| fnName.getFunction().equalsIgnoreCase("reverse")
|| fnName.getFunction().equalsIgnoreCase("%element_slice%")
|| fnName.getFunction().equalsIgnoreCase("array_concat")
@ -1628,7 +1644,9 @@ public class FunctionCallExpr extends Expr {
fn.setReturnType(Type.MAX_DECIMALV2_TYPE);
}
if (this.type.isDecimalV3() || (this.type.isDatetimeV2()
if (this.type.isDecimalV3() || (this.type.isArrayType()
&& ((ArrayType) this.type).getItemType().isDecimalV3())
|| (this.type.isDatetimeV2()
&& !TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase()))) {
// TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change
// it to a double function