[Bug](DecimalV3) fix decimalv3 functions (#19801)

This commit is contained in:
Gabriel
2023-05-19 14:10:01 +08:00
committed by GitHub
parent fcffb1d3de
commit c4900eb658
4 changed files with 93 additions and 18 deletions

View File

@ -64,8 +64,7 @@ namespace doris::vectorized {
// space-saving algorithm
template <typename T>
struct AggregateFunctionTopNData {
using ColVecType =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void set_paramenters(int input_top_num, int space_expand_rate = 50) {
top_num = input_top_num;
capacity = (uint64_t)top_num * space_expand_rate;
@ -231,8 +230,7 @@ struct AggregateFunctionTopNImplIntInt {
//for topn_array agg
template <typename T, bool has_default_param>
struct AggregateFunctionTopNImplArray {
using ColVecType =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {
@ -256,8 +254,7 @@ struct AggregateFunctionTopNImplArray {
//for topn_weighted agg
template <typename T, bool has_default_param>
struct AggregateFunctionTopNImplWeight {
using ColVecType =
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {

View File

@ -159,7 +159,8 @@ public:
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DataTypePtr type = nullptr;
get_least_supertype(DataTypes {arguments[1], arguments[2]}, &type);
DCHECK_NE(type, nullptr);
DCHECK_NE(type, nullptr) << " arguments[1]: " << arguments[1]->get_name()
<< " arguments[2]: " << arguments[2]->get_name();
return type;
}

View File

@ -995,6 +995,21 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
}
}
public static Type getAssignmentCompatibleType(List<Expr> children) {
Type assignmentCompatibleType = Type.INVALID;
for (int i = 0; i < children.size()
&& (assignmentCompatibleType.isDecimalV3() || assignmentCompatibleType.isDatetimeV2()
|| assignmentCompatibleType.isInvalid()); i++) {
if (children.get(i) instanceof NullLiteral) {
continue;
}
assignmentCompatibleType = assignmentCompatibleType.isInvalid() ? children.get(i).type
: ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, children.get(i).type,
true);
}
return assignmentCompatibleType;
}
// Convert this expr into msg (excluding children), which requires setting
// msg.op as well as the expr-specific field.
protected abstract void toThrift(TExprNode msg);

View File

@ -183,16 +183,36 @@ public class FunctionCallExpr extends Expr {
PRECISION_INFER_RULE.put("if", (children, returnType) -> {
Preconditions.checkArgument(children != null && children.size() == 3);
if (children.get(1).getType().isDecimalV3() && children.get(2).getType().isDecimalV3()) {
return ScalarType.createDecimalV3Type(
Math.max(((ScalarType) children.get(1).getType()).decimalPrecision(),
((ScalarType) children.get(2).getType()).decimalPrecision()),
Math.max(((ScalarType) children.get(1).getType()).decimalScale(),
((ScalarType) children.get(2).getType()).decimalScale()));
return Expr.getAssignmentCompatibleType(children.subList(1, children.size()));
} else if (children.get(1).getType().isDatetimeV2() && children.get(2).getType().isDatetimeV2()) {
return ((ScalarType) children.get(1).getType())
.decimalScale() > ((ScalarType) children.get(2).getType()).decimalScale()
? children.get(1).getType()
: children.get(2).getType();
return Expr.getAssignmentCompatibleType(children.subList(1, children.size()));
} else {
return returnType;
}
});
PRECISION_INFER_RULE.put("ifnull", (children, returnType) -> {
Preconditions.checkArgument(children != null && children.size() == 2);
if (children.get(0).getType().isDecimalV3() && children.get(1).getType().isDecimalV3()) {
return Expr.getAssignmentCompatibleType(children);
} else if (children.get(0).getType().isDatetimeV2() && children.get(1).getType().isDatetimeV2()) {
return Expr.getAssignmentCompatibleType(children);
} else {
return returnType;
}
});
PRECISION_INFER_RULE.put("coalesce", (children, returnType) -> {
boolean isDecimalV3 = true;
boolean isDateTimeV2 = true;
Type assignmentCompatibleType = Expr.getAssignmentCompatibleType(children);
for (Expr child : children) {
isDecimalV3 = isDecimalV3 && child.getType().isDecimalV3();
isDateTimeV2 = isDateTimeV2 && child.getType().isDatetimeV2();
}
if ((isDecimalV3 || isDateTimeV2) && assignmentCompatibleType.isValid()) {
return assignmentCompatibleType;
} else {
return returnType;
}
@ -1342,22 +1362,64 @@ public class FunctionCallExpr extends Expr {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
if (assignmentCompatibleType.isDecimalV3()) {
if (childTypes[1].isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, 1);
}
if (childTypes[2].isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) {
if (assignmentCompatibleType.isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, 2);
}
}
childTypes[0] = Type.BOOLEAN;
childTypes[1] = assignmentCompatibleType;
childTypes[2] = assignmentCompatibleType;
if (childTypes[1].isDecimalV3() && childTypes[2].isDecimalV3()) {
argTypes[1] = assignmentCompatibleType;
argTypes[2] = assignmentCompatibleType;
}
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
if (assignmentCompatibleType.isDatetimeV2()) {
fn.setReturnType(assignmentCompatibleType);
}
} else if (fnName.getFunction().equalsIgnoreCase("ifnull")) {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[0], childTypes[1], true);
if (assignmentCompatibleType.isDecimalV3()) {
if (assignmentCompatibleType.isDecimalV3() && !childTypes[0].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, 0);
}
if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, 1);
}
}
childTypes[0] = assignmentCompatibleType;
childTypes[1] = assignmentCompatibleType;
if (childTypes[1].isDecimalV3() && childTypes[0].isDecimalV3()) {
argTypes[1] = assignmentCompatibleType;
argTypes[0] = assignmentCompatibleType;
}
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase("coalesce") && children.size() > 1) {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = childTypes[0];
for (int i = 1; i < childTypes.length && assignmentCompatibleType.isDecimalV3(); i++) {
assignmentCompatibleType =
ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true);
}
if (assignmentCompatibleType.isDecimalV3()) {
for (int i = 0; i < childTypes.length; i++) {
if (assignmentCompatibleType.isDecimalV3() && !childTypes[i].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, i);
argTypes[i] = assignmentCompatibleType;
}
}
}
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
fnName.getFunction().toLowerCase())) {
// order by elements add as child like windows function. so if we get the