[feature](array_function): add support for array_cum_sum function (#18231)
This commit is contained in:
@ -154,6 +154,19 @@ public class FunctionCallExpr extends Expr {
|
||||
return returnType;
|
||||
}
|
||||
};
|
||||
java.util.function.BiFunction<ArrayList<Expr>, Type, Type> arrayDecimal128ArrayRule
|
||||
= (children, returnType) -> {
|
||||
Preconditions.checkArgument(children != null && children.size() > 0);
|
||||
if (children.get(0).getType().isArrayType() && (
|
||||
((ArrayType) children.get(0).getType()).getItemType().isDecimalV3())) {
|
||||
ArrayType childArrayType = (ArrayType) children.get(0).getType();
|
||||
Type itemType = ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
|
||||
((ScalarType) childArrayType.getItemType()).getScalarScale());
|
||||
return ArrayType.create(itemType, childArrayType.getContainsNull());
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
};
|
||||
PRECISION_INFER_RULE = new HashMap<>();
|
||||
PRECISION_INFER_RULE.put("sum", sumRule);
|
||||
PRECISION_INFER_RULE.put("multi_distinct_sum", sumRule);
|
||||
@ -192,6 +205,7 @@ public class FunctionCallExpr extends Expr {
|
||||
PRECISION_INFER_RULE.put("array_avg", arrayDecimal128Rule);
|
||||
PRECISION_INFER_RULE.put("array_sum", arrayDecimal128Rule);
|
||||
PRECISION_INFER_RULE.put("array_product", arrayDecimal128Rule);
|
||||
PRECISION_INFER_RULE.put("array_cum_sum", arrayDecimal128ArrayRule);
|
||||
PRECISION_INFER_RULE.put("round", roundRule);
|
||||
PRECISION_INFER_RULE.put("round_bankers", roundRule);
|
||||
PRECISION_INFER_RULE.put("ceil", roundRule);
|
||||
@ -1089,6 +1103,7 @@ public class FunctionCallExpr extends Expr {
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_product")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_union")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_except")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_cum_sum")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_intersect")
|
||||
|| fnName.getFunction().equalsIgnoreCase("arrays_overlap")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_concat")) {
|
||||
@ -1554,6 +1569,7 @@ public class FunctionCallExpr extends Expr {
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_popback")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_popfront")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_pushfront")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_cum_sum")
|
||||
|| fnName.getFunction().equalsIgnoreCase("reverse")
|
||||
|| fnName.getFunction().equalsIgnoreCase("%element_slice%")
|
||||
|| fnName.getFunction().equalsIgnoreCase("array_concat")
|
||||
@ -1628,7 +1644,9 @@ public class FunctionCallExpr extends Expr {
|
||||
fn.setReturnType(Type.MAX_DECIMALV2_TYPE);
|
||||
}
|
||||
|
||||
if (this.type.isDecimalV3() || (this.type.isDatetimeV2()
|
||||
if (this.type.isDecimalV3() || (this.type.isArrayType()
|
||||
&& ((ArrayType) this.type).getItemType().isDecimalV3())
|
||||
|| (this.type.isDatetimeV2()
|
||||
&& !TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase()))) {
|
||||
// TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change
|
||||
// it to a double function
|
||||
|
||||
Reference in New Issue
Block a user