diff --git a/deps/oblib/src/common/object/ob_obj_compare.h b/deps/oblib/src/common/object/ob_obj_compare.h index d93287bf6..f7c080eed 100644 --- a/deps/oblib/src/common/object/ob_obj_compare.h +++ b/deps/oblib/src/common/object/ob_obj_compare.h @@ -259,7 +259,15 @@ public: } const double l = obj1.get_double(); const double r = obj2.get_double(); - if (l == r || fabs(l - r) < p) { + if (isnan(l) || isnan(r)) { + if (isnan(l) && isnan(r)) { + ret = 0; + } else if (isnan(l)) { + ret = 1; + } else { + ret = -1; + } + } else if (l == r || fabs(l - r) < p) { ret = 0; } else { ret = (l < r ? -1 : 1); diff --git a/src/share/datum/ob_datum_cmp_func_def.h b/src/share/datum/ob_datum_cmp_func_def.h index bee82eedb..2d9d3e870 100644 --- a/src/share/datum/ob_datum_cmp_func_def.h +++ b/src/share/datum/ob_datum_cmp_func_def.h @@ -147,7 +147,17 @@ struct ObFixedDoubleCmp: public ObDefined<> cmp_ret = 0; const double l = l_datum.get_double(); const double r = r_datum.get_double(); - if (l == r || fabs(l - r) < P) { + if (isnan(l) || isnan(r)) { + if (isnan(l) && isnan(r)) { + cmp_ret = 0; + } else if (isnan(l)) { + // l is nan, r is not nan:left always bigger than right + cmp_ret = 1; + } else { + // l is not nan, r is nan, left always less than right + cmp_ret = -1; + } + } else if (l == r || fabs(l - r) < P) { cmp_ret = 0; } else { cmp_ret = (l < r ? -1 : 1); diff --git a/src/share/vector/vector_basic_op.h b/src/share/vector/vector_basic_op.h index 283530703..56b753c1f 100644 --- a/src/share/vector/vector_basic_op.h +++ b/src/share/vector/vector_basic_op.h @@ -632,7 +632,17 @@ struct VecTCCmpCalc const double l = *reinterpret_cast(l_v); const double r = *reinterpret_cast(r_v); const double P = 5 / LOG_10[l_meta.get_scale() + 1]; - if (l == r || fabs(l - r) < P) { + if (isnan(l) || isnan(r)) { + if (isnan(l) && isnan(r)) { + cmp_ret = 0; + } else if (isnan(l)) { + // l is nan, r is not nan:left always bigger than right + cmp_ret = 1; + } else { + // l is not nan, r is nan, left always less than right + cmp_ret = -1; + } + } else if (l == r || fabs(l - r) < P) { cmp_ret = 0; } else { cmp_ret = (l < r ? -1: 1); diff --git a/src/sql/engine/expr/ob_expr_func_round.cpp b/src/sql/engine/expr/ob_expr_func_round.cpp index 501cce07d..1235789c6 100644 --- a/src/sql/engine/expr/ob_expr_func_round.cpp +++ b/src/sql/engine/expr/ob_expr_func_round.cpp @@ -305,7 +305,7 @@ static int do_round_by_type( } case ObDecimalIntType: { if (OB_FAIL(ObExprFuncRound::calc_round_decimalint( - in_meta, out_meta, round_scale, x_datum, res_datum))) { + in_meta, out_meta, GET_SCALE_FOR_CALC(round_scale), x_datum, res_datum))) { LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(round_scale)); } break; @@ -434,7 +434,7 @@ static int do_round_by_type_batch_with_check(const int64_t scale, const ObExpr & if (x_datum.is_null()) { results[i].set_null(); } else if (OB_FAIL(ObExprFuncRound::calc_round_decimalint( - in_meta, out_meta, scale, x_datum, results[i]))) { + in_meta, out_meta, GET_SCALE_FOR_CALC(scale), x_datum, results[i]))) { LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale)); } } @@ -567,7 +567,7 @@ static int do_round_by_type_vector(const int64_t scale, const ObExpr &expr, for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) { CHECK_ROUND_VECTOR(); if (OB_FAIL((ObExprFuncRound::calc_round_decimalint)( - in_meta, out_meta, scale, left_vec, res_vec, j))) { + in_meta, out_meta, GET_SCALE_FOR_CALC(scale), left_vec, res_vec, j))) { LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale)); } eval_flags.set(j); @@ -656,7 +656,7 @@ static int do_round_by_type_batch_without_check(const int64_t scale, const ObExp const ObDatumMeta &out_meta = expr.datum_meta_; for (int64_t i = 0; OB_SUCC(ret) && i < batch_size; ++i) { if (OB_FAIL(ObExprFuncRound::calc_round_decimalint( - in_meta, out_meta, scale, x_datums[i], results[i]))) { + in_meta, out_meta, GET_SCALE_FOR_CALC(scale), x_datums[i], results[i]))) { LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale)); } } diff --git a/src/sql/engine/expr/ob_expr_util.cpp b/src/sql/engine/expr/ob_expr_util.cpp index 7b06d87bb..9e4ddeef2 100644 --- a/src/sql/engine/expr/ob_expr_util.cpp +++ b/src/sql/engine/expr/ob_expr_util.cpp @@ -459,13 +459,13 @@ int ObExprUtil::get_mb_str_info(const ObString &str, double ObExprUtil::round_double(double val, int64_t dec) { - const double pow_val = std::pow(10, static_cast(std::abs(dec))); - volatile double val_div_tmp = val / pow_val; - volatile double val_mul_tmp = val * pow_val; - volatile double res = 0.0; + const double pow_val = std::pow(10.0, static_cast(std::abs(dec))); + double val_div_tmp = val / pow_val; + double val_mul_tmp = val * pow_val; + double res = 0.0; if (dec < 0 && std::isinf(pow_val)) { res = 0.0; - } else if (dec >= 0 && std::isinf(val_mul_tmp)) { + } else if (dec >= 0 && !std::isfinite(val_mul_tmp)) { res = val; } else { res = dec < 0 ? rint(val_div_tmp) * pow_val : rint(val_mul_tmp) / pow_val; @@ -474,23 +474,6 @@ double ObExprUtil::round_double(double val, int64_t dec) return res; } -double ObExprUtil::round_double_nearest(double val, int64_t dec) -{ - const double pow_val = std::pow(10, static_cast(std::abs(dec))); - volatile double val_div_tmp = val / pow_val; - volatile double val_mul_tmp = val * pow_val; - volatile double res = 0.0; - if (dec < 0 && std::isinf(pow_val)) { - res = 0.0; - } else if (dec >= 0 && std::isinf(val_mul_tmp)) { - res = val; - } else { - res = dec < 0 ? std::round(val_div_tmp) * pow_val : std::round(val_mul_tmp) / pow_val; - } - LOG_DEBUG("round double done", K(val), K(dec), K(res)); - return res; -} - uint64_t ObExprUtil::round_uint64(uint64_t val, int64_t dec) { uint64_t res = 0; @@ -524,7 +507,7 @@ double ObExprUtil::trunc_double(double val, int64_t dec) volatile double res = 0.0; if (dec < 0 && std::isinf(pow_val)) { res = 0.0; - } else if (dec >= 0 && std::isinf(val_mul_tmp)) { + } else if (dec >= 0 && !std::isfinite(val_mul_tmp)) { res = val; } else { if (val >= 0) { diff --git a/src/sql/engine/expr/ob_expr_util.h b/src/sql/engine/expr/ob_expr_util.h index 9e3db848b..e9dfa98e9 100644 --- a/src/sql/engine/expr/ob_expr_util.h +++ b/src/sql/engine/expr/ob_expr_util.h @@ -80,8 +80,6 @@ public: common::ObIArray &byte_offset); // 将double round到小数点后或者小数点前指定位置 static double round_double(double val, int64_t dec); - // 将double round到小数点后指定位置,.5 -> 1 而不是0 - static double round_double_nearest(double val, int64_t dec); static uint64_t round_uint64(uint64_t val, int64_t dec); // 将double trunc到小数点后或者小数点前指定位置 static double trunc_double(double val, int64_t dec);