diff --git a/be/src/vec/core/decimal_comparison.h b/be/src/vec/core/decimal_comparison.h index 1f2bce4200..4c7cfcf765 100644 --- a/be/src/vec/core/decimal_comparison.h +++ b/be/src/vec/core/decimal_comparison.h @@ -142,9 +142,11 @@ private: Shift shift; if (decimal0 && decimal1) { - auto result_type = decimal_result_type(*decimal0, *decimal1, false, false); - shift.a = result_type.scale_factor_for(*decimal0, false); - shift.b = result_type.scale_factor_for(*decimal1, false); + using Type = std::conditional_t= sizeof(U), T, U>; + auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false); + const DataTypeDecimal* result_type = check_decimal(*type_ptr); + shift.a = result_type->scale_factor_for(*decimal0, false); + shift.b = result_type->scale_factor_for(*decimal1, false); } else if (decimal0) { shift.b = decimal0->get_scale_multiplier(); } else if (decimal1) { diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index ea29954293..c8e08303a1 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -219,56 +219,30 @@ private: }; template -typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal> decimal_result_type( - const DataTypeDecimal& tx, const DataTypeDecimal& ty, bool is_multiply, - bool is_divide) { +DataTypePtr decimal_result_type(const DataTypeDecimal& tx, const DataTypeDecimal& ty, + bool is_multiply, bool is_divide, bool is_plus_minus) { + using Type = std::conditional_t= sizeof(U), T, U>; if constexpr (IsDecimalV2 && IsDecimalV2) { - return DataTypeDecimal(max_decimal_precision(), 9); + return std::make_shared>((max_decimal_precision(), 9)); } else { - UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); + UInt32 scale = std::max(tx.get_scale(), ty.get_scale()); + auto precision = max_decimal_precision(); + + size_t multiply_precision = tx.get_precision() + ty.get_precision(); + size_t divide_precision = tx.get_precision() + ty.get_scale(); + size_t plus_minus_precision = + std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) + + scale; if (is_multiply) { scale = tx.get_scale() + ty.get_scale(); + precision = std::min(multiply_precision, max_decimal_precision()); } else if (is_divide) { scale = tx.get_scale(); + precision = std::min(divide_precision, max_decimal_precision()); + } else if (is_plus_minus) { + precision = std::min(plus_minus_precision, max_decimal_precision()); } - return DataTypeDecimal(max_decimal_precision(), scale); - } -} - -template -typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal> decimal_result_type( - const DataTypeDecimal& tx, const DataTypeDecimal& ty, bool is_multiply, - bool is_divide) { - if constexpr (IsDecimalV2 && IsDecimalV2) { - return DataTypeDecimal(max_decimal_precision(), 9); - } else { - UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); - if (is_multiply) { - scale = tx.get_scale() + ty.get_scale(); - } else if (is_divide) { - scale = tx.get_scale(); - } - return DataTypeDecimal(max_decimal_precision(), scale); - } -} - -template -const DataTypeDecimal decimal_result_type(const DataTypeDecimal& tx, const DataTypeNumber&, - bool, bool) { - if constexpr (IsDecimalV2 && IsDecimalV2) { - return DataTypeDecimal(max_decimal_precision(), 9); - } else { - return DataTypeDecimal(max_decimal_precision(), tx.get_scale()); - } -} - -template -const DataTypeDecimal decimal_result_type(const DataTypeNumber&, const DataTypeDecimal& ty, - bool, bool) { - if constexpr (IsDecimalV2 && IsDecimalV2) { - return DataTypeDecimal(max_decimal_precision(), 9); - } else { - return DataTypeDecimal(max_decimal_precision(), ty.get_scale()); + return create_decimal(precision, scale, false); } } diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 5c98e72486..2a8da748e3 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -730,10 +730,9 @@ public: if constexpr (!std::is_same_v) { if constexpr (IsDataTypeDecimal && IsDataTypeDecimal) { - ResultDataType result_type = decimal_result_type( - left, right, OpTraits::is_multiply, OpTraits::is_division); - type_res = std::make_shared(result_type.get_precision(), - result_type.get_scale()); + type_res = decimal_result_type(left, right, OpTraits::is_multiply, + OpTraits::is_division, + OpTraits::is_plus_minus); } else if constexpr (IsDataTypeDecimal) { type_res = std::make_shared(left.get_precision(), left.get_scale()); diff --git a/regression-test/data/decimalv3/test_decimalv3.out b/regression-test/data/decimalv3/test_decimalv3.out index 1bb8b045c0..f8d56b4c41 100644 --- a/regression-test/data/decimalv3/test_decimalv3.out +++ b/regression-test/data/decimalv3/test_decimalv3.out @@ -2,3 +2,6 @@ -- !decimalv3 -- 100.000000000000000000 +-- !decimalv3 -- +100.00000000000000000000 + diff --git a/regression-test/suites/decimalv3/test_decimalv3.groovy b/regression-test/suites/decimalv3/test_decimalv3.groovy index 374e554b93..8b8b010240 100644 --- a/regression-test/suites/decimalv3/test_decimalv3.groovy +++ b/regression-test/suites/decimalv3/test_decimalv3.groovy @@ -26,4 +26,5 @@ suite("test_decimalv3") { sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5" qt_decimalv3 "select * from test5_v" + qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5" }