[feature-wip] Optimize Decimal type (#10794)
* [feature-wip](decimalv3) support decimalv3 * [feature-wip] Optimize Decimal type Co-authored-by: liaoxin <liaoxinbit@126.com>
This commit is contained in:
@ -66,6 +66,7 @@ struct OperationTraits {
|
||||
static constexpr bool is_multiply = std::is_same_v<Op, MultiplyImpl<T, T>>;
|
||||
static constexpr bool is_division = std::is_same_v<Op, DivideFloatingImpl<T, T>> ||
|
||||
std::is_same_v<Op, DivideIntegralImpl<T, T>>;
|
||||
static constexpr bool is_mod = std::is_same_v<Op, ModuloImpl<T, T>>;
|
||||
static constexpr bool allow_decimal =
|
||||
std::is_same_v<Op, PlusImpl<T, T>> || std::is_same_v<Op, MinusImpl<T, T>> ||
|
||||
std::is_same_v<Op, MultiplyImpl<T, T>> || std::is_same_v<Op, ModuloImpl<T, T>> ||
|
||||
@ -212,7 +213,7 @@ struct BinaryOperationImpl {
|
||||
/// * no agrs scale. ScaleR = Scale1 + Scale2;
|
||||
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
|
||||
template <typename A, typename B, template <typename, typename> typename Operation,
|
||||
typename ResultType, bool is_to_null_type, bool check_overflow = false>
|
||||
typename ResultType, bool is_to_null_type, bool check_overflow = true>
|
||||
struct DecimalBinaryOperation {
|
||||
using OpTraits = OperationTraits<Operation>;
|
||||
|
||||
@ -251,7 +252,26 @@ struct DecimalBinaryOperation {
|
||||
ArrayC& c, ResultType scale_a [[maybe_unused]],
|
||||
ResultType scale_b [[maybe_unused]], NullMap& null_map) {
|
||||
size_t size = a.size();
|
||||
|
||||
if (config::enable_decimalv3) {
|
||||
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_div(a[i], b[i], scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if constexpr (OpTraits::is_mod) {
|
||||
if (scale_a != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<true>(a[i], b[i], scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if (scale_b != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<false>(a[i], b[i], scale_b, null_map[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
/// default: use it if no return before
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply(a[i], b[i], null_map[i]);
|
||||
@ -296,6 +316,18 @@ struct DecimalBinaryOperation {
|
||||
c[i] = apply_scaled_div(a[i], b, scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if constexpr (OpTraits::is_mod) {
|
||||
if (scale_a != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<true>(a[i], b, scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if (scale_b != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<false>(a[i], b, scale_b, null_map[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
@ -341,6 +373,18 @@ struct DecimalBinaryOperation {
|
||||
c[i] = apply_scaled_div(a, b[i], scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if constexpr (OpTraits::is_mod) {
|
||||
if (scale_a != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<true>(a, b[i], scale_a, null_map[i]);
|
||||
}
|
||||
return;
|
||||
} else if (scale_b != 1) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
c[i] = apply_scaled_mod<false>(a, b[i], scale_b, null_map[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
@ -372,6 +416,12 @@ struct DecimalBinaryOperation {
|
||||
}
|
||||
} else if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
|
||||
return apply_scaled_div(a, b, scale_a, is_null);
|
||||
} else if constexpr (OpTraits::is_mod) {
|
||||
if (scale_a != 1) {
|
||||
return apply_scaled_mod<true>(a, b, scale_a, is_null);
|
||||
} else if (scale_b != 1) {
|
||||
return apply_scaled_mod<false>(a, b, scale_b, is_null);
|
||||
}
|
||||
}
|
||||
return apply(a, b, is_null);
|
||||
}
|
||||
@ -451,6 +501,20 @@ struct DecimalBinaryOperation {
|
||||
private:
|
||||
/// there's implicit type convertion here
|
||||
static NativeResultType apply(NativeResultType a, NativeResultType b) {
|
||||
if (config::enable_decimalv3) {
|
||||
if constexpr (OpTraits::can_overflow && check_overflow) {
|
||||
NativeResultType res;
|
||||
// TODO handle overflow gracefully
|
||||
if (Op::template apply<NativeResultType>(a, b, res)) {
|
||||
LOG(WARNING) << "Decimal math overflow";
|
||||
res = max_decimal_value<ResultType>();
|
||||
}
|
||||
return res;
|
||||
} else {
|
||||
return Op::template apply<NativeResultType>(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
// Now, Doris only support decimal +-*/ decimal.
|
||||
// overflow in consider in operator
|
||||
DecimalV2Value l(a);
|
||||
@ -463,6 +527,10 @@ private:
|
||||
|
||||
/// null_map for divide and mod
|
||||
static NativeResultType apply(NativeResultType a, NativeResultType b, UInt8& is_null) {
|
||||
if (config::enable_decimalv3) {
|
||||
return Op::template apply<NativeResultType>(a, b, is_null);
|
||||
}
|
||||
|
||||
DecimalV2Value l(a);
|
||||
DecimalV2Value r(b);
|
||||
auto ans = Op::template apply(l, r, is_null);
|
||||
@ -491,8 +559,10 @@ private:
|
||||
res = Op::template apply<NativeResultType>(a, b);
|
||||
}
|
||||
|
||||
// TODO handle overflow gracefully
|
||||
if (overflow) {
|
||||
LOG(FATAL) << "Decimal math overflow";
|
||||
LOG(WARNING) << "Decimal math overflow";
|
||||
res = max_decimal_value<ResultType>();
|
||||
}
|
||||
} else {
|
||||
if constexpr (scale_left) {
|
||||
@ -516,8 +586,10 @@ private:
|
||||
overflow |= common::mul_overflow(scale, scale, scale);
|
||||
}
|
||||
overflow |= common::mul_overflow(a, scale, a);
|
||||
// TODO handle overflow gracefully
|
||||
if (overflow) {
|
||||
LOG(FATAL) << "Decimal math overflow";
|
||||
LOG(WARNING) << "Decimal math overflow";
|
||||
return max_decimal_value<ResultType>();
|
||||
}
|
||||
} else {
|
||||
if constexpr (!IsDecimalNumber<A>) {
|
||||
@ -529,6 +601,31 @@ private:
|
||||
return apply(a, b, is_null);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool scale_left>
|
||||
static NativeResultType apply_scaled_mod(NativeResultType a, NativeResultType b,
|
||||
NativeResultType scale, UInt8& is_null) {
|
||||
if constexpr (check_overflow) {
|
||||
bool overflow = false;
|
||||
if constexpr (scale_left)
|
||||
overflow |= common::mul_overflow(a, scale, a);
|
||||
else
|
||||
overflow |= common::mul_overflow(b, scale, b);
|
||||
|
||||
// TODO handle overflow gracefully
|
||||
if (overflow) {
|
||||
LOG(WARNING) << "Decimal math overflow";
|
||||
return max_decimal_value<ResultType>();
|
||||
}
|
||||
} else {
|
||||
if constexpr (scale_left)
|
||||
a *= scale;
|
||||
else
|
||||
b *= scale;
|
||||
}
|
||||
|
||||
return apply(a, b, is_null);
|
||||
}
|
||||
};
|
||||
|
||||
/// Used to indicate undefined operation
|
||||
@ -646,10 +743,16 @@ private:
|
||||
static auto get_decimal_infos(const LeftDataType& type_left, const RightDataType& type_right) {
|
||||
ResultDataType type = decimal_result_type(type_left, type_right, OpTraits::is_multiply,
|
||||
OpTraits::is_division);
|
||||
typename ResultDataType::FieldType scale_a =
|
||||
type.scale_factor_for(type_left, OpTraits::is_multiply);
|
||||
typename ResultDataType::FieldType scale_b =
|
||||
type.scale_factor_for(type_right, OpTraits::is_multiply || OpTraits::is_division);
|
||||
typename ResultDataType::FieldType scale_a;
|
||||
typename ResultDataType::FieldType scale_b;
|
||||
if constexpr (OpTraits::is_division && IsDataTypeDecimal<RightDataType>) {
|
||||
scale_a = type_right.get_scale_multiplier();
|
||||
scale_b = 1;
|
||||
} else {
|
||||
scale_a = type.scale_factor_for(type_left, OpTraits::is_multiply);
|
||||
scale_b = type.scale_factor_for(type_right,
|
||||
OpTraits::is_multiply || OpTraits::is_division);
|
||||
}
|
||||
return std::make_tuple(type, scale_a, scale_b);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user