[Bug](Decimalv3) coredump of decimalv3 multiply (#15452)
This commit is contained in:
@ -142,9 +142,11 @@ private:
|
||||
|
||||
Shift shift;
|
||||
if (decimal0 && decimal1) {
|
||||
auto result_type = decimal_result_type(*decimal0, *decimal1, false, false);
|
||||
shift.a = result_type.scale_factor_for(*decimal0, false);
|
||||
shift.b = result_type.scale_factor_for(*decimal1, false);
|
||||
using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
|
||||
auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false);
|
||||
const DataTypeDecimal<Type>* result_type = check_decimal<Type>(*type_ptr);
|
||||
shift.a = result_type->scale_factor_for(*decimal0, false);
|
||||
shift.b = result_type->scale_factor_for(*decimal1, false);
|
||||
} else if (decimal0) {
|
||||
shift.b = decimal0->get_scale_multiplier();
|
||||
} else if (decimal1) {
|
||||
|
||||
@ -219,56 +219,30 @@ private:
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimal_result_type(
|
||||
const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
|
||||
bool is_divide) {
|
||||
DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty,
|
||||
bool is_multiply, bool is_divide, bool is_plus_minus) {
|
||||
using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
|
||||
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
|
||||
return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
|
||||
return std::make_shared<DataTypeDecimal<Type>>((max_decimal_precision<T>(), 9));
|
||||
} else {
|
||||
UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
|
||||
UInt32 scale = std::max(tx.get_scale(), ty.get_scale());
|
||||
auto precision = max_decimal_precision<Type>();
|
||||
|
||||
size_t multiply_precision = tx.get_precision() + ty.get_precision();
|
||||
size_t divide_precision = tx.get_precision() + ty.get_scale();
|
||||
size_t plus_minus_precision =
|
||||
std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) +
|
||||
scale;
|
||||
if (is_multiply) {
|
||||
scale = tx.get_scale() + ty.get_scale();
|
||||
precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>());
|
||||
} else if (is_divide) {
|
||||
scale = tx.get_scale();
|
||||
precision = std::min(divide_precision, max_decimal_precision<Decimal128I>());
|
||||
} else if (is_plus_minus) {
|
||||
precision = std::min(plus_minus_precision, max_decimal_precision<Decimal128I>());
|
||||
}
|
||||
return DataTypeDecimal<T>(max_decimal_precision<T>(), scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimal_result_type(
|
||||
const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
|
||||
bool is_divide) {
|
||||
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
|
||||
return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
|
||||
} else {
|
||||
UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
|
||||
if (is_multiply) {
|
||||
scale = tx.get_scale() + ty.get_scale();
|
||||
} else if (is_divide) {
|
||||
scale = tx.get_scale();
|
||||
}
|
||||
return DataTypeDecimal<U>(max_decimal_precision<U>(), scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
const DataTypeDecimal<T> decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeNumber<U>&,
|
||||
bool, bool) {
|
||||
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
|
||||
return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
|
||||
} else {
|
||||
return DataTypeDecimal<T>(max_decimal_precision<T>(), tx.get_scale());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
const DataTypeDecimal<U> decimal_result_type(const DataTypeNumber<T>&, const DataTypeDecimal<U>& ty,
|
||||
bool, bool) {
|
||||
if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
|
||||
return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
|
||||
} else {
|
||||
return DataTypeDecimal<U>(max_decimal_precision<U>(), ty.get_scale());
|
||||
return create_decimal(precision, scale, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -730,10 +730,9 @@ public:
|
||||
if constexpr (!std::is_same_v<ResultDataType, InvalidType>) {
|
||||
if constexpr (IsDataTypeDecimal<LeftDataType> &&
|
||||
IsDataTypeDecimal<RightDataType>) {
|
||||
ResultDataType result_type = decimal_result_type(
|
||||
left, right, OpTraits::is_multiply, OpTraits::is_division);
|
||||
type_res = std::make_shared<ResultDataType>(result_type.get_precision(),
|
||||
result_type.get_scale());
|
||||
type_res = decimal_result_type(left, right, OpTraits::is_multiply,
|
||||
OpTraits::is_division,
|
||||
OpTraits::is_plus_minus);
|
||||
} else if constexpr (IsDataTypeDecimal<LeftDataType>) {
|
||||
type_res = std::make_shared<LeftDataType>(left.get_precision(),
|
||||
left.get_scale());
|
||||
|
||||
@ -2,3 +2,6 @@
|
||||
-- !decimalv3 --
|
||||
100.000000000000000000
|
||||
|
||||
-- !decimalv3 --
|
||||
100.00000000000000000000
|
||||
|
||||
|
||||
@ -26,4 +26,5 @@ suite("test_decimalv3") {
|
||||
sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5"
|
||||
|
||||
qt_decimalv3 "select * from test5_v"
|
||||
qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user