[Feature](array) Support array<decimalv3> data type (#16640)

This commit is contained in:
abmdocrt
2023-03-13 10:48:13 +08:00
committed by GitHub
parent 3a6c0e7867
commit 55c42da511
24 changed files with 576 additions and 90 deletions

View File

@ -21,6 +21,7 @@
package org.apache.doris.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.FunctionSet;
@ -2201,11 +2202,29 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return Type.DECIMAL128;
} else if (type.getPrimitiveType() == PrimitiveType.DATETIMEV2) {
return Type.DATETIMEV2;
} else if (type.getPrimitiveType() == PrimitiveType.ARRAY) {
return getActualArrayType((ArrayType) type);
}
return type;
}).toArray(Type[]::new);
}
private ArrayType getActualArrayType(ArrayType originArrayType) {
// Now we only support single-level array nesting.
// Multi-layer array nesting will be supported in the future.
Type type = originArrayType.getItemType();
if (type.getPrimitiveType() == PrimitiveType.DECIMAL32) {
return new ArrayType(Type.DECIMAL32);
} else if (type.getPrimitiveType() == PrimitiveType.DECIMAL64) {
return new ArrayType(Type.DECIMAL64);
} else if (type.getPrimitiveType() == PrimitiveType.DECIMAL128) {
return new ArrayType(Type.DECIMAL128);
} else if (type.getPrimitiveType() == PrimitiveType.DATETIMEV2) {
return new ArrayType(Type.DATETIMEV2);
}
return originArrayType;
}
public boolean refToCountStar() {
if (this instanceof SlotRef) {
SlotRef slotRef = (SlotRef) this;

View File

@ -137,6 +137,17 @@ public class FunctionCallExpr extends Expr {
return returnType;
}
};
java.util.function.BiFunction<ArrayList<Expr>, Type, Type> arrayDecimal128Rule
= (children, returnType) -> {
Preconditions.checkArgument(children != null && children.size() > 0);
if (children.get(0).getType().isArrayType() && (
((ArrayType) children.get(0).getType()).getItemType().isDecimalV3())) {
return ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
((ScalarType) ((ArrayType) children.get(0).getType()).getItemType()).getScalarScale());
} else {
return returnType;
}
};
PRECISION_INFER_RULE = new HashMap<>();
PRECISION_INFER_RULE.put("sum", sumRule);
PRECISION_INFER_RULE.put("multi_distinct_sum", sumRule);
@ -172,7 +183,9 @@ public class FunctionCallExpr extends Expr {
PRECISION_INFER_RULE.put("array_max", arrayDateTimeV2OrDecimalV3Rule);
PRECISION_INFER_RULE.put("element_at", arrayDateTimeV2OrDecimalV3Rule);
PRECISION_INFER_RULE.put("%element_extract%", arrayDateTimeV2OrDecimalV3Rule);
PRECISION_INFER_RULE.put("array_avg", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("array_sum", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("array_product", arrayDecimal128Rule);
PRECISION_INFER_RULE.put("round", roundRule);
PRECISION_INFER_RULE.put("round_bankers", roundRule);
PRECISION_INFER_RULE.put("ceil", roundRule);
@ -1382,7 +1395,9 @@ public class FunctionCallExpr extends Expr {
for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
// For varargs, we must compare with the last type in callArgs.argTypes.
int ix = Math.min(args.length - 1, i);
if (fnName.getFunction().equalsIgnoreCase("money_format")
if ((fnName.getFunction().equalsIgnoreCase("money_format") || fnName.getFunction()
.equalsIgnoreCase("histogram")
|| fnName.getFunction().equalsIgnoreCase("hist"))
&& children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
continue;
} else if (fnName.getFunction().equalsIgnoreCase("array")
@ -1399,6 +1414,25 @@ public class FunctionCallExpr extends Expr {
|| (children.get(0).getType().isDecimalV2()
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
continue;
} else if ((fnName.getFunction().equalsIgnoreCase("array_distinct") || fnName.getFunction()
.equalsIgnoreCase("array_remove") || fnName.getFunction().equalsIgnoreCase("array_sort")
|| fnName.getFunction().equalsIgnoreCase("array_overlap")
|| fnName.getFunction().equalsIgnoreCase("array_union")
|| fnName.getFunction().equalsIgnoreCase("array_intersect")
|| fnName.getFunction().equalsIgnoreCase("array_compact")
|| fnName.getFunction().equalsIgnoreCase("array_slice")
|| fnName.getFunction().equalsIgnoreCase("array_popback")
|| fnName.getFunction().equalsIgnoreCase("array_popfront")
|| fnName.getFunction().equalsIgnoreCase("reverse")
|| fnName.getFunction().equalsIgnoreCase("%element_slice%")
|| fnName.getFunction().equalsIgnoreCase("array_concat")
|| fnName.getFunction().equalsIgnoreCase("array_except"))
&& ((args[ix].isDecimalV3())
|| (children.get(0).getType().isArrayType()
&& (((ArrayType) children.get(0).getType()).getItemType().isDecimalV3())
&& (args[ix].isArrayType())
&& ((ArrayType) args[ix]).getItemType().isDecimalV3()))) {
continue;
} else if (!argTypes[i].matchesType(args[ix]) && !(
argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
&& (!fn.getReturnType().isDecimalV3()