diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h b/be/src/vec/aggregate_functions/aggregate_function_topn.h index bf02cc1817..633a36231a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.h +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h @@ -64,8 +64,7 @@ namespace doris::vectorized { // space-saving algorithm template struct AggregateFunctionTopNData { - using ColVecType = - std::conditional_t, ColumnDecimal, ColumnVector>; + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; void set_paramenters(int input_top_num, int space_expand_rate = 50) { top_num = input_top_num; capacity = (uint64_t)top_num * space_expand_rate; @@ -231,8 +230,7 @@ struct AggregateFunctionTopNImplIntInt { //for topn_array agg template struct AggregateFunctionTopNImplArray { - using ColVecType = - std::conditional_t, ColumnDecimal, ColumnVector>; + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { if constexpr (has_default_param) { @@ -256,8 +254,7 @@ struct AggregateFunctionTopNImplArray { //for topn_weighted agg template struct AggregateFunctionTopNImplWeight { - using ColVecType = - std::conditional_t, ColumnDecimal, ColumnVector>; + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, size_t row_num) { if constexpr (has_default_param) { diff --git a/be/src/vec/functions/if.cpp b/be/src/vec/functions/if.cpp index a4e86f8e51..bcdf147b1d 100644 --- a/be/src/vec/functions/if.cpp +++ b/be/src/vec/functions/if.cpp @@ -159,7 +159,8 @@ public: DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { DataTypePtr type = nullptr; get_least_supertype(DataTypes {arguments[1], arguments[2]}, &type); - DCHECK_NE(type, nullptr); + DCHECK_NE(type, nullptr) << " arguments[1]: " << arguments[1]->get_name() + << " arguments[2]: " << arguments[2]->get_name(); return type; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index 2d9922e193..091fe7b9b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -995,6 +995,21 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl } } + public static Type getAssignmentCompatibleType(List children) { + Type assignmentCompatibleType = Type.INVALID; + for (int i = 0; i < children.size() + && (assignmentCompatibleType.isDecimalV3() || assignmentCompatibleType.isDatetimeV2() + || assignmentCompatibleType.isInvalid()); i++) { + if (children.get(i) instanceof NullLiteral) { + continue; + } + assignmentCompatibleType = assignmentCompatibleType.isInvalid() ? children.get(i).type + : ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, children.get(i).type, + true); + } + return assignmentCompatibleType; + } + // Convert this expr into msg (excluding children), which requires setting // msg.op as well as the expr-specific field. protected abstract void toThrift(TExprNode msg); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index a082e8569a..96a45a54c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -183,16 +183,36 @@ public class FunctionCallExpr extends Expr { PRECISION_INFER_RULE.put("if", (children, returnType) -> { Preconditions.checkArgument(children != null && children.size() == 3); if (children.get(1).getType().isDecimalV3() && children.get(2).getType().isDecimalV3()) { - return ScalarType.createDecimalV3Type( - Math.max(((ScalarType) children.get(1).getType()).decimalPrecision(), - ((ScalarType) children.get(2).getType()).decimalPrecision()), - Math.max(((ScalarType) children.get(1).getType()).decimalScale(), - ((ScalarType) children.get(2).getType()).decimalScale())); + return Expr.getAssignmentCompatibleType(children.subList(1, children.size())); } else if (children.get(1).getType().isDatetimeV2() && children.get(2).getType().isDatetimeV2()) { - return ((ScalarType) children.get(1).getType()) - .decimalScale() > ((ScalarType) children.get(2).getType()).decimalScale() - ? children.get(1).getType() - : children.get(2).getType(); + return Expr.getAssignmentCompatibleType(children.subList(1, children.size())); + } else { + return returnType; + } + }); + + PRECISION_INFER_RULE.put("ifnull", (children, returnType) -> { + Preconditions.checkArgument(children != null && children.size() == 2); + if (children.get(0).getType().isDecimalV3() && children.get(1).getType().isDecimalV3()) { + return Expr.getAssignmentCompatibleType(children); + } else if (children.get(0).getType().isDatetimeV2() && children.get(1).getType().isDatetimeV2()) { + return Expr.getAssignmentCompatibleType(children); + } else { + return returnType; + } + }); + + PRECISION_INFER_RULE.put("coalesce", (children, returnType) -> { + boolean isDecimalV3 = true; + boolean isDateTimeV2 = true; + + Type assignmentCompatibleType = Expr.getAssignmentCompatibleType(children); + for (Expr child : children) { + isDecimalV3 = isDecimalV3 && child.getType().isDecimalV3(); + isDateTimeV2 = isDateTimeV2 && child.getType().isDatetimeV2(); + } + if ((isDecimalV3 || isDateTimeV2) && assignmentCompatibleType.isValid()) { + return assignmentCompatibleType; } else { return returnType; } @@ -1342,22 +1362,64 @@ public class FunctionCallExpr extends Expr { Type[] childTypes = collectChildReturnTypes(); Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true); if (assignmentCompatibleType.isDecimalV3()) { - if (childTypes[1].isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) { + if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) { uncheckedCastChild(assignmentCompatibleType, 1); } - if (childTypes[2].isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) { + if (assignmentCompatibleType.isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) { uncheckedCastChild(assignmentCompatibleType, 2); } } childTypes[0] = Type.BOOLEAN; childTypes[1] = assignmentCompatibleType; childTypes[2] = assignmentCompatibleType; + + if (childTypes[1].isDecimalV3() && childTypes[2].isDecimalV3()) { + argTypes[1] = assignmentCompatibleType; + argTypes[2] = assignmentCompatibleType; + } fn = getBuiltinFunction(fnName.getFunction(), childTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); if (assignmentCompatibleType.isDatetimeV2()) { fn.setReturnType(assignmentCompatibleType); } + } else if (fnName.getFunction().equalsIgnoreCase("ifnull")) { + Type[] childTypes = collectChildReturnTypes(); + Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[0], childTypes[1], true); + if (assignmentCompatibleType.isDecimalV3()) { + if (assignmentCompatibleType.isDecimalV3() && !childTypes[0].equals(assignmentCompatibleType)) { + uncheckedCastChild(assignmentCompatibleType, 0); + } + if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) { + uncheckedCastChild(assignmentCompatibleType, 1); + } + } + childTypes[0] = assignmentCompatibleType; + childTypes[1] = assignmentCompatibleType; + + if (childTypes[1].isDecimalV3() && childTypes[0].isDecimalV3()) { + argTypes[1] = assignmentCompatibleType; + argTypes[0] = assignmentCompatibleType; + } + fn = getBuiltinFunction(fnName.getFunction(), childTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); + } else if (fnName.getFunction().equalsIgnoreCase("coalesce") && children.size() > 1) { + Type[] childTypes = collectChildReturnTypes(); + Type assignmentCompatibleType = childTypes[0]; + for (int i = 1; i < childTypes.length && assignmentCompatibleType.isDecimalV3(); i++) { + assignmentCompatibleType = + ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true); + } + if (assignmentCompatibleType.isDecimalV3()) { + for (int i = 0; i < childTypes.length; i++) { + if (assignmentCompatibleType.isDecimalV3() && !childTypes[i].equals(assignmentCompatibleType)) { + uncheckedCastChild(assignmentCompatibleType, i); + argTypes[i] = assignmentCompatibleType; + } + } + } + fn = getBuiltinFunction(fnName.getFunction(), childTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains( fnName.getFunction().toLowerCase())) { // order by elements add as child like windows function. so if we get the