[Bug](DecimalV3) fix decimalv3 functions (#19801)
This commit is contained in:
@ -64,8 +64,7 @@ namespace doris::vectorized {
|
||||
// space-saving algorithm
|
||||
template <typename T>
|
||||
struct AggregateFunctionTopNData {
|
||||
using ColVecType =
|
||||
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
void set_paramenters(int input_top_num, int space_expand_rate = 50) {
|
||||
top_num = input_top_num;
|
||||
capacity = (uint64_t)top_num * space_expand_rate;
|
||||
@ -231,8 +230,7 @@ struct AggregateFunctionTopNImplIntInt {
|
||||
//for topn_array agg
|
||||
template <typename T, bool has_default_param>
|
||||
struct AggregateFunctionTopNImplArray {
|
||||
using ColVecType =
|
||||
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
|
||||
size_t row_num) {
|
||||
if constexpr (has_default_param) {
|
||||
@ -256,8 +254,7 @@ struct AggregateFunctionTopNImplArray {
|
||||
//for topn_weighted agg
|
||||
template <typename T, bool has_default_param>
|
||||
struct AggregateFunctionTopNImplWeight {
|
||||
using ColVecType =
|
||||
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<T>>;
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
|
||||
size_t row_num) {
|
||||
if constexpr (has_default_param) {
|
||||
|
||||
@ -159,7 +159,8 @@ public:
|
||||
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
|
||||
DataTypePtr type = nullptr;
|
||||
get_least_supertype(DataTypes {arguments[1], arguments[2]}, &type);
|
||||
DCHECK_NE(type, nullptr);
|
||||
DCHECK_NE(type, nullptr) << " arguments[1]: " << arguments[1]->get_name()
|
||||
<< " arguments[2]: " << arguments[2]->get_name();
|
||||
return type;
|
||||
}
|
||||
|
||||
|
||||
@ -995,6 +995,21 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
}
|
||||
}
|
||||
|
||||
public static Type getAssignmentCompatibleType(List<Expr> children) {
|
||||
Type assignmentCompatibleType = Type.INVALID;
|
||||
for (int i = 0; i < children.size()
|
||||
&& (assignmentCompatibleType.isDecimalV3() || assignmentCompatibleType.isDatetimeV2()
|
||||
|| assignmentCompatibleType.isInvalid()); i++) {
|
||||
if (children.get(i) instanceof NullLiteral) {
|
||||
continue;
|
||||
}
|
||||
assignmentCompatibleType = assignmentCompatibleType.isInvalid() ? children.get(i).type
|
||||
: ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, children.get(i).type,
|
||||
true);
|
||||
}
|
||||
return assignmentCompatibleType;
|
||||
}
|
||||
|
||||
// Convert this expr into msg (excluding children), which requires setting
|
||||
// msg.op as well as the expr-specific field.
|
||||
protected abstract void toThrift(TExprNode msg);
|
||||
|
||||
@ -183,16 +183,36 @@ public class FunctionCallExpr extends Expr {
|
||||
PRECISION_INFER_RULE.put("if", (children, returnType) -> {
|
||||
Preconditions.checkArgument(children != null && children.size() == 3);
|
||||
if (children.get(1).getType().isDecimalV3() && children.get(2).getType().isDecimalV3()) {
|
||||
return ScalarType.createDecimalV3Type(
|
||||
Math.max(((ScalarType) children.get(1).getType()).decimalPrecision(),
|
||||
((ScalarType) children.get(2).getType()).decimalPrecision()),
|
||||
Math.max(((ScalarType) children.get(1).getType()).decimalScale(),
|
||||
((ScalarType) children.get(2).getType()).decimalScale()));
|
||||
return Expr.getAssignmentCompatibleType(children.subList(1, children.size()));
|
||||
} else if (children.get(1).getType().isDatetimeV2() && children.get(2).getType().isDatetimeV2()) {
|
||||
return ((ScalarType) children.get(1).getType())
|
||||
.decimalScale() > ((ScalarType) children.get(2).getType()).decimalScale()
|
||||
? children.get(1).getType()
|
||||
: children.get(2).getType();
|
||||
return Expr.getAssignmentCompatibleType(children.subList(1, children.size()));
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
});
|
||||
|
||||
PRECISION_INFER_RULE.put("ifnull", (children, returnType) -> {
|
||||
Preconditions.checkArgument(children != null && children.size() == 2);
|
||||
if (children.get(0).getType().isDecimalV3() && children.get(1).getType().isDecimalV3()) {
|
||||
return Expr.getAssignmentCompatibleType(children);
|
||||
} else if (children.get(0).getType().isDatetimeV2() && children.get(1).getType().isDatetimeV2()) {
|
||||
return Expr.getAssignmentCompatibleType(children);
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
});
|
||||
|
||||
PRECISION_INFER_RULE.put("coalesce", (children, returnType) -> {
|
||||
boolean isDecimalV3 = true;
|
||||
boolean isDateTimeV2 = true;
|
||||
|
||||
Type assignmentCompatibleType = Expr.getAssignmentCompatibleType(children);
|
||||
for (Expr child : children) {
|
||||
isDecimalV3 = isDecimalV3 && child.getType().isDecimalV3();
|
||||
isDateTimeV2 = isDateTimeV2 && child.getType().isDatetimeV2();
|
||||
}
|
||||
if ((isDecimalV3 || isDateTimeV2) && assignmentCompatibleType.isValid()) {
|
||||
return assignmentCompatibleType;
|
||||
} else {
|
||||
return returnType;
|
||||
}
|
||||
@ -1342,22 +1362,64 @@ public class FunctionCallExpr extends Expr {
|
||||
Type[] childTypes = collectChildReturnTypes();
|
||||
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
|
||||
if (assignmentCompatibleType.isDecimalV3()) {
|
||||
if (childTypes[1].isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
|
||||
if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
|
||||
uncheckedCastChild(assignmentCompatibleType, 1);
|
||||
}
|
||||
if (childTypes[2].isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) {
|
||||
if (assignmentCompatibleType.isDecimalV3() && !childTypes[2].equals(assignmentCompatibleType)) {
|
||||
uncheckedCastChild(assignmentCompatibleType, 2);
|
||||
}
|
||||
}
|
||||
childTypes[0] = Type.BOOLEAN;
|
||||
childTypes[1] = assignmentCompatibleType;
|
||||
childTypes[2] = assignmentCompatibleType;
|
||||
|
||||
if (childTypes[1].isDecimalV3() && childTypes[2].isDecimalV3()) {
|
||||
argTypes[1] = assignmentCompatibleType;
|
||||
argTypes[2] = assignmentCompatibleType;
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
if (assignmentCompatibleType.isDatetimeV2()) {
|
||||
fn.setReturnType(assignmentCompatibleType);
|
||||
}
|
||||
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("ifnull")) {
|
||||
Type[] childTypes = collectChildReturnTypes();
|
||||
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[0], childTypes[1], true);
|
||||
if (assignmentCompatibleType.isDecimalV3()) {
|
||||
if (assignmentCompatibleType.isDecimalV3() && !childTypes[0].equals(assignmentCompatibleType)) {
|
||||
uncheckedCastChild(assignmentCompatibleType, 0);
|
||||
}
|
||||
if (assignmentCompatibleType.isDecimalV3() && !childTypes[1].equals(assignmentCompatibleType)) {
|
||||
uncheckedCastChild(assignmentCompatibleType, 1);
|
||||
}
|
||||
}
|
||||
childTypes[0] = assignmentCompatibleType;
|
||||
childTypes[1] = assignmentCompatibleType;
|
||||
|
||||
if (childTypes[1].isDecimalV3() && childTypes[0].isDecimalV3()) {
|
||||
argTypes[1] = assignmentCompatibleType;
|
||||
argTypes[0] = assignmentCompatibleType;
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("coalesce") && children.size() > 1) {
|
||||
Type[] childTypes = collectChildReturnTypes();
|
||||
Type assignmentCompatibleType = childTypes[0];
|
||||
for (int i = 1; i < childTypes.length && assignmentCompatibleType.isDecimalV3(); i++) {
|
||||
assignmentCompatibleType =
|
||||
ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true);
|
||||
}
|
||||
if (assignmentCompatibleType.isDecimalV3()) {
|
||||
for (int i = 0; i < childTypes.length; i++) {
|
||||
if (assignmentCompatibleType.isDecimalV3() && !childTypes[i].equals(assignmentCompatibleType)) {
|
||||
uncheckedCastChild(assignmentCompatibleType, i);
|
||||
argTypes[i] = assignmentCompatibleType;
|
||||
}
|
||||
}
|
||||
}
|
||||
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
|
||||
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
|
||||
} else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
|
||||
fnName.getFunction().toLowerCase())) {
|
||||
// order by elements add as child like windows function. so if we get the
|
||||
|
||||
Reference in New Issue
Block a user