[improvement](decimalv2) support check overflow for decimalv2 arithmetics (#28456)
This commit is contained in:
@ -244,10 +244,10 @@ private:
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
// TODO: handle overflow of decimalv2
|
||||
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
|
||||
if constexpr (OpTraits::is_multiply && IsDecimalV2<A> && IsDecimalV2<B> &&
|
||||
IsDecimalV2<ResultType>) {
|
||||
Op::vector_vector(a, b, c, size);
|
||||
Op::template vector_vector<check_overflow>(a, b, c, size);
|
||||
} else {
|
||||
bool need_adjust_scale = scale_diff_multiplier.value > 1;
|
||||
std::visit(
|
||||
@ -264,33 +264,31 @@ private:
|
||||
/// null_map for divide and mod
|
||||
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
|
||||
const typename Traits::ArrayB::value_type* __restrict b,
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
|
||||
const ResultType& max_result_number) {
|
||||
static_assert(OpTraits::is_division || OpTraits::is_mod);
|
||||
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
|
||||
/// default: use it if no return before
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i]));
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply(a[i], b[i], null_map[i], max_result_number));
|
||||
}
|
||||
} else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (IsDecimalNumber<B> && IsDecimalNumber<A>) {
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply_scaled_div(a[i].value, b[i].value, null_map[i]));
|
||||
apply(a[i].value, b[i].value, null_map[i], max_result_number));
|
||||
} else if constexpr (IsDecimalNumber<A>) {
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply_scaled_div(a[i].value, b[i], null_map[i]));
|
||||
apply(a[i].value, b[i], null_map[i], max_result_number));
|
||||
} else {
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply_scaled_div(a[i], b[i].value, null_map[i]));
|
||||
apply(a[i], b[i].value, null_map[i], max_result_number));
|
||||
}
|
||||
}
|
||||
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
|
||||
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
null_map[i] = apply_op_safely(a[i], b[i], c[i].value);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i]));
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply(a[i], b[i], null_map[i], max_result_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -299,15 +297,8 @@ private:
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
// code never executed????
|
||||
c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b, a));
|
||||
}
|
||||
return;
|
||||
}
|
||||
static_assert(!OpTraits::is_division);
|
||||
|
||||
/// default: use it if no return before
|
||||
bool need_adjust_scale = scale_diff_multiplier.value > 1;
|
||||
std::visit(
|
||||
[&](auto need_adjust_scale) {
|
||||
@ -320,19 +311,17 @@ private:
|
||||
}
|
||||
|
||||
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
|
||||
const ResultType& max_result_number) {
|
||||
static_assert(OpTraits::is_division || OpTraits::is_mod);
|
||||
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b.value, null_map[i]));
|
||||
}
|
||||
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
|
||||
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
null_map[i] = apply_op_safely(a[i], b, c[i].value);
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply(a[i], b.value, null_map[i], max_result_number));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i]));
|
||||
c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i], max_result_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -341,39 +330,29 @@ private:
|
||||
typename ArrayC::value_type* c, size_t size,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
|
||||
DecimalV2Value da(a);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(
|
||||
Op::template apply(da, DecimalV2Value(b[i])).value());
|
||||
}
|
||||
} else {
|
||||
bool need_adjust_scale = scale_diff_multiplier.value > 1;
|
||||
std::visit(
|
||||
[&](auto need_adjust_scale) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
|
||||
a, b[i], max_result_number, scale_diff_multiplier));
|
||||
}
|
||||
},
|
||||
make_bool_variant(need_adjust_scale));
|
||||
}
|
||||
bool need_adjust_scale = scale_diff_multiplier.value > 1;
|
||||
std::visit(
|
||||
[&](auto need_adjust_scale) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
|
||||
a, b[i], max_result_number, scale_diff_multiplier));
|
||||
}
|
||||
},
|
||||
make_bool_variant(need_adjust_scale));
|
||||
}
|
||||
|
||||
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
|
||||
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
|
||||
const ResultType& max_result_number) {
|
||||
static_assert(OpTraits::is_division || OpTraits::is_mod);
|
||||
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply_scaled_div(a, b[i].value, null_map[i]));
|
||||
}
|
||||
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
|
||||
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
null_map[i] = apply_op_safely(a, b[i], c[i].value);
|
||||
c[i] = typename ArrayC::value_type(
|
||||
apply(a, b[i].value, null_map[i], max_result_number));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i]));
|
||||
c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i], max_result_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -383,20 +362,17 @@ private:
|
||||
return ResultType(apply<true>(a, b, max_result_number, scale_diff_multiplier));
|
||||
}
|
||||
|
||||
static ResultType constant_constant(A a, B b, UInt8& is_null) {
|
||||
static ResultType constant_constant(A a, B b, UInt8& is_null,
|
||||
const ResultType& max_result_number) {
|
||||
static_assert(OpTraits::is_division || OpTraits::is_mod);
|
||||
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
if constexpr (IsDecimalNumber<A>) {
|
||||
return ResultType(apply_scaled_div(a.value, b.value, is_null));
|
||||
return ResultType(apply(a.value, b.value, is_null, max_result_number));
|
||||
} else {
|
||||
return ResultType(apply_scaled_div(a, b.value, is_null));
|
||||
return ResultType(apply(a, b.value, is_null, max_result_number));
|
||||
}
|
||||
} else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
|
||||
(IsDecimalNumber<B> || IsDecimalNumber<A>)) {
|
||||
NativeResultType res;
|
||||
is_null = apply_op_safely(a, b, res);
|
||||
return ResultType(res);
|
||||
} else {
|
||||
return ResultType(apply(a, b, is_null));
|
||||
return ResultType(apply(a, b, is_null, max_result_number));
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,13 +384,13 @@ public:
|
||||
1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
|
||||
|
||||
if constexpr (check_overflow && !is_to_null_type &&
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
|
||||
IsDecimalV2<B>)) {
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
|
||||
LOG(FATAL) << "Invalid function type!";
|
||||
return column_result;
|
||||
} else if constexpr (is_to_null_type) {
|
||||
auto null_map = ColumnUInt8::create(1, 0);
|
||||
column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0));
|
||||
column_result->get_element(0) =
|
||||
constant_constant(a, b, null_map->get_element(0), max_result_number);
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
column_result->get_element(0) =
|
||||
@ -434,14 +410,13 @@ public:
|
||||
DCHECK(column_left_ptr != nullptr);
|
||||
|
||||
if constexpr (check_overflow && !is_to_null_type &&
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
|
||||
IsDecimalV2<B>)) {
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
|
||||
LOG(FATAL) << "Invalid function type!";
|
||||
return column_result;
|
||||
} else if constexpr (is_to_null_type) {
|
||||
auto null_map = ColumnUInt8::create(column_left->size(), 0);
|
||||
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
|
||||
null_map->get_data(), column_left->size());
|
||||
null_map->get_data(), column_left->size(), max_result_number);
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
|
||||
@ -461,15 +436,14 @@ public:
|
||||
DCHECK(column_right_ptr != nullptr);
|
||||
|
||||
if constexpr (check_overflow && !is_to_null_type &&
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
|
||||
IsDecimalV2<B>)) {
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
|
||||
LOG(FATAL) << "Invalid function type!";
|
||||
return column_result;
|
||||
} else if constexpr (is_to_null_type) {
|
||||
auto null_map = ColumnUInt8::create(column_right->size(), 0);
|
||||
constant_vector(a, column_right_ptr->get_data().data(),
|
||||
column_result->get_data().data(), null_map->get_data(),
|
||||
column_right->size());
|
||||
column_right->size(), max_result_number);
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
constant_vector(a, column_right_ptr->get_data().data(),
|
||||
@ -492,15 +466,15 @@ public:
|
||||
DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
|
||||
|
||||
if constexpr (check_overflow && !is_to_null_type &&
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
|
||||
IsDecimalV2<B>)) {
|
||||
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
|
||||
LOG(FATAL) << "Invalid function type!";
|
||||
return column_result;
|
||||
} else if constexpr (is_to_null_type) {
|
||||
// function divide, modulo and pmod
|
||||
auto null_map = ColumnUInt8::create(column_result->size(), 0);
|
||||
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
|
||||
column_result->get_data().data(), null_map->get_data(),
|
||||
column_left->size());
|
||||
column_left->size(), max_result_number);
|
||||
return ColumnNullable::create(std::move(column_result), std::move(null_map));
|
||||
} else {
|
||||
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
|
||||
@ -516,11 +490,18 @@ private:
|
||||
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
|
||||
const ResultType& max_result_number,
|
||||
const ResultType& scale_diff_multiplier) {
|
||||
// TODO: handle overflow of decimalv2
|
||||
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
|
||||
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
|
||||
// Now, Doris only support decimal +-*/ decimal.
|
||||
// overflow in consider in operator
|
||||
return Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
|
||||
if constexpr (check_overflow) {
|
||||
auto res = Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
|
||||
if (res > max_result_number.value || res < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR, "Arithmetic overflow");
|
||||
}
|
||||
return res;
|
||||
} else {
|
||||
return Op::template apply(DecimalV2Value(a), DecimalV2Value(b)).value();
|
||||
}
|
||||
} else {
|
||||
NativeResultType res;
|
||||
if constexpr (OpTraits::can_overflow && check_overflow) {
|
||||
@ -588,11 +569,34 @@ private:
|
||||
|
||||
/// null_map for divide and mod
|
||||
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
|
||||
UInt8& is_null) {
|
||||
UInt8& is_null,
|
||||
const ResultType& max_result_number) {
|
||||
static_assert(OpTraits::is_division || OpTraits::is_mod);
|
||||
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
|
||||
DecimalV2Value l(a);
|
||||
DecimalV2Value r(b);
|
||||
auto ans = Op::template apply(l, r, is_null);
|
||||
using ANS_TYPE = std::decay_t<decltype(ans)>;
|
||||
if constexpr (check_overflow && OpTraits::is_division) {
|
||||
if constexpr (std::is_same_v<ANS_TYPE, DecimalV2Value>) {
|
||||
if (ans.value() > max_result_number.value ||
|
||||
ans.value() < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
}
|
||||
} else if constexpr (IsDecimalNumber<ANS_TYPE>) {
|
||||
if (ans.value > max_result_number.value ||
|
||||
ans.value < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
}
|
||||
} else {
|
||||
if (ans > max_result_number.value || ans < -max_result_number.value) {
|
||||
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
|
||||
"Arithmetic overflow");
|
||||
}
|
||||
}
|
||||
}
|
||||
NativeResultType result {};
|
||||
memcpy(&result, &ans, std::min(sizeof(result), sizeof(ans)));
|
||||
return result;
|
||||
@ -600,25 +604,6 @@ private:
|
||||
return Op::template apply<NativeResultType>(a, b, is_null);
|
||||
}
|
||||
}
|
||||
|
||||
static NativeResultType apply_scaled_div(NativeResultType a, NativeResultType b,
|
||||
UInt8& is_null) {
|
||||
if constexpr (OpTraits::is_division) {
|
||||
return apply(a, b, is_null);
|
||||
}
|
||||
}
|
||||
|
||||
static NativeResultType apply_scaled_mod(NativeResultType a, NativeResultType b,
|
||||
UInt8& is_null) {
|
||||
return apply(a, b, is_null);
|
||||
}
|
||||
|
||||
static ALWAYS_INLINE UInt8 apply_op_safely(NativeResultType a, NativeResultType b,
|
||||
NativeResultType& c) {
|
||||
if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
|
||||
return Op::template apply(a, b, c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Used to indicate undefined operation
|
||||
@ -979,11 +964,14 @@ public:
|
||||
(IsDataTypeDecimal<LeftDataType> ||
|
||||
IsDataTypeDecimal<RightDataType>))) {
|
||||
if (check_overflow_for_decimal) {
|
||||
if constexpr ((IsDecimalV2<typename LeftDataType::FieldType> ||
|
||||
IsDecimalV2<typename RightDataType::
|
||||
FieldType>)&&!is_to_null_type) {
|
||||
// !is_to_null_type: plus, minus, multiply,
|
||||
// pow, bitxor, bitor, bitand
|
||||
// if check_overflow and params are decimal types:
|
||||
// for functions pow, bitxor, bitor, bitand, return error
|
||||
if constexpr (IsDataTypeDecimal<ResultDataType> && !is_to_null_type &&
|
||||
!OpTraits::is_multiply && !OpTraits::is_plus_minus) {
|
||||
status = Status::Error<ErrorCode::NOT_IMPLEMENTED_ERROR>(
|
||||
"cannot check overflow with decimalv2");
|
||||
"cannot check overflow with decimal for function {}", name);
|
||||
return false;
|
||||
}
|
||||
auto column_result = ConstOrVectorAdapter<
|
||||
|
||||
Reference in New Issue
Block a user