[Bug](Decimalv3) coredump of decimalv3 multiply (#15452)

This commit is contained in:
HappenLee
2022-12-29 15:35:17 +08:00
committed by GitHub
parent 89e2fb4301
commit c22ba8e160
5 changed files with 29 additions and 50 deletions

View File

@ -142,9 +142,11 @@ private:
Shift shift;
if (decimal0 && decimal1) {
auto result_type = decimal_result_type(*decimal0, *decimal1, false, false);
shift.a = result_type.scale_factor_for(*decimal0, false);
shift.b = result_type.scale_factor_for(*decimal1, false);
using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false);
const DataTypeDecimal<Type>* result_type = check_decimal<Type>(*type_ptr);
shift.a = result_type->scale_factor_for(*decimal0, false);
shift.b = result_type->scale_factor_for(*decimal1, false);
} else if (decimal0) {
shift.b = decimal0->get_scale_multiplier();
} else if (decimal1) {

View File

@ -219,56 +219,30 @@ private:
};
template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimal_result_type(
const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
bool is_divide) {
DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty,
bool is_multiply, bool is_divide, bool is_plus_minus) {
using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
return std::make_shared<DataTypeDecimal<Type>>((max_decimal_precision<T>(), 9));
} else {
UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
UInt32 scale = std::max(tx.get_scale(), ty.get_scale());
auto precision = max_decimal_precision<Type>();
size_t multiply_precision = tx.get_precision() + ty.get_precision();
size_t divide_precision = tx.get_precision() + ty.get_scale();
size_t plus_minus_precision =
std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) +
scale;
if (is_multiply) {
scale = tx.get_scale() + ty.get_scale();
precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>());
} else if (is_divide) {
scale = tx.get_scale();
precision = std::min(divide_precision, max_decimal_precision<Decimal128I>());
} else if (is_plus_minus) {
precision = std::min(plus_minus_precision, max_decimal_precision<Decimal128I>());
}
return DataTypeDecimal<T>(max_decimal_precision<T>(), scale);
}
}
template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimal_result_type(
const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
bool is_divide) {
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
} else {
UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
if (is_multiply) {
scale = tx.get_scale() + ty.get_scale();
} else if (is_divide) {
scale = tx.get_scale();
}
return DataTypeDecimal<U>(max_decimal_precision<U>(), scale);
}
}
template <typename T, typename U>
const DataTypeDecimal<T> decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeNumber<U>&,
bool, bool) {
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
} else {
return DataTypeDecimal<T>(max_decimal_precision<T>(), tx.get_scale());
}
}
template <typename T, typename U>
const DataTypeDecimal<U> decimal_result_type(const DataTypeNumber<T>&, const DataTypeDecimal<U>& ty,
bool, bool) {
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
} else {
return DataTypeDecimal<U>(max_decimal_precision<U>(), ty.get_scale());
return create_decimal(precision, scale, false);
}
}

View File

@ -730,10 +730,9 @@ public:
if constexpr (!std::is_same_v<ResultDataType, InvalidType>) {
if constexpr (IsDataTypeDecimal<LeftDataType> &&
IsDataTypeDecimal<RightDataType>) {
ResultDataType result_type = decimal_result_type(
left, right, OpTraits::is_multiply, OpTraits::is_division);
type_res = std::make_shared<ResultDataType>(result_type.get_precision(),
result_type.get_scale());
type_res = decimal_result_type(left, right, OpTraits::is_multiply,
OpTraits::is_division,
OpTraits::is_plus_minus);
} else if constexpr (IsDataTypeDecimal<LeftDataType>) {
type_res = std::make_shared<LeftDataType>(left.get_precision(),
left.get_scale());

View File

@ -2,3 +2,6 @@
-- !decimalv3 --
100.000000000000000000
-- !decimalv3 --
100.00000000000000000000

View File

@ -26,4 +26,5 @@ suite("test_decimalv3") {
sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5"
qt_decimalv3 "select * from test5_v"
qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5"
}