[improvement](decimalv2) support check overflow for decimalv2 arithmetics (#28456)

This commit is contained in:
TengJianPing
2023-12-18 10:54:25 +08:00
committed by GitHub
parent 6e855dd198
commit fbe5a7c244
9 changed files with 795 additions and 115 deletions

View File

@ -244,10 +244,10 @@ private:
typename ArrayC::value_type* c, size_t size,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
// TODO: handle overflow of decimalv2
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
if constexpr (OpTraits::is_multiply && IsDecimalV2<A> && IsDecimalV2<B> &&
IsDecimalV2<ResultType>) {
Op::vector_vector(a, b, c, size);
Op::template vector_vector<check_overflow>(a, b, c, size);
} else {
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
@ -264,33 +264,31 @@ private:
/// null_map for divide and mod
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
/// default: use it if no return before
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i]));
c[i] = typename ArrayC::value_type(
apply(a[i], b[i], null_map[i], max_result_number));
}
} else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
for (size_t i = 0; i < size; ++i) {
if constexpr (IsDecimalNumber<B> && IsDecimalNumber<A>) {
c[i] = typename ArrayC::value_type(
apply_scaled_div(a[i].value, b[i].value, null_map[i]));
apply(a[i].value, b[i].value, null_map[i], max_result_number));
} else if constexpr (IsDecimalNumber<A>) {
c[i] = typename ArrayC::value_type(
apply_scaled_div(a[i].value, b[i], null_map[i]));
apply(a[i].value, b[i], null_map[i], max_result_number));
} else {
c[i] = typename ArrayC::value_type(
apply_scaled_div(a[i], b[i].value, null_map[i]));
apply(a[i], b[i].value, null_map[i], max_result_number));
}
}
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
for (size_t i = 0; i < size; ++i) {
null_map[i] = apply_op_safely(a[i], b[i], c[i].value);
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i]));
c[i] = typename ArrayC::value_type(
apply(a[i], b[i], null_map[i], max_result_number));
}
}
}
@ -299,15 +297,8 @@ private:
typename ArrayC::value_type* c, size_t size,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
// code never executed????
c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b, a));
}
return;
}
static_assert(!OpTraits::is_division);
/// default: use it if no return before
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
@ -320,19 +311,17 @@ private:
}
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b.value, null_map[i]));
}
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
for (size_t i = 0; i < size; ++i) {
null_map[i] = apply_op_safely(a[i], b, c[i].value);
c[i] = typename ArrayC::value_type(
apply(a[i], b.value, null_map[i], max_result_number));
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i]));
c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i], max_result_number));
}
}
}
@ -341,39 +330,29 @@ private:
typename ArrayC::value_type* c, size_t size,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
DecimalV2Value da(a);
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(
Op::template apply(da, DecimalV2Value(b[i])).value());
}
} else {
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
a, b[i], max_result_number, scale_diff_multiplier));
}
},
make_bool_variant(need_adjust_scale));
}
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
a, b[i], max_result_number, scale_diff_multiplier));
}
},
make_bool_variant(need_adjust_scale));
}
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply_scaled_div(a, b[i].value, null_map[i]));
}
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
for (size_t i = 0; i < size; ++i) {
null_map[i] = apply_op_safely(a, b[i], c[i].value);
c[i] = typename ArrayC::value_type(
apply(a, b[i].value, null_map[i], max_result_number));
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i]));
c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i], max_result_number));
}
}
}
@ -383,20 +362,17 @@ private:
return ResultType(apply<true>(a, b, max_result_number, scale_diff_multiplier));
}
static ResultType constant_constant(A a, B b, UInt8& is_null) {
static ResultType constant_constant(A a, B b, UInt8& is_null,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
if constexpr (IsDecimalNumber<A>) {
return ResultType(apply_scaled_div(a.value, b.value, is_null));
return ResultType(apply(a.value, b.value, is_null, max_result_number));
} else {
return ResultType(apply_scaled_div(a, b.value, is_null));
return ResultType(apply(a, b.value, is_null, max_result_number));
}
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
NativeResultType res;
is_null = apply_op_safely(a, b, res);
return ResultType(res);
} else {
return ResultType(apply(a, b, is_null));
return ResultType(apply(a, b, is_null, max_result_number));
}
}
@ -408,13 +384,13 @@ public:
1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
IsDecimalV2<B>)) {
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
LOG(FATAL) << "Invalid function type!";
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(1, 0);
column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0));
column_result->get_element(0) =
constant_constant(a, b, null_map->get_element(0), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
column_result->get_element(0) =
@ -434,14 +410,13 @@ public:
DCHECK(column_left_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
IsDecimalV2<B>)) {
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
LOG(FATAL) << "Invalid function type!";
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(column_left->size(), 0);
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
null_map->get_data(), column_left->size());
null_map->get_data(), column_left->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
@ -461,15 +436,14 @@ public:
DCHECK(column_right_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
IsDecimalV2<B>)) {
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
LOG(FATAL) << "Invalid function type!";
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(column_right->size(), 0);
constant_vector(a, column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_right->size());
column_right->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
constant_vector(a, column_right_ptr->get_data().data(),
@ -492,15 +466,15 @@ public:
DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
IsDecimalV2<B>)) {
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
LOG(FATAL) << "Invalid function type!";
return column_result;
} else if constexpr (is_to_null_type) {
// function divide, modulo and pmod
auto null_map = ColumnUInt8::create(column_result->size(), 0);
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_left->size());
column_left->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
@ -516,11 +490,18 @@ private:
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
// TODO: handle overflow of decimalv2
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
// Now, Doris only support decimal +-*/ decimal.
// overflow in consider in operator
return Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
if constexpr (check_overflow) {
auto res = Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
if (res > max_result_number.value || res < -max_result_number.value) {
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR, "Arithmetic overflow");
}
return res;
} else {
return Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
}
} else {
NativeResultType res;
if constexpr (OpTraits::can_overflow && check_overflow) {
@ -588,11 +569,34 @@ private:
/// null_map for divide and mod
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
UInt8& is_null) {
UInt8& is_null,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
DecimalV2Value l(a);
DecimalV2Value r(b);
auto ans = Op::template apply(l, r, is_null);
using ANS_TYPE = std::decay_t<decltype(ans)>;
if constexpr (check_overflow && OpTraits::is_division) {
if constexpr (std::is_same_v<ANS_TYPE, DecimalV2Value>) {
if (ans.value() > max_result_number.value ||
ans.value() < -max_result_number.value) {
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
"Arithmetic overflow");
}
} else if constexpr (IsDecimalNumber<ANS_TYPE>) {
if (ans.value > max_result_number.value ||
ans.value < -max_result_number.value) {
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
"Arithmetic overflow");
}
} else {
if (ans > max_result_number.value || ans < -max_result_number.value) {
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
"Arithmetic overflow");
}
}
}
NativeResultType result {};
memcpy(&result, &ans, std::min(sizeof(result), sizeof(ans)));
return result;
@ -600,25 +604,6 @@ private:
return Op::template apply<NativeResultType>(a, b, is_null);
}
}
static NativeResultType apply_scaled_div(NativeResultType a, NativeResultType b,
UInt8& is_null) {
if constexpr (OpTraits::is_division) {
return apply(a, b, is_null);
}
}
static NativeResultType apply_scaled_mod(NativeResultType a, NativeResultType b,
UInt8& is_null) {
return apply(a, b, is_null);
}
static ALWAYS_INLINE UInt8 apply_op_safely(NativeResultType a, NativeResultType b,
NativeResultType& c) {
if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
return Op::template apply(a, b, c);
}
}
};
/// Used to indicate undefined operation
@ -979,11 +964,14 @@ public:
(IsDataTypeDecimal<LeftDataType> ||
IsDataTypeDecimal<RightDataType>))) {
if (check_overflow_for_decimal) {
if constexpr ((IsDecimalV2<typename LeftDataType::FieldType> ||
IsDecimalV2<typename RightDataType::
FieldType>)&&!is_to_null_type) {
// !is_to_null_type: plus, minus, multiply,
// pow, bitxor, bitor, bitand
// if check_overflow and params are decimal types:
// for functions pow, bitxor, bitor, bitand, return error
if constexpr (IsDataTypeDecimal<ResultDataType> && !is_to_null_type &&
!OpTraits::is_multiply && !OpTraits::is_plus_minus) {
status = Status::Error<ErrorCode::NOT_IMPLEMENTED_ERROR>(
"cannot check overflow with decimalv2");
"cannot check overflow with decimal for function {}", name);
return false;
}
auto column_result = ConstOrVectorAdapter<