Fix nan comparison result of fixed double

This commit is contained in:
hezuojiao 2024-05-16 06:47:53 +00:00 committed by ob-robot
parent 84cdeb1437
commit 238f10498b
6 changed files with 41 additions and 32 deletions

View File

@ -259,7 +259,15 @@ public:
}
const double l = obj1.get_double();
const double r = obj2.get_double();
if (l == r || fabs(l - r) < p) {
if (isnan(l) || isnan(r)) {
if (isnan(l) && isnan(r)) {
ret = 0;
} else if (isnan(l)) {
ret = 1;
} else {
ret = -1;
}
} else if (l == r || fabs(l - r) < p) {
ret = 0;
} else {
ret = (l < r ? -1 : 1);

View File

@ -147,7 +147,17 @@ struct ObFixedDoubleCmp: public ObDefined<>
cmp_ret = 0;
const double l = l_datum.get_double();
const double r = r_datum.get_double();
if (l == r || fabs(l - r) < P) {
if (isnan(l) || isnan(r)) {
if (isnan(l) && isnan(r)) {
cmp_ret = 0;
} else if (isnan(l)) {
// l is nan, r is not nan:left always bigger than right
cmp_ret = 1;
} else {
// l is not nan, r is nan, left always less than right
cmp_ret = -1;
}
} else if (l == r || fabs(l - r) < P) {
cmp_ret = 0;
} else {
cmp_ret = (l < r ? -1 : 1);

View File

@ -632,7 +632,17 @@ struct VecTCCmpCalc<VEC_TC_FIXED_DOUBLE, VEC_TC_FIXED_DOUBLE>
const double l = *reinterpret_cast<const double *>(l_v);
const double r = *reinterpret_cast<const double *>(r_v);
const double P = 5 / LOG_10[l_meta.get_scale() + 1];
if (l == r || fabs(l - r) < P) {
if (isnan(l) || isnan(r)) {
if (isnan(l) && isnan(r)) {
cmp_ret = 0;
} else if (isnan(l)) {
// l is nan, r is not nan:left always bigger than right
cmp_ret = 1;
} else {
// l is not nan, r is nan, left always less than right
cmp_ret = -1;
}
} else if (l == r || fabs(l - r) < P) {
cmp_ret = 0;
} else {
cmp_ret = (l < r ? -1: 1);

View File

@ -305,7 +305,7 @@ static int do_round_by_type(
}
case ObDecimalIntType: {
if (OB_FAIL(ObExprFuncRound::calc_round_decimalint(
in_meta, out_meta, round_scale, x_datum, res_datum))) {
in_meta, out_meta, GET_SCALE_FOR_CALC(round_scale), x_datum, res_datum))) {
LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(round_scale));
}
break;
@ -434,7 +434,7 @@ static int do_round_by_type_batch_with_check(const int64_t scale, const ObExpr &
if (x_datum.is_null()) {
results[i].set_null();
} else if (OB_FAIL(ObExprFuncRound::calc_round_decimalint(
in_meta, out_meta, scale, x_datum, results[i]))) {
in_meta, out_meta, GET_SCALE_FOR_CALC(scale), x_datum, results[i]))) {
LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale));
}
}
@ -567,7 +567,7 @@ static int do_round_by_type_vector(const int64_t scale, const ObExpr &expr,
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
if (OB_FAIL((ObExprFuncRound::calc_round_decimalint<LeftVec, ResVec>)(
in_meta, out_meta, scale, left_vec, res_vec, j))) {
in_meta, out_meta, GET_SCALE_FOR_CALC(scale), left_vec, res_vec, j))) {
LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale));
}
eval_flags.set(j);
@ -656,7 +656,7 @@ static int do_round_by_type_batch_without_check(const int64_t scale, const ObExp
const ObDatumMeta &out_meta = expr.datum_meta_;
for (int64_t i = 0; OB_SUCC(ret) && i < batch_size; ++i) {
if (OB_FAIL(ObExprFuncRound::calc_round_decimalint(
in_meta, out_meta, scale, x_datums[i], results[i]))) {
in_meta, out_meta, GET_SCALE_FOR_CALC(scale), x_datums[i], results[i]))) {
LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale));
}
}

View File

@ -459,13 +459,13 @@ int ObExprUtil::get_mb_str_info(const ObString &str,
double ObExprUtil::round_double(double val, int64_t dec)
{
const double pow_val = std::pow(10, static_cast<double>(std::abs(dec)));
volatile double val_div_tmp = val / pow_val;
volatile double val_mul_tmp = val * pow_val;
volatile double res = 0.0;
const double pow_val = std::pow(10.0, static_cast<double>(std::abs(dec)));
double val_div_tmp = val / pow_val;
double val_mul_tmp = val * pow_val;
double res = 0.0;
if (dec < 0 && std::isinf(pow_val)) {
res = 0.0;
} else if (dec >= 0 && std::isinf(val_mul_tmp)) {
} else if (dec >= 0 && !std::isfinite(val_mul_tmp)) {
res = val;
} else {
res = dec < 0 ? rint(val_div_tmp) * pow_val : rint(val_mul_tmp) / pow_val;
@ -474,23 +474,6 @@ double ObExprUtil::round_double(double val, int64_t dec)
return res;
}
double ObExprUtil::round_double_nearest(double val, int64_t dec)
{
const double pow_val = std::pow(10, static_cast<double>(std::abs(dec)));
volatile double val_div_tmp = val / pow_val;
volatile double val_mul_tmp = val * pow_val;
volatile double res = 0.0;
if (dec < 0 && std::isinf(pow_val)) {
res = 0.0;
} else if (dec >= 0 && std::isinf(val_mul_tmp)) {
res = val;
} else {
res = dec < 0 ? std::round(val_div_tmp) * pow_val : std::round(val_mul_tmp) / pow_val;
}
LOG_DEBUG("round double done", K(val), K(dec), K(res));
return res;
}
uint64_t ObExprUtil::round_uint64(uint64_t val, int64_t dec)
{
uint64_t res = 0;
@ -524,7 +507,7 @@ double ObExprUtil::trunc_double(double val, int64_t dec)
volatile double res = 0.0;
if (dec < 0 && std::isinf(pow_val)) {
res = 0.0;
} else if (dec >= 0 && std::isinf(val_mul_tmp)) {
} else if (dec >= 0 && !std::isfinite(val_mul_tmp)) {
res = val;
} else {
if (val >= 0) {

View File

@ -80,8 +80,6 @@ public:
common::ObIArray<size_t> &byte_offset);
// 将double round到小数点后或者小数点前指定位置
static double round_double(double val, int64_t dec);
// 将double round到小数点后指定位置,.5 -> 1 而不是0
static double round_double_nearest(double val, int64_t dec);
static uint64_t round_uint64(uint64_t val, int64_t dec);
// 将double trunc到小数点后或者小数点前指定位置
static double trunc_double(double val, int64_t dec);