bp #34689 Co-authored-by: TengJianPing <18241664+jacktengg@users.noreply.github.com>
This commit is contained in:
@ -220,12 +220,17 @@ struct BinaryOperationImpl {
|
||||
}
|
||||
};
|
||||
|
||||
#define THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(left_value, op_name, right_value, result_value, \
|
||||
result_type_name) \
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR, \
|
||||
"Arithmetic overflow: {} {} {} = {}, result type: {}", left_value, op_name, \
|
||||
right_value, result_value, result_type_name)
|
||||
/// Binary operations for Decimals need scale args
|
||||
/// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2);
|
||||
/// * no agrs scale. ScaleR = Scale1 + Scale2;
|
||||
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
|
||||
template <typename LeftDataType, typename RightDataType,
|
||||
template <typename, typename> typename Operation, typename ResultType,
|
||||
template <typename LeftDataType, typename RightDataType, typename ResultDataType,
|
||||
template <typename, typename> typename Operation, typename Name, typename ResultType,
|
||||
bool is_to_null_type, bool check_overflow>
|
||||
struct DecimalBinaryOperation {
|
||||
using A = typename LeftDataType::FieldType;
|
||||
@ -246,8 +251,9 @@ private:
|
||||
|
||||
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
|
||||
const typename Traits::ArrayB::value_type* __restrict b,
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
typename ArrayC::value_type* c, const LeftDataType& type_left,
|
||||
const RightDataType& type_right, const ResultDataType& type_result,
|
||||
size_t size, const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
|
||||
if constexpr (OpTraits::is_multiply && IsDecimalV2<A> && IsDecimalV2<B> &&
|
||||
@ -259,7 +265,8 @@ private:
|
||||
[&](auto need_adjust_scale) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
|
||||
a[i], b[i], max_result_number, scale_diff_multiplier));
|
||||
a[i], b[i], type_left, type_right, type_result,
|
||||
max_result_number, scale_diff_multiplier));
|
||||
}
|
||||
},
|
||||
make_bool_variant(need_adjust_scale && check_overflow));
|
||||
@ -310,8 +317,9 @@ private:
|
||||
}
|
||||
|
||||
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
typename ArrayC::value_type* c, const LeftDataType& type_left,
|
||||
const RightDataType& type_right, const ResultDataType& type_result,
|
||||
size_t size, const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
static_assert(!OpTraits::is_division);
|
||||
|
||||
@ -320,7 +328,8 @@ private:
|
||||
[&](auto need_adjust_scale) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
|
||||
a[i], b, max_result_number, scale_diff_multiplier));
|
||||
a[i], b, type_left, type_right, type_result, max_result_number,
|
||||
scale_diff_multiplier));
|
||||
}
|
||||
},
|
||||
make_bool_variant(need_adjust_scale));
|
||||
@ -343,15 +352,17 @@ private:
|
||||
}
|
||||
|
||||
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
typename ArrayC::value_type* c, const LeftDataType& type_left,
|
||||
const RightDataType& type_right, const ResultDataType& type_result,
|
||||
size_t size, const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
bool need_adjust_scale = scale_diff_multiplier.value > 1;
|
||||
std::visit(
|
||||
[&](auto need_adjust_scale) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
|
||||
a, b[i], max_result_number, scale_diff_multiplier));
|
||||
a, b[i], type_left, type_right, type_result, max_result_number,
|
||||
scale_diff_multiplier));
|
||||
}
|
||||
},
|
||||
make_bool_variant(need_adjust_scale));
|
||||
@ -373,9 +384,13 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
static ResultType constant_constant(A a, B b, const ResultType& max_result_number,
|
||||
static ResultType constant_constant(A a, B b, const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultDataType& type_result,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
return ResultType(apply<true>(a, b, max_result_number, scale_diff_multiplier));
|
||||
return ResultType(apply<true>(a, b, type_left, type_right, type_result, max_result_number,
|
||||
scale_diff_multiplier));
|
||||
}
|
||||
|
||||
static ResultType constant_constant(A a, B b, UInt8& is_null,
|
||||
@ -393,9 +408,12 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
static ColumnPtr adapt_decimal_constant_constant(A a, B b, const ResultType& max_result_number,
|
||||
static ColumnPtr adapt_decimal_constant_constant(A a, B b, const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier,
|
||||
DataTypePtr res_data_type) {
|
||||
auto type_result = assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto column_result = ColumnDecimal<ResultType>::create(
|
||||
1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
|
||||
|
||||
@ -410,15 +428,19 @@ public:
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
column_result->get_element(0) =
|
||||
constant_constant(a, b, max_result_number, scale_diff_multiplier);
|
||||
constant_constant(a, b, type_left, type_right, type_result, max_result_number,
|
||||
scale_diff_multiplier);
|
||||
return column_result;
|
||||
}
|
||||
}
|
||||
|
||||
static ColumnPtr adapt_decimal_vector_constant(ColumnPtr column_left, B b,
|
||||
const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier,
|
||||
DataTypePtr res_data_type) {
|
||||
auto type_result = assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto column_left_ptr = check_and_get_column<typename Traits::ColumnVectorA>(column_left);
|
||||
auto column_result = ColumnDecimal<ResultType>::create(
|
||||
column_left->size(),
|
||||
@ -436,15 +458,19 @@ public:
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
|
||||
column_left->size(), max_result_number, scale_diff_multiplier);
|
||||
type_left, type_right, type_result, column_left->size(),
|
||||
max_result_number, scale_diff_multiplier);
|
||||
return column_result;
|
||||
}
|
||||
}
|
||||
|
||||
static ColumnPtr adapt_decimal_constant_vector(A a, ColumnPtr column_right,
|
||||
const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier,
|
||||
DataTypePtr res_data_type) {
|
||||
auto type_result = assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto column_right_ptr = check_and_get_column<typename Traits::ColumnVectorB>(column_right);
|
||||
auto column_result = ColumnDecimal<ResultType>::create(
|
||||
column_right->size(),
|
||||
@ -463,13 +489,15 @@ public:
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
constant_vector(a, column_right_ptr->get_data().data(),
|
||||
column_result->get_data().data(), column_right->size(),
|
||||
max_result_number, scale_diff_multiplier);
|
||||
column_result->get_data().data(), type_left, type_right, type_result,
|
||||
column_right->size(), max_result_number, scale_diff_multiplier);
|
||||
return column_result;
|
||||
}
|
||||
}
|
||||
|
||||
static ColumnPtr adapt_decimal_vector_vector(ColumnPtr column_left, ColumnPtr column_right,
|
||||
const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier,
|
||||
DataTypePtr res_data_type) {
|
||||
@ -494,8 +522,8 @@ public:
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
|
||||
column_result->get_data().data(), column_left->size(), max_result_number,
|
||||
scale_diff_multiplier);
|
||||
column_result->get_data().data(), type_left, type_right, type_result,
|
||||
column_left->size(), max_result_number, scale_diff_multiplier);
|
||||
return column_result;
|
||||
}
|
||||
}
|
||||
@ -504,6 +532,9 @@ private:
|
||||
/// there's implicit type conversion here
|
||||
template <bool need_adjust_scale>
|
||||
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
|
||||
const LeftDataType& type_left,
|
||||
const RightDataType& type_right,
|
||||
const ResultDataType& type_result,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
|
||||
@ -512,7 +543,10 @@ private:
|
||||
if constexpr (check_overflow) {
|
||||
auto res = Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
|
||||
if (res > max_result_number.value || res < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR, "Arithmetic overflow");
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
DecimalV2Value(a).to_string(), Name::name,
|
||||
DecimalV2Value(b).to_string(), DecimalV2Value(res).to_string(),
|
||||
ResultDataType {}.get_name());
|
||||
}
|
||||
return res;
|
||||
} else {
|
||||
@ -524,8 +558,13 @@ private:
|
||||
// TODO handle overflow gracefully
|
||||
if (UNLIKELY(Op::template apply<NativeResultType>(a, b, res))) {
|
||||
if constexpr (OpTraits::is_plus_minus) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
auto result_str =
|
||||
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
|
||||
type_result.get_scale()}
|
||||
.to_string(Decimal256(res));
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
|
||||
result_str, type_result.get_name());
|
||||
}
|
||||
// multiply
|
||||
if constexpr (std::is_same_v<NativeResultType, __int128>) {
|
||||
@ -543,14 +582,24 @@ private:
|
||||
// check if final result is overflow
|
||||
if (res256 > wide::Int256(max_result_number.value) ||
|
||||
res256 < wide::Int256(-max_result_number.value)) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
auto result_str =
|
||||
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
|
||||
type_result.get_scale()}
|
||||
.to_string(Decimal256(res256));
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
type_left.to_string(A(a)), Name::name,
|
||||
type_right.to_string(B(b)), result_str, type_result.get_name());
|
||||
} else {
|
||||
res = res256;
|
||||
}
|
||||
} else {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
auto result_str =
|
||||
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
|
||||
type_result.get_scale()}
|
||||
.to_string(Decimal256(res));
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
|
||||
result_str, type_result.get_name());
|
||||
}
|
||||
} else {
|
||||
// round to final result precision
|
||||
@ -564,8 +613,13 @@ private:
|
||||
}
|
||||
}
|
||||
if (res > max_result_number.value || res < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
auto result_str =
|
||||
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
|
||||
type_result.get_scale()}
|
||||
.to_string(Decimal256(res));
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
|
||||
result_str, type_result.get_name());
|
||||
}
|
||||
}
|
||||
return res;
|
||||
@ -597,19 +651,25 @@ private:
|
||||
if constexpr (std::is_same_v<ANS_TYPE, DecimalV2Value>) {
|
||||
if (ans.value() > max_result_number.value ||
|
||||
ans.value() < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
DecimalV2Value(a).to_string(), Name::name,
|
||||
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
|
||||
ResultDataType {}.get_name());
|
||||
}
|
||||
} else if constexpr (IsDecimalNumber<ANS_TYPE>) {
|
||||
if (ans.value > max_result_number.value ||
|
||||
ans.value < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
DecimalV2Value(a).to_string(), Name::name,
|
||||
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
|
||||
ResultDataType {}.get_name());
|
||||
}
|
||||
} else {
|
||||
if (ans > max_result_number.value || ans < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
|
||||
DecimalV2Value(a).to_string(), Name::name,
|
||||
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
|
||||
ResultDataType {}.get_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -714,7 +774,7 @@ struct BinaryOperationTraits {
|
||||
};
|
||||
|
||||
template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType,
|
||||
template <typename, typename> class Operation, bool is_to_null_type,
|
||||
template <typename, typename> class Operation, typename Name, bool is_to_null_type,
|
||||
bool check_overflow_for_decimal>
|
||||
struct ConstOrVectorAdapter {
|
||||
static constexpr bool result_is_decimal =
|
||||
@ -727,8 +787,8 @@ struct ConstOrVectorAdapter {
|
||||
|
||||
using OperationImpl = std::conditional_t<
|
||||
IsDataTypeDecimal<ResultDataType>,
|
||||
DecimalBinaryOperation<LeftDataType, RightDataType, Operation, ResultType,
|
||||
is_to_null_type, check_overflow_for_decimal>,
|
||||
DecimalBinaryOperation<LeftDataType, RightDataType, ResultDataType, Operation, Name,
|
||||
ResultType, is_to_null_type, check_overflow_for_decimal>,
|
||||
BinaryOperationImpl<A, B, Operation<A, B>, is_to_null_type, ResultType>>;
|
||||
|
||||
static ColumnPtr execute(ColumnPtr column_left, ColumnPtr column_right,
|
||||
@ -785,8 +845,8 @@ private:
|
||||
|
||||
column_result = OperationImpl::adapt_decimal_constant_constant(
|
||||
column_left_ptr->template get_value<A>(),
|
||||
column_right_ptr->template get_value<B>(), max_and_multiplier.first,
|
||||
max_and_multiplier.second, res_data_type);
|
||||
column_right_ptr->template get_value<B>(), type_left, type_right,
|
||||
max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
|
||||
} else {
|
||||
column_result = OperationImpl::adapt_normal_constant_constant(
|
||||
@ -808,8 +868,8 @@ private:
|
||||
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto max_and_multiplier = get_max_and_multiplier(type_left, type_right, type_result);
|
||||
return OperationImpl::adapt_decimal_vector_constant(
|
||||
column_left->get_ptr(), column_right_ptr->template get_value<B>(),
|
||||
max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
column_left->get_ptr(), column_right_ptr->template get_value<B>(), type_left,
|
||||
type_right, max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
} else {
|
||||
return OperationImpl::adapt_normal_vector_constant(
|
||||
column_left->get_ptr(), column_right_ptr->template get_value<B>());
|
||||
@ -827,8 +887,8 @@ private:
|
||||
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto max_and_multiplier = get_max_and_multiplier(type_left, type_right, type_result);
|
||||
return OperationImpl::adapt_decimal_constant_vector(
|
||||
column_left_ptr->template get_value<A>(), column_right->get_ptr(),
|
||||
max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
column_left_ptr->template get_value<A>(), column_right->get_ptr(), type_left,
|
||||
type_right, max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
} else {
|
||||
return OperationImpl::adapt_normal_constant_vector(
|
||||
column_left_ptr->template get_value<A>(), column_right->get_ptr());
|
||||
@ -843,8 +903,8 @@ private:
|
||||
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
|
||||
auto max_and_multiplier = get_max_and_multiplier(type_left, type_right, type_result);
|
||||
return OperationImpl::adapt_decimal_vector_vector(
|
||||
column_left->get_ptr(), column_right->get_ptr(), max_and_multiplier.first,
|
||||
max_and_multiplier.second, res_data_type);
|
||||
column_left->get_ptr(), column_right->get_ptr(), type_left, type_right,
|
||||
max_and_multiplier.first, max_and_multiplier.second, res_data_type);
|
||||
} else {
|
||||
return OperationImpl::adapt_normal_vector_vector(column_left->get_ptr(),
|
||||
column_right->get_ptr());
|
||||
@ -1004,7 +1064,7 @@ public:
|
||||
LeftDataType, RightDataType,
|
||||
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
|
||||
ExpectedResultDataType, ResultDataType>,
|
||||
Operation, is_to_null_type,
|
||||
Operation, Name, is_to_null_type,
|
||||
true>::execute(block.get_by_position(arguments[0]).column,
|
||||
block.get_by_position(arguments[1]).column, left,
|
||||
right,
|
||||
@ -1016,7 +1076,7 @@ public:
|
||||
LeftDataType, RightDataType,
|
||||
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
|
||||
ExpectedResultDataType, ResultDataType>,
|
||||
Operation, is_to_null_type,
|
||||
Operation, Name, is_to_null_type,
|
||||
false>::execute(block.get_by_position(arguments[0]).column,
|
||||
block.get_by_position(arguments[1]).column,
|
||||
left, right,
|
||||
|
||||
Reference in New Issue
Block a user