Implement user-defined aggregate function overloading

This commit is contained in:
obdev
2023-01-05 08:38:19 +00:00
committed by ob-robot
parent a0b1894058
commit cce7bdeb9d
2 changed files with 167 additions and 115 deletions

View File

@ -15,6 +15,8 @@
#include "sql/ob_spi.h"
#include "pl/ob_pl.h"
#include "pl/ob_pl_stmt.h"
#include "pl/ob_pl_resolver.h"
#include "sql/resolver/ob_resolver_utils.h"
namespace oceanbase
{
@ -23,8 +25,46 @@ using namespace share::schema;
namespace sql
{
int ObPlAggUdfFunction::pick_routine(ObSEArray<const ObIRoutineInfo *, 4> &routine_infos,
const ObIRoutineInfo *&routine_info,
ObIArray<ObExprResType> &param_type)
{
int ret = OB_SUCCESS;
routine_info = NULL;
if (routine_infos.count() == 1) {
routine_info = routine_infos.at(0);
} else {
ObSEArray<ObRawExpr *, 4> mock_exec_expr;
ObRawExprFactory *expr_factory = exec_ctx_->get_expr_factory();
CK (OB_NOT_NULL(expr_factory));
CK (OB_NOT_NULL(allocator_));
CK (OB_NOT_NULL(session_info_));
CK (OB_NOT_NULL(exec_ctx_->get_sql_ctx()));
CK (OB_NOT_NULL(exec_ctx_->get_sql_ctx()->schema_guard_));
CK (OB_NOT_NULL(exec_ctx_->get_sql_proxy()));
for (int64_t i = 0; OB_SUCC(ret) && i < param_type.count(); ++i) {
ObConstRawExpr *c_expr = NULL;
OZ (expr_factory->create_raw_expr(T_QUESTIONMARK, c_expr));
OX (c_expr->set_result_type(param_type.at(i)));
OZ (mock_exec_expr.push_back(c_expr));
}
if (OB_SUCC(ret)) {
pl::ObPLPackageGuard package_guard(session_info_->get_effective_tenant_id());
pl::ObPLResolveCtx resolve_ctx(*allocator_,
*session_info_,
*exec_ctx_->get_sql_ctx()->schema_guard_,
package_guard,
*exec_ctx_->get_sql_proxy(),
false /*is_ps*/);
OZ (ObResolverUtils::pick_routine(resolve_ctx, mock_exec_expr, routine_infos, routine_info));
}
}
return ret;
}
int ObPlAggUdfFunction::get_package_routine_info(const ObString &routine_name,
const ObRoutineInfo *&routine_info)
const ObRoutineInfo *&routine_info,
ObIArray<ObExprResType> &param_type)
{
int ret = OB_SUCCESS;
ObSqlCtx *sql_ctx = NULL;
@ -32,6 +72,7 @@ int ObPlAggUdfFunction::get_package_routine_info(const ObString &routine_name,
ObSEArray<const ObIRoutineInfo *, 4> routine_infos;
ObRoutineType routine_type = share::schema::ObRoutineType::ROUTINE_FUNCTION_TYPE;
routine_info = NULL;
const ObIRoutineInfo *base_routine_info = NULL;
const ObUDTTypeInfo *udt_info = NULL;
if (OB_ISNULL(session_info_) || OB_ISNULL(exec_ctx_) ||
OB_ISNULL(sql_ctx = exec_ctx_->get_sql_ctx()) ||
@ -51,8 +92,10 @@ int ObPlAggUdfFunction::get_package_routine_info(const ObString &routine_name,
routine_type,
routine_infos))) {
LOG_WARN("failed to get package routine infos", K(ret));
} else if (OB_UNLIKELY(routine_infos.count() != 1) ||
OB_ISNULL(routine_info = static_cast<const ObRoutineInfo *>(routine_infos.at(0)))) {
} else if (OB_FAIL(pick_routine(routine_infos, base_routine_info, param_type))) {
LOG_WARN("get unexpected error", K(routine_infos), K(base_routine_info), K(ret));
} else if (OB_ISNULL(base_routine_info) ||
OB_ISNULL(routine_info = static_cast<const ObRoutineInfo *>(base_routine_info))) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected error", K(routine_infos), K(routine_info), K(ret));
} else {
@ -208,11 +251,9 @@ int ObPlAggUdfFunction::process_init_pl_agg_udf(ObObjParam &pl_obj)
OB_ISNULL(exec_ctx_->get_sql_ctx()->schema_guard_) || OB_ISNULL(allocator_)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret), K(exec_ctx_), K(allocator_));
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info))) {
LOG_WARN("failed to get package routine info", K(ret));
} else {
ObSEArray<ObUDFParamDesc, 4> params_desc;
ObSEArray<ObExprResType, 4> params_type;
ObSEArray<ObUDFParamDesc, 5> params_desc;
ObSEArray<ObExprResType, 5> params_type;
pl::ObPLUDTNS ns(*exec_ctx_->get_sql_ctx()->schema_guard_);
pl::ObPLDataType pl_type;
pl_type.set_user_type_id(pl::PL_RECORD_TYPE, type_id_);
@ -232,11 +273,14 @@ int ObPlAggUdfFunction::process_init_pl_agg_udf(ObObjParam &pl_obj)
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
param_type.set_ext();
param_type.set_udt_id(type_id_);
if (OB_FAIL(params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::OutType::LOCAL_OUT, 0)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info, params_type))) {
LOG_WARN("failed to get package routine info", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 1)) {
@ -269,8 +313,6 @@ int ObPlAggUdfFunction::process_calc_pl_agg_udf(ObObjParam &pl_obj,
LOG_WARN("failed to check params validty", K(ret));
} else if (is_null_params) {
/*no nothing*/
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info))) {
LOG_WARN("failed to get package routine info", K(ret));
} else if (OB_FAIL(process_obj_params(const_cast<ObObj*>(obj_params), param_num))) {
LOG_WARN("failed to process obj params", K(ret));
} else {
@ -284,6 +326,7 @@ int ObPlAggUdfFunction::process_calc_pl_agg_udf(ObObjParam &pl_obj,
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
param_type.set_ext();
param_type.set_udt_id(type_id_);
if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(
@ -308,6 +351,8 @@ int ObPlAggUdfFunction::process_calc_pl_agg_udf(ObObjParam &pl_obj,
} else if (OB_ISNULL(udf_params)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret));
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info, all_params_type))) {
LOG_WARN("failed to get package routine info", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 1)) {
@ -354,48 +399,49 @@ int ObPlAggUdfFunction::process_merge_pl_agg_udf(ObObjParam &pl_obj,
ObString routine_name(strlen(str), str);
ParamStore *udf_params = NULL;
ObObj tmp_result;
if (OB_FAIL(get_package_routine_info(routine_name, routine_info))) {
//for pl agg udf, type member ODCIAggregateMerge() the first param must be self and is IN OUT,
//the second param is IN. so we need rebuild relation infos.
//see oracle url:https://docs.oracle.com/cd/B28359_01/appdev.111/b28425/ext_agg_ref.htm#CACBJHHI
ObSEArray<ObExprResType, 4> params_type;
ObSEArray<ObUDFParamDesc, 4> params_desc;
ObSEArray<ObUDFParamDesc, 4> all_params_desc;
ObSEArray<ObExprResType, 4> all_params_type;
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
param_type.set_ext();
param_type.set_udt_id(type_id_);
if (OB_FAIL(params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(build_in_params_store(pl_obj, true, &pl_obj2, 1, params_desc,
params_type, udf_params))) {
LOG_WARN("failed to build in params store", K(ret));
} else if (OB_ISNULL(udf_params)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret));
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 0)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info, all_params_type))) {
LOG_WARN("failed to get package routine info", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 1)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected error", K(ret), K(udf_params->count()));
} else {
//for pl agg udf, type member ODCIAggregateMerge() the first param must be self and is IN OUT,
//the second param is IN. so we need rebuild relation infos.
//see oracle url:https://docs.oracle.com/cd/B28359_01/appdev.111/b28425/ext_agg_ref.htm#CACBJHHI
ObSEArray<ObUDFParamDesc, 4> params_desc;
ObSEArray<ObUDFParamDesc, 4> all_params_desc;
ObSEArray<ObExprResType, 4> params_type;
ObSEArray<ObExprResType, 4> all_params_type;
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
param_type.set_ext();
if (OB_FAIL(params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(build_in_params_store(pl_obj, true, &pl_obj2, 1, params_desc,
params_type, udf_params))) {
LOG_WARN("failed to build in params store", K(ret));
} else if (OB_ISNULL(udf_params)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret));
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 0)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 1)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected error", K(ret), K(udf_params->count()));
} else {
udf_params->at(0).copy_value_or_obj(pl_obj, true);
LOG_TRACE("Succeed to process merge pl agg udf", K(pl_obj), K(tmp_result));
}
udf_params->at(0).copy_value_or_obj(pl_obj, true);
LOG_TRACE("Succeed to process merge pl agg udf", K(pl_obj), K(tmp_result));
}
return ret;
}
@ -408,80 +454,81 @@ int ObPlAggUdfFunction::process_get_pl_agg_udf_result(ObObjParam &pl_obj,
ObString routine_name(strlen(str), str);
ParamStore *udf_params = NULL;
ObObj tmp_result;
if (OB_FAIL(get_package_routine_info(routine_name, routine_info))) {
LOG_WARN("failed to get package routine info", K(ret));
//for pl agg udf, type member ODCIAggregateTerminate() the first param must be self and is IN,
//the second param is OUT, the third param is number. so we need rebuild relation infos.
//see oracle url:https://docs.oracle.com/cd/B28359_01/appdev.111/b28425/ext_agg_ref.htm#CACBJHHI
ObSEArray<ObUDFParamDesc, 4> params_desc;
ObSEArray<ObUDFParamDesc, 4> all_params_desc;
ObSEArray<ObExprResType, 4> params_type;
ObSEArray<ObExprResType, 4> all_params_type;
result.set_meta_type(result_type_);
if (OB_FAIL(params_type.push_back(result_type_))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 0)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(build_in_params_store(pl_obj, false, &result, 1, params_desc,
params_type, udf_params))) {
LOG_WARN("failed to build in params store", K(ret));
} else if (OB_ISNULL(udf_params)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret));
} else {
//for pl agg udf, type member ODCIAggregateTerminate() the first param must be self and is IN,
//the second param is OUT, the third param is number. so we need rebuild relation infos.
//see oracle url:https://docs.oracle.com/cd/B28359_01/appdev.111/b28425/ext_agg_ref.htm#CACBJHHI
ObSEArray<ObUDFParamDesc, 4> params_desc;
ObSEArray<ObUDFParamDesc, 4> all_params_desc;
ObSEArray<ObExprResType, 4> params_type;
ObSEArray<ObExprResType, 4> all_params_type;
result.set_meta_type(result_type_);
if (OB_FAIL(params_type.push_back(result_type_))) {
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
ObExprResType flags_type(alloc);
param_type.set_ext();
param_type.set_udt_id(type_id_);
flags_type.set_number();
number::ObNumber num;
ObObj number;
number.set_number(ObNumberType, num);
ObObjParam param;
param.reset();
number.copy_value_or_obj(param, true);
if (OB_FAIL(udf_params->push_back(param))) {
LOG_WARN("failed to push back obj param");
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 0)))) {
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(build_in_params_store(pl_obj, false, &result, 1, params_desc,
params_type, udf_params))) {
LOG_WARN("failed to build in params store", K(ret));
} else if (OB_ISNULL(udf_params)) {
} else if (OB_FAIL(all_params_type.push_back(result_type_))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 1)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(all_params_type.push_back(flags_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(get_package_routine_info(routine_name, routine_info, all_params_type))) {
LOG_WARN("failed to get package routine info", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 2)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected null", K(ret));
} else {
common::ObArenaAllocator alloc;
ObExprResType param_type(alloc);
ObExprResType flags_type(alloc);
param_type.set_ext();
flags_type.set_number();
number::ObNumber num;
ObObj number;
number.set_number(ObNumberType, num);
ObObjParam param;
param.reset();
number.copy_value_or_obj(param, true);
if (OB_FAIL(udf_params->push_back(param))) {
LOG_WARN("failed to push back obj param");
} else if (OB_FAIL(all_params_type.push_back(param_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(all_params_type.push_back(result_type_))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(
ObUDFParamDesc(ObUDFParamDesc::LOCAL_OUT, 1)))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(all_params_type.push_back(flags_type))) {
LOG_WARN("failed to push back type", K(ret));
} else if (OB_FAIL(all_params_desc.push_back(ObUDFParamDesc()))) {
LOG_WARN("failed to push back param desc", K(ret));
} else if (OB_FAIL(call_pl_engine_exectue_udf(*udf_params, routine_info, tmp_result))) {
LOG_WARN("failed to call pl engine exectue udf", K(ret));
} else if (OB_UNLIKELY(udf_params->count() < 2)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("get unexpected error", K(ret), K(udf_params->count()));
} else if (!ob_is_lob_tc(result_type_.get_type())) {
ObObj src_obj;
udf_params->at(1).copy_value_or_obj(src_obj, true);
ObCastMode cast_mode = CM_NONE;
if (OB_FAIL(ObSQLUtils::get_default_cast_mode(session_info_, cast_mode))) {
LOG_WARN("failed to get default cast mode", K(ret));
} else {
ObCastCtx cast_ctx(allocator_, NULL, cast_mode, ObCharset::get_system_collation(), NULL);
if (OB_FAIL(ObObjCaster::to_type(result_type_.get_type(), cast_ctx, src_obj, result))) {
LOG_WARN("failed to cast type", K(ret));
} else {
LOG_TRACE("succeed to process get pl agg udf result", K(src_obj), K(result));
}
}
LOG_WARN("get unexpected error", K(ret), K(udf_params->count()));
} else if (!ob_is_lob_tc(result_type_.get_type())) {
ObObj src_obj;
udf_params->at(1).copy_value_or_obj(src_obj, true);
ObCastMode cast_mode = CM_NONE;
if (OB_FAIL(ObSQLUtils::get_default_cast_mode(session_info_, cast_mode))) {
LOG_WARN("failed to get default cast mode", K(ret));
} else {
udf_params->at(1).copy_value_or_obj(result, true);
LOG_TRACE("succeed to process get pl agg udf result", K(result));
ObCastCtx cast_ctx(allocator_, NULL, cast_mode, ObCharset::get_system_collation(), NULL);
if (OB_FAIL(ObObjCaster::to_type(result_type_.get_type(), cast_ctx, src_obj, result))) {
LOG_WARN("failed to cast type", K(ret));
} else {
LOG_TRACE("succeed to process get pl agg udf result", K(src_obj), K(result));
}
}
} else {
udf_params->at(1).copy_value_or_obj(result, true);
LOG_TRACE("succeed to process get pl agg udf result", K(result));
}
}
return ret;
}

View File

@ -72,8 +72,13 @@ class ObPlAggUdfFunction
int get_package_udf_id(const ObString &routine_name,
const share::schema::ObRoutineInfo *&routine_info);
int pick_routine(ObSEArray<const ObIRoutineInfo *, 4> &routine_infos,
const ObIRoutineInfo *&routine_info,
ObIArray<ObExprResType> &param_type);
int get_package_routine_info(const ObString &routine_name,
const share::schema::ObRoutineInfo *&routine_info);
const share::schema::ObRoutineInfo *&routine_info,
ObIArray<ObExprResType> &param_type);
int check_types(const ObObj *obj_params,
int64_t param_num,