/** * Copyright (c) 2023 OceanBase * OceanBase is licensed under Mulan PubL v2. * You can use this software according to the terms and conditions of the Mulan * PubL v2. You may obtain a copy of Mulan PubL v2 at: * http://license.coscl.org.cn/MulanPubL-2.0 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY * KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO * NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. See the * Mulan PubL v2 for more details. */ #define USING_LOG_PREFIX SQL_ENG #include "sql/engine/expr/ob_expr_vector.h" #include "sql/engine/ob_subschema_ctx.h" #include "sql/engine/ob_exec_context.h" #include "sql/engine/expr/ob_expr_lob_utils.h" #include "sql/engine/expr/ob_array_expr_utils.h" #include "sql/engine/expr/ob_array_cast.h" #include "share/vector_type/ob_vector_l2_distance.h" #include "share/vector_type/ob_vector_cosine_distance.h" #include "share/vector_type/ob_vector_ip_distance.h" #include "share/vector_type/ob_vector_norm.h" #include "share/vector_type/ob_vector_l1_distance.h" namespace oceanbase { namespace sql { ObExprVector::ObExprVector(ObIAllocator &alloc, ObExprOperatorType type, const char *name, int32_t param_num, int32_t dimension) : ObFuncExprOperator(alloc, type, name, param_num, VALID_FOR_GENERATED_COL, dimension) { } // [a,b,c,...] is array type, there is no dim_cnt_ in ObCollectionArrayType int ObExprVector::calc_result_type2( ObExprResType &type, ObExprResType &type1, ObExprResType &type2, common::ObExprTypeCtx &type_ctx) const { int ret = OB_SUCCESS; uint16_t unused_id = UINT16_MAX; if (OB_FAIL(ObArrayExprUtils::calc_cast_type2(type1, type2, type_ctx, unused_id))) { LOG_WARN("failed to calc cast type", K(ret), K(type1)); } else { type.set_type(ObDoubleType); type.set_calc_type(ObDoubleType); } return ret; } int ObExprVector::calc_result_type1( ObExprResType &type, ObExprResType &type1, common::ObExprTypeCtx &type_ctx) const { int ret = OB_SUCCESS; if (OB_FAIL(ObArrayExprUtils::calc_cast_type(type1, type_ctx))) { LOG_WARN("failed to calc cast type", K(ret), K(type1)); } else { type.set_type(ObDoubleType); type.set_calc_type(ObDoubleType); } return ret; } ObExprVectorDistance::ObExprVectorDistance(ObIAllocator &alloc) : ObExprVector(alloc, T_FUN_SYS_VECTOR_DISTANCE, N_VECTOR_DISTANCE, TWO_OR_THREE, NOT_ROW_DIMENSION) {} ObExprVectorDistance::ObExprVectorDistance( ObIAllocator &alloc, ObExprOperatorType type, const char *name, int32_t param_num, int32_t dimension) : ObExprVector(alloc, type, name, param_num, dimension) {} ObExprVectorDistance::FuncPtrType ObExprVectorDistance::distance_funcs[] = { ObVectorCosineDistance::cosine_distance_func, ObVectorIpDistance::ip_distance_func, ObVectorL2Distance::l2_distance_func, ObVectorL1Distance::l1_distance_func, ObVectorL2Distance::l2_square_func, nullptr, }; int ObExprVectorDistance::calc_result_typeN( ObExprResType &type, ObExprResType *types_stack, int64_t param_num, common::ObExprTypeCtx &type_ctx) const { int ret = OB_SUCCESS; if (OB_UNLIKELY(param_num > 3)) { ObString func_name_(get_name()); ret = OB_ERR_PARAM_SIZE; LOG_USER_ERROR(OB_ERR_PARAM_SIZE, func_name_.length(), func_name_.ptr()); } else if (OB_FAIL(calc_result_type2(type, types_stack[0], types_stack[1], type_ctx))) { LOG_WARN("failed to calc result type", K(ret)); } return ret; } int ObExprVectorDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorDistance::calc_distance; return ret; } int ObExprVectorDistance::calc_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { int ret = OB_SUCCESS; ObVecDisType dis_type = ObVecDisType::EUCLIDEAN; // default metric if (3 == expr.arg_cnt_) { ObDatum *datum = NULL; if (OB_FAIL(expr.args_[2]->eval(ctx, datum))) { LOG_WARN("eval failed", K(ret)); } else if (datum->is_null()) { ret = OB_INVALID_ARGUMENT; LOG_WARN("invalid arg", K(ret), K(*datum)); } else { dis_type = static_cast(datum->get_int()); } } if (FAILEDx(calc_distance(expr, ctx, res_datum, dis_type))) { LOG_WARN("failed to calc distance", K(ret), K(dis_type)); } return ret; } int ObExprVectorDistance::calc_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum, ObVecDisType dis_type) { int ret = OB_SUCCESS; ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx); common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator(); ObIArrayType *arr_l = NULL; ObIArrayType *arr_r = NULL; bool contain_null = false; if (dis_type < ObVecDisType::COSINE || dis_type >= ObVecDisType::MAX_TYPE) { ret = OB_ERR_UNEXPECTED; LOG_WARN("unexpect distance type", K(ret), K(dis_type)); } else if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr_l, contain_null))) { LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0])); } else if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[1]), ctx, tmp_allocator, arr_r, contain_null))) { LOG_WARN("failed to get vector", K(ret), K(*expr.args_[1])); } else if (contain_null) { res_datum.set_null(); } else if (OB_ISNULL(arr_l) || OB_ISNULL(arr_r)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("unexpected nullptr", K(ret), K(arr_l), K(arr_r)); } else if (OB_UNLIKELY(arr_l->size() != arr_r->size())) { ret = OB_ERR_INVALID_VECTOR_DIM; LOG_WARN("check array validty failed", K(ret), K(arr_l->size()), K(arr_r->size())); } else if (arr_l->contain_null() || arr_r->contain_null()) { ret = OB_ERR_NULL_VALUE; LOG_WARN("array with null can't calculate vector distance", K(ret)); } else { double distance = 0.0; const float *data_l = reinterpret_cast(arr_l->get_data()); const float *data_r = reinterpret_cast(arr_r->get_data()); const uint32_t size = arr_l->size(); if (distance_funcs[dis_type] == nullptr) { ret = OB_NOT_SUPPORTED; LOG_WARN("not support", K(ret), K(dis_type)); } else if (OB_FAIL(distance_funcs[dis_type](data_l, data_r, size, distance))) { if (OB_ERR_NULL_VALUE == ret) { res_datum.set_null(); ret = OB_SUCCESS; // ignore } else { LOG_WARN("failed to calc distance", K(ret), K(dis_type)); } } else { res_datum.set_double(distance); } } return ret; } ObExprVectorL1Distance::ObExprVectorL1Distance(ObIAllocator &alloc) : ObExprVectorDistance(alloc, T_FUN_SYS_L1_DISTANCE, N_VECTOR_L1_DISTANCE, 2, NOT_ROW_DIMENSION) {} int ObExprVectorL1Distance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorL1Distance::calc_l1_distance; return ret; } int ObExprVectorL1Distance::calc_l1_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::MANHATTAN); } ObExprVectorL2Distance::ObExprVectorL2Distance(ObIAllocator &alloc) : ObExprVectorDistance(alloc, T_FUN_SYS_L2_DISTANCE, N_VECTOR_L2_DISTANCE, 2, NOT_ROW_DIMENSION) {} int ObExprVectorL2Distance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorL2Distance::calc_l2_distance; return ret; } int ObExprVectorL2Distance::calc_l2_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::EUCLIDEAN); } ObExprVectorCosineDistance::ObExprVectorCosineDistance(ObIAllocator &alloc) : ObExprVectorDistance(alloc, T_FUN_SYS_COSINE_DISTANCE, N_VECTOR_COS_DISTANCE, 2, NOT_ROW_DIMENSION) {} int ObExprVectorCosineDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorCosineDistance::calc_cosine_distance; return ret; } int ObExprVectorCosineDistance::calc_cosine_distance(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::COSINE); } ObExprVectorIPDistance::ObExprVectorIPDistance(ObIAllocator &alloc) : ObExprVectorDistance(alloc, T_FUN_SYS_INNER_PRODUCT, N_VECTOR_INNER_PRODUCT, 2, NOT_ROW_DIMENSION) {} int ObExprVectorIPDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorIPDistance::calc_inner_product; return ret; } int ObExprVectorIPDistance::calc_inner_product(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { return ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::DOT); } ObExprVectorNegativeIPDistance::ObExprVectorNegativeIPDistance(ObIAllocator &alloc) : ObExprVectorDistance(alloc, T_FUN_SYS_NEGATIVE_INNER_PRODUCT, N_VECTOR_NEGATIVE_INNER_PRODUCT, 2, NOT_ROW_DIMENSION) {} int ObExprVectorNegativeIPDistance::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorNegativeIPDistance::calc_negative_inner_product; return ret; } int ObExprVectorNegativeIPDistance::calc_negative_inner_product(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { int ret = OB_SUCCESS; if (OB_FAIL(ObExprVectorDistance::calc_distance(expr, ctx, res_datum, ObVecDisType::DOT))) { LOG_WARN("fail to calc distance", K(ret), K(ObVecDisType::DOT)); } else if (!res_datum.is_null()) { double value = -1 * res_datum.get_double(); res_datum.set_double(value); } return ret; } ObExprVectorDims::ObExprVectorDims(ObIAllocator &alloc) : ObExprVector(alloc, T_FUN_SYS_VECTOR_DIMS, N_VECTOR_DIMS, 1, NOT_ROW_DIMENSION) {} int ObExprVectorDims::calc_result_type1( ObExprResType &type, ObExprResType &type1, common::ObExprTypeCtx &type_ctx) const { int ret = OB_SUCCESS; if (OB_FAIL(ObArrayExprUtils::calc_cast_type(type1, type_ctx))) { LOG_WARN("failed to calc cast type", K(ret), K(type1)); } else { type.set_type(ObIntType); type.set_calc_type(ObIntType); } return ret; } int ObExprVectorDims::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorDims::calc_dims; if (rt_expr.arg_cnt_ != 1 || OB_ISNULL(rt_expr.args_)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("count of children is not 1 or children is null", K(ret), K(rt_expr.arg_cnt_), K(rt_expr.args_)); } else if (rt_expr.args_[0]->type_ == T_FUN_SYS_CAST) { // return error if cast failed rt_expr.args_[0]->extra_ &= ~CM_WARN_ON_FAIL; } return ret; } int ObExprVectorDims::calc_dims(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { int ret = OB_SUCCESS; ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx); common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator(); ObIArrayType *arr = NULL; bool contain_null = false; if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr, contain_null))) { LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0])); } else if (contain_null) { res_datum.set_null(); } else if (OB_ISNULL(arr)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("unexpected nullptr", K(ret), K(arr)); } else if (arr->contain_null()) { ret = OB_ERR_NULL_VALUE; LOG_WARN("array with null can't calculate vector norm", K(ret)); } else { res_datum.set_int(arr->size()); } return ret; } ObExprVectorNorm::ObExprVectorNorm(ObIAllocator &alloc) : ObExprVector(alloc, T_FUN_SYS_VECTOR_NORM, N_VECTOR_NORM, 1, NOT_ROW_DIMENSION) {} int ObExprVectorNorm::cg_expr(ObExprCGCtx &expr_cg_ctx, const ObRawExpr &raw_expr, ObExpr &rt_expr) const { int ret = OB_SUCCESS; rt_expr.eval_func_ = ObExprVectorNorm::calc_norm; if (rt_expr.arg_cnt_ != 1 || OB_ISNULL(rt_expr.args_)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("count of children is not 1 or children is null", K(ret), K(rt_expr.arg_cnt_), K(rt_expr.args_)); } else if (rt_expr.args_[0]->type_ == T_FUN_SYS_CAST) { // return error if cast failed rt_expr.args_[0]->extra_ &= ~CM_WARN_ON_FAIL; } return ret; } int ObExprVectorNorm::calc_norm(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum) { int ret = OB_SUCCESS; ObEvalCtx::TempAllocGuard tmp_alloc_g(ctx); common::ObArenaAllocator &tmp_allocator = tmp_alloc_g.get_allocator(); ObIArrayType *arr = NULL; bool contain_null = false; if (OB_FAIL(ObArrayExprUtils::get_type_vector(*(expr.args_[0]), ctx, tmp_allocator, arr, contain_null))) { LOG_WARN("failed to get vector", K(ret), K(*expr.args_[0])); } else if (contain_null) { res_datum.set_null(); } else if (OB_ISNULL(arr)) { ret = OB_ERR_UNEXPECTED; LOG_WARN("unexpected nullptr", K(ret), K(arr)); } else if (arr->contain_null()) { ret = OB_ERR_NULL_VALUE; LOG_WARN("array with null can't calculate vector norm", K(ret)); } else { double norm = 0.0; const float *data = reinterpret_cast(arr->get_data()); if (OB_FAIL(ObVectorNorm::vector_norm_func(data, arr->size(), norm))) { LOG_WARN("failed to calc vector norm", K(ret)); } else { res_datum.set_double(norm); } } return ret; } } // sql } // oceanbase