[feature-wip] Optimize Decimal type (#10794)

* [feature-wip](decimalv3) support decimalv3

* [feature-wip] Optimize Decimal type

Co-authored-by: liaoxin <liaoxinbit@126.com>
This commit is contained in:
Gabriel
2022-07-14 10:50:50 +08:00
committed by GitHub
parent bb0d023abd
commit 3b46242483
149 changed files with 4011 additions and 549 deletions

View File

@ -66,6 +66,7 @@ struct OperationTraits {
static constexpr bool is_multiply = std::is_same_v<Op, MultiplyImpl<T, T>>;
static constexpr bool is_division = std::is_same_v<Op, DivideFloatingImpl<T, T>> ||
std::is_same_v<Op, DivideIntegralImpl<T, T>>;
static constexpr bool is_mod = std::is_same_v<Op, ModuloImpl<T, T>>;
static constexpr bool allow_decimal =
std::is_same_v<Op, PlusImpl<T, T>> || std::is_same_v<Op, MinusImpl<T, T>> ||
std::is_same_v<Op, MultiplyImpl<T, T>> || std::is_same_v<Op, ModuloImpl<T, T>> ||
@ -212,7 +213,7 @@ struct BinaryOperationImpl {
/// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
template <typename A, typename B, template <typename, typename> typename Operation,
typename ResultType, bool is_to_null_type, bool check_overflow = false>
typename ResultType, bool is_to_null_type, bool check_overflow = true>
struct DecimalBinaryOperation {
using OpTraits = OperationTraits<Operation>;
@ -251,7 +252,26 @@ struct DecimalBinaryOperation {
ArrayC& c, ResultType scale_a [[maybe_unused]],
ResultType scale_b [[maybe_unused]], NullMap& null_map) {
size_t size = a.size();
if (config::enable_decimalv3) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a[i], b[i], scale_a, null_map[i]);
}
return;
} else if constexpr (OpTraits::is_mod) {
if (scale_a != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<true>(a[i], b[i], scale_a, null_map[i]);
}
return;
} else if (scale_b != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<false>(a[i], b[i], scale_b, null_map[i]);
}
return;
}
}
}
/// default: use it if no return before
for (size_t i = 0; i < size; ++i) {
c[i] = apply(a[i], b[i], null_map[i]);
@ -296,6 +316,18 @@ struct DecimalBinaryOperation {
c[i] = apply_scaled_div(a[i], b, scale_a, null_map[i]);
}
return;
} else if constexpr (OpTraits::is_mod) {
if (scale_a != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<true>(a[i], b, scale_a, null_map[i]);
}
return;
} else if (scale_b != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<false>(a[i], b, scale_b, null_map[i]);
}
return;
}
}
for (size_t i = 0; i < size; ++i) {
@ -341,6 +373,18 @@ struct DecimalBinaryOperation {
c[i] = apply_scaled_div(a, b[i], scale_a, null_map[i]);
}
return;
} else if constexpr (OpTraits::is_mod) {
if (scale_a != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<true>(a, b[i], scale_a, null_map[i]);
}
return;
} else if (scale_b != 1) {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_mod<false>(a, b[i], scale_b, null_map[i]);
}
return;
}
}
for (size_t i = 0; i < size; ++i) {
@ -372,6 +416,12 @@ struct DecimalBinaryOperation {
}
} else if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
return apply_scaled_div(a, b, scale_a, is_null);
} else if constexpr (OpTraits::is_mod) {
if (scale_a != 1) {
return apply_scaled_mod<true>(a, b, scale_a, is_null);
} else if (scale_b != 1) {
return apply_scaled_mod<false>(a, b, scale_b, is_null);
}
}
return apply(a, b, is_null);
}
@ -451,6 +501,20 @@ struct DecimalBinaryOperation {
private:
/// there's implicit type convertion here
static NativeResultType apply(NativeResultType a, NativeResultType b) {
if (config::enable_decimalv3) {
if constexpr (OpTraits::can_overflow && check_overflow) {
NativeResultType res;
// TODO handle overflow gracefully
if (Op::template apply<NativeResultType>(a, b, res)) {
LOG(WARNING) << "Decimal math overflow";
res = max_decimal_value<ResultType>();
}
return res;
} else {
return Op::template apply<NativeResultType>(a, b);
}
}
// Now, Doris only support decimal +-*/ decimal.
// overflow in consider in operator
DecimalV2Value l(a);
@ -463,6 +527,10 @@ private:
/// null_map for divide and mod
static NativeResultType apply(NativeResultType a, NativeResultType b, UInt8& is_null) {
if (config::enable_decimalv3) {
return Op::template apply<NativeResultType>(a, b, is_null);
}
DecimalV2Value l(a);
DecimalV2Value r(b);
auto ans = Op::template apply(l, r, is_null);
@ -491,8 +559,10 @@ private:
res = Op::template apply<NativeResultType>(a, b);
}
// TODO handle overflow gracefully
if (overflow) {
LOG(FATAL) << "Decimal math overflow";
LOG(WARNING) << "Decimal math overflow";
res = max_decimal_value<ResultType>();
}
} else {
if constexpr (scale_left) {
@ -516,8 +586,10 @@ private:
overflow |= common::mul_overflow(scale, scale, scale);
}
overflow |= common::mul_overflow(a, scale, a);
// TODO handle overflow gracefully
if (overflow) {
LOG(FATAL) << "Decimal math overflow";
LOG(WARNING) << "Decimal math overflow";
return max_decimal_value<ResultType>();
}
} else {
if constexpr (!IsDecimalNumber<A>) {
@ -529,6 +601,31 @@ private:
return apply(a, b, is_null);
}
}
template <bool scale_left>
static NativeResultType apply_scaled_mod(NativeResultType a, NativeResultType b,
NativeResultType scale, UInt8& is_null) {
if constexpr (check_overflow) {
bool overflow = false;
if constexpr (scale_left)
overflow |= common::mul_overflow(a, scale, a);
else
overflow |= common::mul_overflow(b, scale, b);
// TODO handle overflow gracefully
if (overflow) {
LOG(WARNING) << "Decimal math overflow";
return max_decimal_value<ResultType>();
}
} else {
if constexpr (scale_left)
a *= scale;
else
b *= scale;
}
return apply(a, b, is_null);
}
};
/// Used to indicate undefined operation
@ -646,10 +743,16 @@ private:
static auto get_decimal_infos(const LeftDataType& type_left, const RightDataType& type_right) {
ResultDataType type = decimal_result_type(type_left, type_right, OpTraits::is_multiply,
OpTraits::is_division);
typename ResultDataType::FieldType scale_a =
type.scale_factor_for(type_left, OpTraits::is_multiply);
typename ResultDataType::FieldType scale_b =
type.scale_factor_for(type_right, OpTraits::is_multiply || OpTraits::is_division);
typename ResultDataType::FieldType scale_a;
typename ResultDataType::FieldType scale_b;
if constexpr (OpTraits::is_division && IsDataTypeDecimal<RightDataType>) {
scale_a = type_right.get_scale_multiplier();
scale_b = 1;
} else {
scale_a = type.scale_factor_for(type_left, OpTraits::is_multiply);
scale_b = type.scale_factor_for(type_right,
OpTraits::is_multiply || OpTraits::is_division);
}
return std::make_tuple(type, scale_a, scale_b);
}