[Improvement](DECIMAL) Improve decimal operation (#18437)

This commit is contained in:
Gabriel
2023-04-07 15:58:28 +08:00
committed by GitHub
parent c5d9e8529a
commit f6f4dac1d0
4 changed files with 61 additions and 51 deletions

View File

@ -231,13 +231,12 @@ struct DecimalBinaryOperation {
using Traits = NumberTraits::BinaryOperatorTraits<A, B>;
using ArrayC = typename ColumnDecimal<ResultType>::Container;
static void vector_vector(const typename Traits::ArrayA& a, const typename Traits::ArrayB& b,
ArrayC& c) {
size_t size = a.size();
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, size_t size) {
if constexpr (OpTraits::is_multiply && IsDecimalV2<A> && IsDecimalV2<B> &&
IsDecimalV2<ResultType>) {
Op::vector_vector(a, b, c);
Op::vector_vector(a, b, c, size);
} else {
for (size_t i = 0; i < size; i++) {
c[i] = apply(a[i], b[i]);
@ -246,9 +245,9 @@ struct DecimalBinaryOperation {
}
/// null_map for divide and mod
static void vector_vector(const typename Traits::ArrayA& a, const typename Traits::ArrayB& b,
ArrayC& c, NullMap& null_map) {
size_t size = a.size();
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
/// default: use it if no return before
for (size_t i = 0; i < size; ++i) {
@ -266,8 +265,8 @@ struct DecimalBinaryOperation {
}
}
static void vector_constant(const typename Traits::ArrayA& a, B b, ArrayC& c) {
size_t size = a.size();
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
typename ArrayC::value_type* c, size_t size) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a[i], b);
@ -281,9 +280,8 @@ struct DecimalBinaryOperation {
}
}
static void vector_constant(const typename Traits::ArrayA& a, B b, ArrayC& c,
NullMap& null_map) {
size_t size = a.size();
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a[i], b, null_map[i]);
@ -300,8 +298,8 @@ struct DecimalBinaryOperation {
}
}
static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c) {
size_t size = b.size();
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, size_t size) {
if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
DecimalV2Value da(a);
for (size_t i = 0; i < size; ++i) {
@ -314,9 +312,8 @@ struct DecimalBinaryOperation {
}
}
static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c,
NullMap& null_map) {
size_t size = b.size();
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a, b[i], null_map[i]);
@ -382,11 +379,12 @@ struct DecimalBinaryOperation {
return column_result;
} else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_left->size(), 0);
vector_constant(column_left_ptr->get_data(), b, column_result->get_data(),
null_map->get_data());
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
null_map->get_data(), column_left->size());
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_constant(column_left_ptr->get_data(), b, column_result->get_data());
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
column_left->size());
return column_result;
}
}
@ -406,11 +404,13 @@ struct DecimalBinaryOperation {
return column_result;
} else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_right->size(), 0);
constant_vector(a, column_right_ptr->get_data(), column_result->get_data(),
null_map->get_data());
constant_vector(a, column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_right->size());
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
constant_vector(a, column_right_ptr->get_data(), column_result->get_data());
constant_vector(a, column_right_ptr->get_data().data(),
column_result->get_data().data(), column_right->size());
return column_result;
}
}
@ -432,19 +432,20 @@ struct DecimalBinaryOperation {
return column_result;
} else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_result->size(), 0);
vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(),
column_result->get_data(), null_map->get_data());
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_left->size());
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(),
column_result->get_data());
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
column_result->get_data().data(), column_left->size());
return column_result;
}
}
private:
/// there's implicit type conversion here
static NativeResultType apply(NativeResultType a, NativeResultType b) {
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b) {
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
// Now, Doris only support decimal +-*/ decimal.
// overflow in consider in operator
@ -465,7 +466,8 @@ private:
}
/// null_map for divide and mod
static NativeResultType apply(NativeResultType a, NativeResultType b, UInt8& is_null) {
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
UInt8& is_null) {
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
DecimalV2Value l(a);
DecimalV2Value r(b);
@ -516,7 +518,8 @@ private:
return apply(a, b, is_null);
}
static UInt8 apply_op_safely(NativeResultType a, NativeResultType b, NativeResultType& c) {
static ALWAYS_INLINE UInt8 apply_op_safely(NativeResultType a, NativeResultType b,
NativeResultType& c) {
if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
return Op::template apply(a, b, c);
}