Files
oceanbase/src/sql/engine/expr/ob_expr_vector.cpp
2024-09-19 07:23:47 +00:00

378 lines
14 KiB
C++

/**
* Copyright (c) 2023 OceanBase
* OceanBase is licensed under Mulan PubL v2.
* You can use this software according to the terms and conditions of the Mulan
* PubL v2. You may obtain a copy of Mulan PubL v2 at:
* http://license.coscl.org.cn/MulanPubL-2.0
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
* KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
* NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. See the
* Mulan PubL v2 for more details.
*/
#define USING_LOG_PREFIX SQL_ENG
#include "sql/engine/expr/ob_expr_vector.h"
#include "sql/engine/ob_subschema_ctx.h"
#include "sql/engine/ob_exec_context.h"
#include "sql/engine/expr/ob_expr_lob_utils.h"
#include "sql/engine/expr/ob_array_expr_utils.h"
#include "sql/engine/expr/ob_array_cast.h"
#include "share/vector_type/ob_vector_l2_distance.h"
#include "share/vector_type/ob_vector_cosine_distance.h"
#include "share/vector_type/ob_vector_ip_distance.h"
#include "share/vector_type/ob_vector_norm.h"
#include "share/vector_type/ob_vector_l1_distance.h"
namespace oceanbase
{
namespace sql
{
ObExprVector::ObExprVector(ObIAllocator &alloc,
ObExprOperatorType type,
const char *name,
int32_t param_num,
int32_t dimension) : ObFuncExprOperator(alloc, type, name, param_num, VALID_FOR_GENERATED_COL, dimension)
{
}
// [a,b,c,...] is array type, there is no dim_cnt_ in ObCollectionArrayType
int ObExprVector::calc_result_type2(
ObExprResType &type,
ObExprResType &type1,
ObExprResType &type2,
common::ObExprTypeCtx &type_ctx) const
{
int ret = OB_SUCCESS;
uint16_t unused_id = UINT16_MAX;
if (OB_FAIL(ObArrayExprUtils::calc_cast_type2(type1, type2, type_ctx, unused_id))) {
LOG_WARN("failed to calc cast type", K(ret), K(type1));
} else {
type.set_type(ObDoubleType);
type.set_calc_type(ObDoubleType);
}
return ret;
}
int ObExprVector::calc_result_type1(
ObExprResType &type,
ObExprResType &type1,
common::ObExprTypeCtx &type_ctx) const
{
int ret = OB_SUCCESS;
if (OB_FAIL(ObArrayExprUtils::calc_cast_type(type1, type_ctx))) {
LOG_WARN("failed to calc cast type", K(ret), K(type1));
} else {
type.set_type(ObDoubleType);
type.set_calc_type(ObDoubleType);
}
return ret;
}
ObExprVectorDistance::ObExprVectorDistance(ObIAllocator &alloc)
: ObExprVector(alloc, T_FUN_SYS_VECTOR_DISTANCE, N_VECTOR_DISTANCE, TWO_OR_THREE, NOT_ROW_DIMENSION)
{}
ObExprVectorDistance::ObExprVectorDistance(
ObIAllocator &alloc,
ObExprOperatorType type,
const char *name,
int32_t param_num,
int32_t dimension)
: ObExprVector(alloc, type, name, param_num, dimension)
{}
ObExprVectorDistance::FuncPtrType ObExprVectorDistance::distance_funcs[] =
{
ObVectorCosineDistance::cosine_distance_func,
ObVectorIpDistance::ip_distance_func,
ObVectorL2Distance::l2_distance_func,
ObVectorL1Distance::l1_distance_func,
ObVectorL2Distance::l2_square_func,
nullptr,
};
int ObExprVectorDistance::calc_result_typeN(
ObExprResType &type,
ObExprResType *types_stack,
int64_t param_num,
common::ObExprTypeCtx &type_ctx) const
{
int ret = OB_SUCCESS;
if (OB_UNLIKELY(param_num > 3)) {
ObString func_name_(get_name());
ret = OB_ERR_PARAM_SIZE;
LOG_USER_ERROR(OB_ERR_PARAM_SIZE, func_name_.length(), func_name_.ptr());
} else if (OB_FAIL(calc_result_type2(type, types_stack[0], types_stack[1], type_ctx))) {
LOG_WARN("failed to calc result type", K(ret));
}
return ret;
}
int ObExprVectorDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorDistance::calc_distance;
return ret;
}
int ObExprVectorDistance::calc_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
int ret = OB_SUCCESS;
ObVecDisType dis_type = ObVecDisType::EUCLIDEAN; // default metric
if (3 == expr.arg_cnt_) {
ObDatum *datum = NULL;
if (OB_FAIL(expr.args_[2]->eval(ctx, datum))) {
LOG_WARN("eval failed", K(ret));
} else if (datum->is_null()) {
ret = OB_INVALID_ARGUMENT;
LOG_WARN("invalid arg", K(ret), K(*datum));
} else {
dis_type = static_cast<ObVecDisType>(datum->get_int());
}
}
if (FAILEDx(calc_distance(expr, ctx, res_datum, dis_type))) {
LOG_WARN("failed to calc distance", K(ret), K(dis_type));
}
return ret;
}
int ObExprVectorDistance::calc_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum, ObVecDisType dis_type)
{
int ret = OB_SUCCESS;
ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx);
common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator();
ObIArrayType *arr_l = NULL;
ObIArrayType *arr_r = NULL;
bool contain_null = false;
if (dis_type < ObVecDisType::COSINE || dis_type >= ObVecDisType::MAX_TYPE) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpect distance type", K(ret), K(dis_type));
} else if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr_l, contain_null))) {
LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0]));
} else if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[1]), ctx, tmp_allocator, arr_r, contain_null))) {
LOG_WARN("failed to get vector", K(ret), K(*expr.args_[1]));
} else if (contain_null) {
res_datum.set_null();
} else if (OB_ISNULL(arr_l) || OB_ISNULL(arr_r)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected nullptr", K(ret), K(arr_l), K(arr_r));
} else if (OB_UNLIKELY(arr_l->size() != arr_r->size())) {
ret = OB_ERR_INVALID_VECTOR_DIM;
LOG_WARN("check array validty failed", K(ret), K(arr_l->size()), K(arr_r->size()));
} else if (arr_l->contain_null() || arr_r->contain_null()) {
ret = OB_ERR_NULL_VALUE;
LOG_WARN("array with null can't calculate vector distance", K(ret));
} else {
double distance = 0.0;
const float *data_l = reinterpret_cast<const float*>(arr_l->get_data());
const float *data_r = reinterpret_cast<const float*>(arr_r->get_data());
const uint32_t size = arr_l->size();
if (distance_funcs[dis_type] == nullptr) {
ret = OB_NOT_SUPPORTED;
LOG_WARN("not support", K(ret), K(dis_type));
} else if (OB_FAIL(distance_funcs[dis_type](data_l, data_r, size, distance))) {
if (OB_ERR_NULL_VALUE == ret) {
res_datum.set_null();
ret = OB_SUCCESS; // ignore
} else {
LOG_WARN("failed to calc distance", K(ret), K(dis_type));
}
} else {
res_datum.set_double(distance);
}
}
return ret;
}
ObExprVectorL1Distance::ObExprVectorL1Distance(ObIAllocator &alloc)
: ObExprVectorDistance(alloc, T_FUN_SYS_L1_DISTANCE, N_VECTOR_L1_DISTANCE, 2, NOT_ROW_DIMENSION) {}
int ObExprVectorL1Distance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorL1Distance::calc_l1_distance;
return ret;
}
int ObExprVectorL1Distance::calc_l1_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::MANHATTAN);
}
ObExprVectorL2Distance::ObExprVectorL2Distance(ObIAllocator &alloc)
: ObExprVectorDistance(alloc, T_FUN_SYS_L2_DISTANCE, N_VECTOR_L2_DISTANCE, 2, NOT_ROW_DIMENSION) {}
int ObExprVectorL2Distance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorL2Distance::calc_l2_distance;
return ret;
}
int ObExprVectorL2Distance::calc_l2_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::EUCLIDEAN);
}
ObExprVectorCosineDistance::ObExprVectorCosineDistance(ObIAllocator &alloc)
: ObExprVectorDistance(alloc, T_FUN_SYS_COSINE_DISTANCE, N_VECTOR_COS_DISTANCE, 2, NOT_ROW_DIMENSION) {}
int ObExprVectorCosineDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorCosineDistance::calc_cosine_distance;
return ret;
}
int ObExprVectorCosineDistance::calc_cosine_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::COSINE);
}
ObExprVectorIPDistance::ObExprVectorIPDistance(ObIAllocator &alloc)
: ObExprVectorDistance(alloc, T_FUN_SYS_INNER_PRODUCT, N_VECTOR_INNER_PRODUCT, 2, NOT_ROW_DIMENSION) {}
int ObExprVectorIPDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorIPDistance::calc_inner_product;
return ret;
}
int ObExprVectorIPDistance::calc_inner_product(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::DOT);
}
ObExprVectorNegativeIPDistance::ObExprVectorNegativeIPDistance(ObIAllocator &alloc)
: ObExprVectorDistance(alloc, T_FUN_SYS_NEGATIVE_INNER_PRODUCT, N_VECTOR_NEGATIVE_INNER_PRODUCT, 2, NOT_ROW_DIMENSION) {}
int ObExprVectorNegativeIPDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorNegativeIPDistance::calc_negative_inner_product;
return ret;
}
int ObExprVectorNegativeIPDistance::calc_negative_inner_product(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
int ret = OB_SUCCESS;
if (OB_FAIL(ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::DOT))) {
LOG_WARN("fail to calc distance", K(ret), K(ObVecDisType::DOT));
} else if (!res_datum.is_null()) {
double value = -1 * res_datum.get_double();
res_datum.set_double(value);
}
return ret;
}
ObExprVectorDims::ObExprVectorDims(ObIAllocator &alloc)
: ObExprVector(alloc, T_FUN_SYS_VECTOR_DIMS, N_VECTOR_DIMS, 1, NOT_ROW_DIMENSION) {}
int ObExprVectorDims::calc_result_type1(
ObExprResType &type,
ObExprResType &type1,
common::ObExprTypeCtx &type_ctx) const
{
int ret = OB_SUCCESS;
if (OB_FAIL(ObArrayExprUtils::calc_cast_type(type1, type_ctx))) {
LOG_WARN("failed to calc cast type", K(ret), K(type1));
} else {
type.set_type(ObIntType);
type.set_calc_type(ObIntType);
}
return ret;
}
int ObExprVectorDims::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorDims::calc_dims;
if (rt_expr.arg_cnt_ != 1 || OB_ISNULL(rt_expr.args_)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("count of children is not 1 or children is null", K(ret), K(rt_expr.arg_cnt_), K(rt_expr.args_));
} else if (rt_expr.args_[0]->type_ == T_FUN_SYS_CAST) {
// return error if cast failed
rt_expr.args_[0]->extra_ &= ~CM_WARN_ON_FAIL;
}
return ret;
}
int ObExprVectorDims::calc_dims(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
int ret = OB_SUCCESS;
ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx);
common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator();
ObIArrayType *arr = NULL;
bool contain_null = false;
if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr, contain_null))) {
LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0]));
} else if (contain_null) {
res_datum.set_null();
} else if (OB_ISNULL(arr)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected nullptr", K(ret), K(arr));
} else if (arr->contain_null()) {
ret = OB_ERR_NULL_VALUE;
LOG_WARN("array with null can't calculate vector norm", K(ret));
} else {
res_datum.set_int(arr->size());
}
return ret;
}
ObExprVectorNorm::ObExprVectorNorm(ObIAllocator &alloc)
: ObExprVector(alloc, T_FUN_SYS_VECTOR_NORM, N_VECTOR_NORM, 1, NOT_ROW_DIMENSION) {}
int ObExprVectorNorm::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
rt_expr.eval_func_ = ObExprVectorNorm::calc_norm;
if (rt_expr.arg_cnt_ != 1 || OB_ISNULL(rt_expr.args_)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("count of children is not 1 or children is null", K(ret), K(rt_expr.arg_cnt_), K(rt_expr.args_));
} else if (rt_expr.args_[0]->type_ == T_FUN_SYS_CAST) {
// return error if cast failed
rt_expr.args_[0]->extra_ &= ~CM_WARN_ON_FAIL;
}
return ret;
}
int ObExprVectorNorm::calc_norm(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
int ret = OB_SUCCESS;
ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx);
common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator();
ObIArrayType *arr = NULL;
bool contain_null = false;
if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr, contain_null))) {
LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0]));
} else if (contain_null) {
res_datum.set_null();
} else if (OB_ISNULL(arr)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected nullptr", K(ret), K(arr));
} else if (arr->contain_null()) {
ret = OB_ERR_NULL_VALUE;
LOG_WARN("array with null can't calculate vector norm", K(ret));
} else {
double norm = 0.0;
const float *data = reinterpret_cast<const float*>(arr->get_data());
if (OB_FAIL(ObVectorNorm::vector_norm_func(data, arr->size(), norm))) {
LOG_WARN("failed to calc vector norm", K(ret));
} else {
res_datum.set_double(norm);
}
}
return ret;
}
} // sql
} // oceanbase