[FEAT MERGE] impl vectorization 2.0

Co-authored-by: Naynahs <cfzy002@126.com>
Co-authored-by: hwx65 <1780011298@qq.com>
Co-authored-by: oceanoverflow <oceanoverflow@gmail.com>
This commit is contained in:
obdev
2023-12-22 03:43:19 +00:00
committed by ob-robot
parent 1178245448
commit b6773084c6
592 changed files with 358124 additions and 303288 deletions

View File

@ -257,6 +257,29 @@ int ObExprFuncRound::calc_round_decimalint(
return ret;
}
template <typename LeftVec, typename ResVec>
int ObExprFuncRound::calc_round_decimalint(
const ObDatumMeta &in_meta, const ObDatumMeta &out_meta, const int64_t round_scale,
LeftVec *left_vec, ResVec *res_vec, const int64_t &idx)
{
int ret = OB_SUCCESS;
if (in_meta.scale_ != round_scale
|| get_decimalint_type(in_meta.precision_) != get_decimalint_type(out_meta.precision_)) {
ObDecimalIntBuilder res_val;
ObDatum left_datum(left_vec->get_payload(idx), left_vec->get_length(idx), false);
if (OB_FAIL(do_round_decimalint(
in_meta.precision_, in_meta.scale_, out_meta.precision_, out_meta.scale_, round_scale,
left_datum, res_val))) {
LOG_WARN("do_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(round_scale));
} else {
res_vec->set_decimal_int(idx, res_val.get_decimal_int(), res_val.get_int_bytes());
}
} else {
res_vec->set_decimal_int(idx, left_vec->get_decimal_int(idx), left_vec->get_length(idx));
}
return ret;
}
static int do_round_by_type(
const ObDatumMeta &in_meta, const ObDatumMeta &out_meta, const int64_t round_scale,
const ObDatum &x_datum, ObEvalCtx &ctx,
@ -340,12 +363,26 @@ static bool is_batch_need_cal_all(const ObDatum *x_datums,
const ObBitVector &eval_flags,
const int64_t batch_size)
{
bool ret = ObBitVector::bit_op_zero(skip, eval_flags, batch_size,
bool is_need = ObBitVector::bit_op_zero(skip, eval_flags, batch_size,
[](uint64_t l, uint64_t r) { return l | r; });
for (int64_t i = 0; ret && i < batch_size; ++i) {
ret = !(x_datums[i].is_null());
for (int64_t i = 0; is_need && i < batch_size; ++i) {
is_need = !(x_datums[i].is_null());
}
return ret;
return is_need;
}
template <typename LeftVec>
static bool is_vector_need_cal_all(LeftVec *left_vec,
const ObBitVector &skip,
const ObBitVector &eval_flags,
const EvalBound &bound)
{
bool is_need = ObBitVector::bit_op_zero(skip, eval_flags, bound,
[](uint64_t l, uint64_t r) { return l | r; });
for (int64_t j = bound.start(); is_need && j < bound.end(); ++j) {
is_need = !(left_vec->is_null(j));
}
return is_need;
}
static int do_round_by_type_batch_with_check(const int64_t scale, const ObExpr &expr,
@ -483,6 +520,111 @@ static int do_round_by_type_batch_with_check(const int64_t scale, const ObExpr &
return ret;
}
#define CHECK_ROUND_VECTOR() \
if (IsCheck) { \
if (skip.at(j) || eval_flags.at(j)) { \
continue; \
} else if (left_vec->is_null(j)) { \
res_vec->set_null(j); \
eval_flags.set(j); \
continue; \
} \
}
template <typename LeftVec, typename ResVec, bool IsCheck>
static int do_round_by_type_vector(const int64_t scale, const ObExpr &expr,
ObEvalCtx &ctx, const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
if (!IsCheck) { UNUSED(skip); }
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
ResVec *res_vec = static_cast<ResVec *>(expr.get_vector(ctx));
LeftVec *left_vec = static_cast<LeftVec *>(expr.args_[0]->get_vector(ctx));
const ObObjType x_type = expr.args_[0]->datum_meta_.type_;
switch (x_type) {
case ObNumberType:
case ObUNumberType: {
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
const number::ObNumber x_nmb(left_vec->get_number(j));
number::ObNumber res_nmb;
ObNumStackOnceAlloc tmp_alloc;
if (OB_FAIL(res_nmb.from(x_nmb, tmp_alloc))) {
LOG_WARN("get num from x_nmb failed", K(ret), K(x_nmb));
} else if (OB_FAIL(res_nmb.round(GET_SCALE_FOR_CALC(scale)))) {
LOG_WARN("eval round of res_nmb failed", K(ret), K(scale), K(res_nmb));
} else {
res_vec->set_number(j, res_nmb);
eval_flags.set(j);
}
}
break;
}
case ObDecimalIntType: {
const ObDatumMeta &in_meta = expr.args_[0]->datum_meta_;
const ObDatumMeta &out_meta = expr.datum_meta_;
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
if (OB_FAIL((ObExprFuncRound::calc_round_decimalint<LeftVec, ResVec>)(
in_meta, out_meta, scale, left_vec, res_vec, j))) {
LOG_WARN("calc_round_decimalint failed", K(ret), K(in_meta), K(out_meta), K(scale));
}
eval_flags.set(j);
}
break;
}
case ObFloatType: {
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
// if in Oracle mode, param_num must be 1(scale is 0)
// MySQL mode cannot be here. because if param type is float, calc type will be double.
res_vec->set_float(j, ObExprUtil::round_double(left_vec->get_float(j), scale));
eval_flags.set(j);
}
break;
}
case ObDoubleType: {
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
// if in Oracle mode, param_num must be 1(scale is 0)
res_vec->set_double(j, ObExprUtil::round_double(left_vec->get_double(j), scale));
eval_flags.set(j);
}
break;
}
case ObIntType: {
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
int64_t x_int = left_vec->get_int(j);
bool neg = x_int < 0;
x_int = neg ? -x_int : x_int;
int64_t res_int = static_cast<int64_t>(ObExprUtil::round_uint64(x_int, scale));
res_int = neg ? -res_int : res_int;
res_vec->set_int(j, res_int);
eval_flags.set(j);
}
break;
}
case ObUInt64Type: {
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
CHECK_ROUND_VECTOR();
uint64_t x_uint = left_vec->get_uint(j);
uint64_t res_uint = ObExprUtil::round_uint64(x_uint, scale);
res_vec->set_uint(j, res_uint);
eval_flags.set(j);
}
break;
}
default: {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected arg type", K(ret), K(x_type));
break;
}
}
return ret;
}
static int do_round_by_type_batch_without_check(const int64_t scale, const ObExpr &expr,
ObEvalCtx &ctx, const int64_t batch_size)
{
@ -609,6 +751,100 @@ int ObExprFuncRound::calc_round_expr_numeric1_batch(const ObExpr &expr,
return ret;
}
#define ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, res_vec) \
switch (left_format) { \
case VEC_FIXED: { \
ret = func_name<ObFixedLengthBase, res_vec>(expr, ctx, skip, bound); \
break; \
} \
case VEC_DISCRETE: { \
ret = func_name<ObDiscreteFormat, res_vec>(expr, ctx, skip, bound); \
break; \
} \
case VEC_CONTINUOUS: { \
ret = func_name<ObContinuousFormat, res_vec>(expr, ctx, skip, bound); \
break; \
} \
case VEC_UNIFORM: { \
ret = func_name<ObUniformFormat<false>, res_vec>(expr, ctx, skip, bound); \
break; \
} \
case VEC_UNIFORM_CONST: { \
ret = func_name<ObUniformFormat<true>, res_vec>(expr, ctx, skip, bound); \
break; \
} \
default: { \
ret = func_name<ObVectorBase, res_vec>(expr, ctx, skip, bound); \
} \
}
#define ROUND_DISPATCH_VECTOR_IN_RES_ARG_FORMAT(func_name) \
switch (res_format) { \
case VEC_FIXED: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObFixedLengthBase); \
break; \
} \
case VEC_DISCRETE: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObDiscreteFormat); \
break; \
} \
case VEC_CONTINUOUS: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObContinuousFormat); \
break; \
} \
case VEC_UNIFORM: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObUniformFormat<false>); \
break; \
} \
case VEC_UNIFORM_CONST: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObUniformFormat<true>); \
break; \
} \
default: { \
ROUND_DISPATCH_VECTOR_IN_LEFT_ARG_FORMAT(func_name, ObVectorBase); \
} \
}
int ObExprFuncRound::calc_round_expr_numeric1_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
if (OB_FAIL(expr.args_[0]->eval_vector(ctx, skip, bound))) {
LOG_WARN("eval arg failed", K(ret), K(expr));
} else {
VectorFormat res_format = expr.get_format(ctx);
VectorFormat left_format = expr.args_[0]->get_format(ctx);
ROUND_DISPATCH_VECTOR_IN_RES_ARG_FORMAT(inner_calc_round_expr_numeric1_vector);
}
return ret;
}
template <typename LeftVec, typename ResVec>
int ObExprFuncRound::inner_calc_round_expr_numeric1_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
ResVec *res_vec = static_cast<ResVec *>(expr.get_vector(ctx));
LeftVec *left_vec = static_cast<LeftVec *>(expr.args_[0]->get_vector(ctx));
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
if (is_vector_need_cal_all<LeftVec>(left_vec, skip, eval_flags, bound)) {
if (OB_FAIL((do_round_by_type_vector<LeftVec, ResVec, false>)(0, expr, ctx, skip, bound))) {
const ObObjType x_type = expr.args_[0]->datum_meta_.type_;
LOG_WARN("calc round by type failed", K(ret), K(x_type), K(expr));
}
} else {
if (OB_FAIL((do_round_by_type_vector<LeftVec, ResVec, true>)(0, expr, ctx, skip, bound))) {
const ObObjType x_type = expr.args_[0]->datum_meta_.type_;
LOG_WARN("calc round by type failed", K(ret), K(x_type), K(expr));
}
}
return ret;
}
int calc_round_expr_numeric2(const sql::ObExpr &expr, sql::ObEvalCtx &ctx,
sql::ObDatum &res_datum)
{
@ -725,6 +961,89 @@ int ObExprFuncRound::calc_round_expr_numeric2_batch(const ObExpr &expr,
return ret;
}
int ObExprFuncRound::calc_round_expr_numeric2_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
if (OB_FAIL(expr.args_[0]->eval_vector(ctx, skip, bound))) {
LOG_WARN("eval arg failed", K(ret), K(expr));
} else {
VectorFormat res_format = expr.get_format(ctx);
VectorFormat left_format = expr.args_[0]->get_format(ctx);
ROUND_DISPATCH_VECTOR_IN_RES_ARG_FORMAT(inner_calc_round_expr_numeric2_vector);
}
return ret;
}
template <typename LeftVec, typename ResVec>
int ObExprFuncRound::inner_calc_round_expr_numeric2_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
ObDatum *fmt_datum = NULL;
if (OB_FAIL(expr.args_[1]->eval(ctx, fmt_datum))) {
LOG_WARN("eval arg failed", K(ret), K(expr));
} else {
int64_t scale = 0;
// get scale
const ObObjType fmt_type = expr.args_[1]->datum_meta_.type_;
if (fmt_datum->is_null()) {
// do nothing
} else if (ObNumberType == fmt_type) {
const number::ObNumber fmt_nmb(fmt_datum->get_number());
if (OB_FAIL(fmt_nmb.extract_valid_int64_with_trunc(scale))) {
LOG_WARN("extract_valid_int64_with_trunc failed", K(ret), K(fmt_nmb));
}
} else if (ObIntType == fmt_type) {
scale = fmt_datum->get_int();
} else {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected fmt type", K(ret), K(fmt_type), K(expr));
}
if (OB_SUCC(ret)) {
if (fmt_datum->is_null()) {
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
ResVec *res_vec = static_cast<ResVec *>(expr.get_vector(ctx));
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
eval_flags.set(j);
res_vec->set_null(j);
}
} else {
if (is_mysql_mode()
&& (ob_is_number_tc(expr.args_[0]->datum_meta_.get_type())
|| ob_is_decimal_int_tc(expr.args_[0]->datum_meta_.get_type()))) {
if (expr.args_[0]->datum_meta_.scale_ < scale
// eg : select round(123.123, 100);
// -> result is 123.123
|| expr.datum_meta_.scale_ < scale) {
// eg : select round(123.123456789123456789123456789123456789, 50);
// -> result accuracy is precision:34, scale:30 (max result scale is 30)
scale = expr.datum_meta_.scale_;
}
}
LeftVec *left_vec = static_cast<LeftVec *>(expr.args_[0]->get_vector(ctx));
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
if (is_vector_need_cal_all<LeftVec>(left_vec, skip, eval_flags, bound)) {
if (OB_FAIL((do_round_by_type_vector<LeftVec, ResVec, false>)(scale, expr, ctx, skip, bound))) {
const ObObjType x_type = expr.args_[0]->datum_meta_.type_;
LOG_WARN("calc round by type failed", K(ret), K(x_type), K(expr));
}
} else {
if (OB_FAIL((do_round_by_type_vector<LeftVec, ResVec, true>)(scale, expr, ctx, skip, bound))) {
const ObObjType x_type = expr.args_[0]->datum_meta_.type_;
LOG_WARN("calc round by type failed", K(ret), K(x_type), K(expr));
}
}
}
}
}
return ret;
}
int calc_round_expr_datetime_inner(const ObDatum &x_datum, const ObString &fmt_str,
ObEvalCtx &ctx, int64_t &dt,
const sql::ObExpr &expr)
@ -816,6 +1135,53 @@ int ObExprFuncRound::calc_round_expr_datetime1_batch(const ObExpr &expr,
return ret;
}
int ObExprFuncRound::calc_round_expr_datetime1_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
if (OB_FAIL(expr.args_[0]->eval_vector(ctx, skip, bound))) {
LOG_WARN("eval arg0 failed", K(ret), K(expr));
} else {
VectorFormat res_format = expr.get_format(ctx);
VectorFormat left_format = expr.args_[0]->get_format(ctx);
ROUND_DISPATCH_VECTOR_IN_RES_ARG_FORMAT(inner_calc_round_expr_datetime1_vector);
}
return ret;
}
template <typename LeftVec, typename ResVec>
int ObExprFuncRound::inner_calc_round_expr_datetime1_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
ObString fmt_str("DD");
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
ResVec *res_vec = static_cast<ResVec *>(expr.get_vector(ctx));
LeftVec *left_vec = static_cast<LeftVec *>(expr.args_[0]->get_vector(ctx));
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
if (skip.at(j) || eval_flags.at(j)) {
continue;
}
int64_t dt = 0;
if (left_vec->is_null(j)) {
res_vec->set_null(j);
} else {
ObDatum left_datum(left_vec->get_payload(j), left_vec->get_length(j), false);
if (OB_FAIL(calc_round_expr_datetime_inner(left_datum, fmt_str, ctx, dt, expr))) {
LOG_WARN("calc_round_expr_datetime_inner failed", K(ret));
} else {
res_vec->set_datetime(j, dt);
}
}
eval_flags.set(j);
}
return ret;
}
int calc_round_expr_datetime2(const sql::ObExpr &expr, sql::ObEvalCtx &ctx,
sql::ObDatum &res_datum)
{
@ -871,6 +1237,58 @@ int ObExprFuncRound::calc_round_expr_datetime2_batch(const ObExpr &expr,
return ret;
}
int ObExprFuncRound::calc_round_expr_datetime2_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
if (OB_FAIL(expr.args_[0]->eval_vector(ctx, skip, bound))) {
LOG_WARN("eval arg failed", K(ret), K(expr));
} else {
VectorFormat res_format = expr.get_format(ctx);
VectorFormat left_format = expr.args_[0]->get_format(ctx);
ROUND_DISPATCH_VECTOR_IN_RES_ARG_FORMAT(inner_calc_round_expr_datetime2_vector);
}
return ret;
}
template <typename LeftVec, typename ResVec>
int ObExprFuncRound::inner_calc_round_expr_datetime2_vector(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const EvalBound &bound)
{
int ret = OB_SUCCESS;
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
ObDatum *fmt_datum = NULL;
if (OB_FAIL(expr.args_[1]->eval(ctx, fmt_datum))) {
LOG_WARN("eval arg failed", K(ret), K(expr));
} else {
LeftVec *left_vec = static_cast<LeftVec *>(expr.args_[0]->get_vector(ctx));
ResVec *res_vec = static_cast<ResVec *>(expr.get_vector(ctx));
for (int64_t j = bound.start(); OB_SUCC(ret) && j < bound.end(); ++j) {
if (skip.at(j) || eval_flags.at(j)) {
continue;
}
int64_t dt = 0;
eval_flags.set(j);
if (left_vec->is_null(j) || fmt_datum->is_null()) {
res_vec->set_null(j);
} else {
ObDatum left_datum(left_vec->get_payload(j), left_vec->get_length(j), false);
if (OB_FAIL(calc_round_expr_datetime_inner(left_datum, fmt_datum->get_string(),
ctx, dt, expr))) {
LOG_WARN("calc_round_expr_datetime_inner failed", K(ret));
} else {
res_vec->set_datetime(j, dt);
}
}
}
}
return ret;
}
int ObExprFuncRound::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
@ -895,12 +1313,14 @@ int ObExprFuncRound::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr
// Only implement vectorization when parameter 0 is batch and parameter 1 is constant
if (rt_expr.args_[0]->is_batch_result() && !(rt_expr.args_[1]->is_batch_result())) {
rt_expr.eval_batch_func_ = calc_round_expr_datetime2_batch;
rt_expr.eval_vector_func_ = calc_round_expr_datetime2_vector;
}
} else {
rt_expr.eval_func_ = calc_round_expr_numeric2;
// Only implement vectorization when parameter 0 is batch and parameter 1 is constant
if (rt_expr.args_[0]->is_batch_result() && !(rt_expr.args_[1]->is_batch_result())) {
rt_expr.eval_batch_func_ = calc_round_expr_numeric2_batch;
rt_expr.eval_vector_func_ = calc_round_expr_numeric2_vector;
}
}
} else {
@ -908,15 +1328,18 @@ int ObExprFuncRound::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr
// Only implement vectorization when parameter 0 is batch and parameter 1 is constant
if (rt_expr.args_[0]->is_batch_result() && !(rt_expr.args_[1]->is_batch_result())) {
rt_expr.eval_batch_func_ = calc_round_expr_numeric2_batch;
rt_expr.eval_vector_func_ = calc_round_expr_numeric2_vector;
}
}
} else {
if (ObDateTimeType == x_type) {
rt_expr.eval_func_ = calc_round_expr_datetime1;
rt_expr.eval_batch_func_ = calc_round_expr_datetime1_batch;
rt_expr.eval_vector_func_ = calc_round_expr_datetime1_vector;
} else {
rt_expr.eval_func_ = calc_round_expr_numeric1;
rt_expr.eval_batch_func_ = calc_round_expr_numeric1_batch;
rt_expr.eval_vector_func_ = calc_round_expr_numeric1_vector;
}
}
}