[vector] ivf post-search adapts the vectorized interface

This commit is contained in:
lyxiong0
2025-02-19 14:16:26 +00:00
committed by ob-robot
parent 77e64f7a59
commit 02000f544b
2 changed files with 326 additions and 149 deletions

View File

@ -87,21 +87,31 @@ int ObDASIvfBaseScanIter::gen_rowkeys_itr()
return ret;
}
int ObDASIvfBaseScanIter::do_table_full_scan(bool &first_scan,
ObTableScanParam &scan_param,
int ObDASIvfBaseScanIter::do_table_full_scan(bool is_vectorized,
const ObDASScanCtDef *ctdef,
ObDASScanRtDef *rtdef,
ObDASScanIter *iter,
int64_t pri_key_cnt,
ObTabletID &tablet_id)
ObTabletID &tablet_id,
bool &first_scan,
ObTableScanParam &scan_param)
{
int ret = OB_SUCCESS;
if (first_scan) {
ObNewRange scan_range;
if (OB_FAIL(ObDasVecScanUtils::init_vec_aux_scan_param(
ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param))) {
LOG_WARN("failed to init scan param", K(ret));
if (is_vectorized) {
if (OB_FAIL(
ObDasVecScanUtils::init_scan_param(ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false/*is_get*/))) {
LOG_WARN("failed to generate init vec aux scan param", K(ret));
}
} else {
if (OB_FAIL(ObDasVecScanUtils::init_vec_aux_scan_param(
ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param))) {
LOG_WARN("failed to init scan param", K(ret));
}
}
if (OB_FAIL(ret)) {
} else if (OB_FALSE_IT(ObDasVecScanUtils::set_whole_range(scan_range, ctdef->ref_table_id_))) {
} else if (OB_FAIL(scan_param.key_ranges_.push_back(scan_range))) {
LOG_WARN("failed to append scan range", K(ret));
@ -143,7 +153,7 @@ int ObDASIvfBaseScanIter::do_aux_table_scan(bool &first_scan,
if (first_scan) {
scan_param.need_switch_param_ = false;
if (OB_FAIL(ObDasVecScanUtils::init_scan_param(
ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false))) {
ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false/*is_get*/))) {
LOG_WARN("failed to init scan param", K(ret));
} else if (OB_FALSE_IT(iter->set_scan_param(scan_param))) {
} else if (OB_FAIL(iter->do_table_scan())) {
@ -538,6 +548,7 @@ int ObDASIvfScanIter::parse_centroid_datum_with_deep_copy(
const ObDASScanCtDef *centroid_ctdef,
ObIAllocator& allocator,
blocksstable::ObDatumRow *datum_row,
bool save_center_vec,
ObString &cid,
ObString &cid_vec)
{
@ -565,13 +576,18 @@ int ObDASIvfScanIter::parse_centroid_datum_with_deep_copy(
// ignoring null vector.
} else if (OB_FAIL(ob_write_string(allocator, tmp_cid, cid))) {
LOG_WARN("failed to write string", K(ret), K(tmp_cid));
} else if (OB_FAIL(ob_write_string(allocator, tmp_cid_vec, cid_vec))) {
LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec));
} else if (save_center_vec) {
if (OB_FAIL(ob_write_string(allocator, tmp_cid_vec, cid_vec))) {
LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec));
}
} else {
cid_vec = tmp_cid_vec;
}
return ret;
}
int ObDASIvfScanIter::generate_nearear_cid_heap(
bool is_vectorized,
share::ObVectorCentorClusterHelper<float, ObString> &nearest_cid_heap,
bool save_center_vec /*= false*/)
{
@ -579,14 +595,57 @@ int ObDASIvfScanIter::generate_nearear_cid_heap(
const ObDASScanCtDef *centroid_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(
vec_aux_ctdef_->get_ivf_centroid_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CENTROID_SCAN);
ObDASScanRtDef *centroid_rtdef = vec_aux_rtdef_->get_vec_aux_tbl_rtdef(vec_aux_ctdef_->get_ivf_centroid_tbl_idx());
if (OB_FAIL(do_table_full_scan(centroid_iter_first_scan_,
centroid_scan_param_,
centroid_ctdef,
centroid_rtdef,
centroid_iter_,
CENTROID_PRI_KEY_CNT,
centroid_tablet_id_))) {
if (OB_FAIL(do_table_full_scan(is_vectorized,
centroid_ctdef,
centroid_rtdef,
centroid_iter_,
CENTROID_PRI_KEY_CNT,
centroid_tablet_id_,
centroid_iter_first_scan_,
centroid_scan_param_))) {
LOG_WARN("failed to do centroid table scan", K(ret));
} else if (is_vectorized) {
IVF_GET_NEXT_ROWS_BEGIN(centroid_iter_)
if (OB_SUCC(ret)) {
ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_);
guard.set_batch_size(scan_row_cnt);
ObExpr *cid_expr = centroid_ctdef->result_output_[CID_IDX];
ObExpr *cid_vec_expr = centroid_ctdef->result_output_[CID_VECTOR_IDX];
ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_);
ObDatum *cid_vec_datum = cid_vec_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_);
bool has_lob_header = centroid_ctdef->result_output_.at(CID_VECTOR_IDX)->obj_meta_.has_lob_header();
for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) {
guard.set_batch_idx(i);
ObString tmp_cid = cid_datum[i].get_string();
ObString tmp_cid_vec = cid_vec_datum[i].get_string();
ObString cid;
ObString cid_vec;
if (OB_FAIL(ObTextStringHelper::read_real_string_data(
&vec_op_alloc_,
ObLongTextType,
CS_TYPE_BINARY,
has_lob_header,
tmp_cid_vec))) {
LOG_WARN("failed to get real data.", K(ret));
} else if (tmp_cid_vec.empty()) {
// ignoring null vector.
} else if (OB_FAIL(ob_write_string(vec_op_alloc_, tmp_cid, cid))) {
LOG_WARN("failed to write string", K(ret), K(tmp_cid));
} else if (save_center_vec) {
if (OB_FAIL(ob_write_string(vec_op_alloc_, tmp_cid_vec, cid_vec))) {
LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec));
}
} else {
cid_vec = tmp_cid_vec;
}
if (OB_FAIL(ret)) {
} else if (OB_FAIL(nearest_cid_heap.push_center(cid, reinterpret_cast<float *>(cid_vec.ptr()), dim_, save_center_vec))) {
LOG_WARN("failed to push center.", K(ret));
}
}
}
IVF_GET_NEXT_ROWS_END(centroid_iter_, centroid_scan_param_, centroid_tablet_id_)
} else {
storage::ObTableScanIterator *centroid_scan_iter =
static_cast<storage::ObTableScanIterator *>(centroid_iter_->get_output_result_iter());
@ -596,9 +655,6 @@ int ObDASIvfScanIter::generate_nearear_cid_heap(
} else {
// get vec in centroid_iter and get l2_distance(c_vec, search_vec_)
// push in max_cid_heap
const ObDASScanCtDef *centroid_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(
vec_aux_ctdef_->get_ivf_centroid_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CENTROID_SCAN);
ObString c_vec;
ObString c_id;
while (OB_SUCC(ret)) {
@ -607,7 +663,7 @@ int ObDASIvfScanIter::generate_nearear_cid_heap(
if (OB_ITER_END != ret) {
LOG_WARN("get next row failed.", K(ret));
}
} else if (OB_FAIL(parse_centroid_datum_with_deep_copy(centroid_ctdef, vec_op_alloc_, datum_row, c_id, c_vec))) {
} else if (OB_FAIL(parse_centroid_datum_with_deep_copy(centroid_ctdef, vec_op_alloc_, datum_row, save_center_vec, c_id, c_vec))) {
LOG_WARN("fail to deep copy centroid datum", K(ret));
} else if (c_vec.empty()) {
// ignoring null vector.
@ -627,14 +683,14 @@ int ObDASIvfScanIter::generate_nearear_cid_heap(
return ret;
}
int ObDASIvfScanIter::get_nearest_probe_center_ids(ObIArray<ObString> &nearest_cids)
int ObDASIvfScanIter::get_nearest_probe_center_ids(bool is_vectorized, ObIArray<ObString> &nearest_cids)
{
int ret = OB_SUCCESS;
share::ObVectorCentorClusterHelper<float, ObString> nearest_cid_heap(
vec_op_alloc_, reinterpret_cast<const float *>(real_search_vec_.ptr()), dim_, nprobes_);
if (OB_FAIL(nearest_cids.reserve(nprobes_))) {
LOG_WARN("failed to reserve nearest_cids", K(ret));
} else if (OB_FAIL(generate_nearear_cid_heap(nearest_cid_heap))) {
} else if (OB_FAIL(generate_nearear_cid_heap(is_vectorized, nearest_cid_heap))) {
LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_));
}
if (OB_SUCC(ret) || ret == OB_ITER_END) {
@ -779,7 +835,7 @@ int ObDASIvfScanIter::parse_cid_vec_datum(
}
template <typename T>
int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray<ObString> &nearest_cids, T *serch_vec)
int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(bool is_vectorized, const ObIArray<ObString> &nearest_cids, T *serch_vec)
{
int ret = OB_SUCCESS;
@ -802,6 +858,38 @@ int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray<ObString>
storage::ObTableScanIterator *cid_vec_scan_iter = nullptr;
if (OB_FAIL(scan_cid_range(cid, cid_vec_pri_key_cnt, cid_vec_ctdef, cid_vec_rtdef, cid_vec_scan_iter))) {
LOG_WARN("fail to scan cid range", K(ret), K(cid), K(cid_vec_pri_key_cnt));
} else if (is_vectorized) {
IVF_GET_NEXT_ROWS_BEGIN(cid_vec_iter_)
if (OB_SUCC(ret)) {
ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_);
guard.set_batch_size(scan_row_cnt);
bool has_lob_header = cid_vec_ctdef->result_output_.at(CID_VECTOR_IDX)->obj_meta_.has_lob_header();
ObExpr *cid_expr = cid_vec_ctdef->result_output_[CID_VECTOR_IDX];
ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_);
for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) {
guard.set_batch_idx(i);
ObRowkey *main_rowkey = nullptr;
ObString vec = cid_datum[i].get_string();
void *buf = nullptr;
ObObj *obj_ptr = nullptr;
if (OB_FAIL(ObTextStringHelper::read_real_string_data(
&vec_op_alloc_,
ObLongTextType,
CS_TYPE_BINARY,
has_lob_header,
vec))) {
LOG_WARN("failed to get real data.", K(ret));
} else if (OB_ISNULL(vec.ptr())) {
// ignoring null vector.
} else if (OB_FAIL(get_main_rowkey_from_cid_vec_datum(cid_vec_ctdef, rowkey_cnt, main_rowkey))) {
LOG_WARN("fail to get main rowkey", K(ret));
} else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, reinterpret_cast<T *>(vec.ptr()), dim_))) {
LOG_WARN("failed to push center.", K(ret));
}
}
}
IVF_GET_NEXT_ROWS_END(cid_vec_iter_, cid_vec_scan_param_, cid_vec_tablet_id_)
} else {
while (OB_SUCC(ret)) {
ObRowkey *main_rowkey;
@ -819,12 +907,12 @@ int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray<ObString>
LOG_WARN("failed to push center.", K(ret));
}
}
}
if (ret == OB_ITER_END) {
ret = OB_SUCCESS;
if (OB_FAIL(cid_vec_iter_->reuse())) {
LOG_WARN("fail to reuse scan iterator.", K(ret));
if (ret == OB_ITER_END) {
ret = OB_SUCCESS;
if (OB_FAIL(cid_vec_iter_->reuse())) {
LOG_WARN("fail to reuse scan iterator.", K(ret));
}
}
}
}
@ -909,16 +997,16 @@ int ObDASIvfScanIter::process_ivf_scan_brute()
return ret;
}
int ObDASIvfScanIter::process_ivf_scan_post()
int ObDASIvfScanIter::process_ivf_scan_post(bool is_vectorized)
{
int ret = OB_SUCCESS;
ObSEArray<ObString, 8> nearest_cids;
nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid"));
if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) {
if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) {
// 1. Scan the centroid table, range: [min] ~ [max]; sort by l2_distance(c_vec, search_vec_); limit nprobes;
// return the cid column.
LOG_WARN("failed to get nearest probe center ids", K(ret));
} else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids<float>(nearest_cids,
} else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids<float>(is_vectorized, nearest_cids,
reinterpret_cast<float *>(real_search_vec_.ptr())))) {
// 2. get cidx in (top-nprobes 个 cid)
// scan cid_vector table, range: [cidx, min] ~ [cidx, max]; sort by l2_distance(vec, search_vec_); limit
@ -933,7 +1021,7 @@ int ObDASIvfScanIter::process_ivf_scan(bool is_vectorized)
int ret = OB_SUCCESS;
if (vec_aux_ctdef_->is_post_filter()) {
if (OB_FAIL(process_ivf_scan_post())) {
if (OB_FAIL(process_ivf_scan_post(is_vectorized))) {
LOG_WARN("failed to process ivf_scan post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys));
}
} else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) {
@ -1032,7 +1120,7 @@ int ObDASIvfScanIter::filter_pre_rowkey_batch(const ObIArray<ObString> &nearest_
} else if (OB_FAIL(filter_rowkey_by_cid(nearest_cids, is_vectorized, batch_row_count, index_end))) {
LOG_WARN("filter rowkey batch failed", K(ret), K(nearest_cids), K(is_vectorized), K(batch_row_count));
} else {
int tmp_ret = ret;
int tmp_ret = ret; // NOTE(liyao): 跳过reuse_iter返回值
if (OB_FAIL(ObDasVecScanUtils::reuse_iter(
ls_id_, rowkey_cid_iter_, rowkey_cid_scan_param_, rowkey_cid_tablet_id_))) {
LOG_WARN("failed to reuse rowkey cid iter.", K(ret));
@ -1052,7 +1140,7 @@ int ObDASIvfScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_vect
int64_t batch_row_count = ObVectorParamData::VI_PARAM_DATA_BATCH_SIZE;
ObSEArray<ObString, 8> nearest_cids;
nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid"));
if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) {
if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) {
LOG_WARN("failed to get nearest probe center ids", K(ret));
} else if (nearest_cids.count() == 1 && nearest_cids.at(0).empty()) {
// create an index on the table, No 1 table is empty
@ -1148,7 +1236,7 @@ int ObDASIvfScanIter::inner_get_next_row()
if (limit_param_.limit_ + limit_param_.offset_ == 0) {
ret = OB_ITER_END;
} else if (OB_ISNULL(saved_rowkeys_itr_)) {
if (OB_FAIL(process_ivf_scan(false))) {
if (OB_FAIL(process_ivf_scan(false/*is_vectorized*/))) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to process ivf scan state", K(ret));
}
@ -1173,7 +1261,7 @@ int ObDASIvfScanIter::inner_get_next_rows(int64_t &count, int64_t capacity)
if (limit_param_.limit_ + limit_param_.offset_ == 0) {
ret = OB_ITER_END;
} else if (OB_ISNULL(saved_rowkeys_itr_)) {
if (OB_FAIL(process_ivf_scan(true))) {
if (OB_FAIL(process_ivf_scan(true/*is_vectorized*/))) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to process ivf scan state", K(ret));
}
@ -1325,9 +1413,10 @@ int ObDASIvfPQScanIter::get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float *
LOG_WARN("invalid null scan iter", K(ret));
} else {
int row_cnt = 0;
bool has_lob_header = pq_cid_vec_ctdef->result_output_.at(PQ_CENTROID_VEC_IDX)->obj_meta_.has_lob_header();
while (OB_SUCC(ret)) {
ObExpr *pq_center_vec_expr = pq_cid_vec_ctdef->result_output_[PQ_CENTROID_VEC_IDX];
ObRowkey *main_rowkey;
ObRowkey *main_rowkey = nullptr;
ObString vec;
// cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY]
if (OB_FAIL(cid_vec_scan_iter->get_next_row())) {
@ -1340,7 +1429,7 @@ int ObDASIvfPQScanIter::get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float *
&vec_op_alloc_,
ObLongTextType,
CS_TYPE_BINARY,
pq_cid_vec_ctdef->result_output_.at(PQ_CENTROID_VEC_IDX)->obj_meta_.has_lob_header(),
has_lob_header,
vec))) {
LOG_WARN("failed to get real data.", K(ret));
} else if (OB_ISNULL(vec.ptr())) {
@ -1370,7 +1459,7 @@ int ObDASIvfPQScanIter::calc_distance_between_pq_ids(
double &distance)
{
int ret = OB_SUCCESS;
distance = 0.0;
double square = 0.0;
for (int j = 0; OB_SUCC(ret) && j < m_; ++j) {
// 3.2.1 pq_center_ids[j] is put into ivf_pq_centroid table to find pq_center_vecs[j]
float *pq_cid_vec = nullptr;
@ -1379,18 +1468,20 @@ int ObDASIvfPQScanIter::calc_distance_between_pq_ids(
} else {
// 3.3.2 Calculate the distance between pq_center_vecs[j] and r(x)[j].
// The sum of j = 0 ~ m is the distance from x to rowkey
double cur_distance = DBL_MAX;
if (OB_FAIL(ObVectorL2Distance::l2_distance_func(splited_residual.at(j), pq_cid_vec, dim_ / m_, cur_distance))) {
double cur_square = DBL_MAX;
if (OB_FAIL(ObVectorL2Distance::l2_square_func(splited_residual.at(j), pq_cid_vec, dim_ / m_, cur_square))) {
LOG_WARN("failed to calc l2 distance", K(ret));
} else {
distance += cur_distance;
square += cur_square;
}
}
} // end for
distance = sqrt(square);
return ret;
}
int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids(
bool is_vectorized,
const ObIArray<IvfCidVecPair> &nearest_centers,
float *search_vec)
{
@ -1429,40 +1520,75 @@ int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids(
storage::ObTableScanIterator *cid_vec_scan_iter = nullptr;
if (OB_FAIL(scan_cid_range(cur_cid, cid_vec_pri_key_cnt, cid_vec_ctdef, cid_vec_rtdef, cid_vec_scan_iter))) {
LOG_WARN("fail to scan cid range", K(ret), K(cur_cid), K(cid_vec_pri_key_cnt));
}
while (OB_SUCC(ret)) {
ObRowkey *main_rowkey;
ObString vec_arr_str;
ObArrayBinary *pq_center_ids = NULL;
double distance = 0.0;
// cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY]
if (OB_FAIL(cid_vec_scan_iter->get_next_row())) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to scan vid rowkey iter", K(ret));
}
} else if (OB_FAIL(parse_cid_vec_datum(
tmp_allocator,
cid_vec_column_count,
cid_vec_ctdef,
rowkey_cnt,
main_rowkey,
pq_center_ids))) {
LOG_WARN("fail to parse cid vec datum", K(ret), K(cid_vec_column_count), K(rowkey_cnt));
} else if (OB_ISNULL(pq_center_ids)) {
// ignore null arr
} else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) {
LOG_WARN("fail to calc distance between pq ids", K(ret));
} else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) {
LOG_WARN("failed to push center.", K(ret));
}
} // end while
} else if (is_vectorized) {
IVF_GET_NEXT_ROWS_BEGIN(cid_vec_iter_)
if (OB_SUCC(ret)) {
ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_);
guard.set_batch_size(scan_row_cnt);
ObExpr *cid_expr = cid_vec_ctdef->result_output_[PQ_IDS_IDX];
ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_);
if (ret == OB_ITER_END) {
ret = OB_SUCCESS;
if (OB_FAIL(cid_vec_iter_->reuse())) {
LOG_WARN("fail to reuse scan iterator.", K(ret));
} else {
splited_residual.reuse();
for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) {
guard.set_batch_idx(i);
if (cid_datum[i].is_null()) {
// continue
} else {
ObRowkey *main_rowkey = nullptr;
ObString vec = cid_datum[i].get_string();
void *buf = nullptr;
ObObj *obj_ptr = nullptr;
ObArrayBinary *pq_center_ids = NULL;
double distance = 0.0;
if (OB_FAIL(get_pq_cids_from_datum(tmp_allocator, vec, pq_center_ids))) {
LOG_WARN("fail to get pq cids from datum", K(ret));
} else if (OB_ISNULL(pq_center_ids)) {
// ignore null arr
} else if (OB_FAIL(get_main_rowkey_from_cid_vec_datum(cid_vec_ctdef, rowkey_cnt, main_rowkey))) {
LOG_WARN("fail to get main rowkey", K(ret));
} else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) {
LOG_WARN("fail to calc distance between pq ids", K(ret));
} else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) {
LOG_WARN("failed to push center.", K(ret));
}
}
}
}
IVF_GET_NEXT_ROWS_END(cid_vec_iter_, cid_vec_scan_param_, cid_vec_tablet_id_)
} else {
while (OB_SUCC(ret)) {
ObRowkey *main_rowkey = nullptr;
ObString vec_arr_str;
ObArrayBinary *pq_center_ids = NULL;
double distance = 0.0;
// cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY]
if (OB_FAIL(cid_vec_scan_iter->get_next_row())) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to scan vid rowkey iter", K(ret));
}
} else if (OB_FAIL(parse_cid_vec_datum(
tmp_allocator,
cid_vec_column_count,
cid_vec_ctdef,
rowkey_cnt,
main_rowkey,
pq_center_ids))) {
LOG_WARN("fail to parse cid vec datum", K(ret), K(cid_vec_column_count), K(rowkey_cnt));
} else if (OB_ISNULL(pq_center_ids)) {
// ignore null arr
} else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) {
LOG_WARN("fail to calc distance between pq ids", K(ret));
} else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) {
LOG_WARN("failed to push center.", K(ret));
}
} // end while
if (ret == OB_ITER_END) {
ret = OB_SUCCESS;
if (OB_FAIL(cid_vec_iter_->reuse())) {
LOG_WARN("fail to reuse scan iterator.", K(ret));
} else {
splited_residual.reuse();
}
}
}
}
@ -1478,14 +1604,14 @@ int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids(
return ret;
}
int ObDASIvfPQScanIter::get_nearest_probe_centers(ObIArray<IvfCidVecPair> &nearest_centers)
int ObDASIvfPQScanIter::get_nearest_probe_centers(bool is_vectorized, ObIArray<IvfCidVecPair> &nearest_centers)
{
int ret = OB_SUCCESS;
share::ObVectorCentorClusterHelper<float, ObString> nearest_cid_heap(
vec_op_alloc_, reinterpret_cast<const float *>(real_search_vec_.ptr()), dim_, nprobes_);
if (OB_FAIL(nearest_centers.reserve(nprobes_))) {
LOG_WARN("failed to reserve nearest_centers", K(ret));
} else if (OB_FAIL(generate_nearear_cid_heap(nearest_cid_heap, true/*save_center_vec*/))) {
} else if (OB_FAIL(generate_nearear_cid_heap(is_vectorized, nearest_cid_heap, true/*save_center_vec*/))) {
LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_));
}
if (OB_SUCC(ret) || ret == OB_ITER_END) {
@ -1521,12 +1647,12 @@ int ObDASIvfPQScanIter::get_rowkey_brute_post()
return ret;
}
int ObDASIvfPQScanIter::process_ivf_scan_post()
int ObDASIvfPQScanIter::process_ivf_scan_post(bool is_vectorized)
{
int ret = OB_SUCCESS;
ObSEArray<IvfCidVecPair, 8> nearest_centers;
nearest_centers.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid"));
if (OB_FAIL(get_nearest_probe_centers(nearest_centers))) {
if (OB_FAIL(get_nearest_probe_centers(is_vectorized, nearest_centers))) {
// 1. Scan the ivf_centroid table, calculate the distance between vec_x and cid_vec,
// and get the nearest cluster center (cid 1, cid_vec 1)... (cid n, cid_vec n)
LOG_WARN("failed to get nearest probe center ids", K(ret));
@ -1535,6 +1661,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_post()
LOG_WARN("failed to get limit rowkey brute", K(ret));
}
} else if (OB_FAIL(calc_nearest_limit_rowkeys_in_cids(
is_vectorized,
nearest_centers,
reinterpret_cast<float *>(real_search_vec_.ptr())))) {
// 2. search nearest rowkeys
@ -1749,7 +1876,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve
IvfRowkeyHeap nearest_rowkey_heap(
vec_op_alloc_, reinterpret_cast<const float *>(real_search_vec_.ptr())/*unused*/,
dim_ / m_, get_nprobe(limit_param_, PQ_ID_ENLARGEMENT_FACTOR));
if (OB_FAIL(get_nearest_probe_centers(nearest_centers))) {
if (OB_FAIL(get_nearest_probe_centers(is_vectorized, nearest_centers))) {
// 1. Scan the ivf_centroid table, calculate the distance between vec_x and cid_vec,
// and get the nearest cluster center (cid 1, cid_vec 1)... (cid n, cid_vec n)
LOG_WARN("failed to get nearest probe center ids", K(ret));
@ -1797,27 +1924,6 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve
return ret;
}
int ObDASIvfPQScanIter::process_ivf_scan(bool is_vectorized)
{
int ret = OB_SUCCESS;
if (vec_aux_ctdef_->is_post_filter()) {
if (OB_FAIL(process_ivf_scan_post())) {
LOG_WARN("failed to process adaptorstate post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys));
}
} else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) {
LOG_WARN("failed to process adaptor state hnsw", K(ret));
}
if (OB_FAIL(ret)) {
} else if (OB_FAIL(gen_rowkeys_itr())) {
if (ret != OB_ITER_END) {
LOG_WARN("failed to gen adapter rowkeys itr", K(saved_rowkeys_));
}
}
return ret;
}
/********************************************************************************************************/
int ObDASIvfSQ8ScanIter::inner_get_next_rows(int64_t &count, int64_t capacity)
{
@ -1825,7 +1931,7 @@ int ObDASIvfSQ8ScanIter::inner_get_next_rows(int64_t &count, int64_t capacity)
if (limit_param_.limit_ + limit_param_.offset_ == 0) {
ret = OB_ITER_END;
} else if (OB_ISNULL(saved_rowkeys_itr_)) {
if (OB_FAIL(process_ivf_scan_sq(true))) {
if (OB_FAIL(process_ivf_scan_sq(true/*is_vectorized*/))) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to process adaptor state", K(ret));
}
@ -1849,7 +1955,7 @@ int ObDASIvfSQ8ScanIter::inner_get_next_row()
if (limit_param_.limit_ + limit_param_.offset_ == 0) {
ret = OB_ITER_END;
} else if (OB_ISNULL(saved_rowkeys_itr_)) {
if (OB_FAIL(process_ivf_scan_sq(false))) {
if (OB_FAIL(process_ivf_scan_sq(false/*is_vectorized*/))) {
if (OB_ITER_END != ret) {
LOG_WARN("failed to process adaptor state", K(ret));
}
@ -1871,7 +1977,7 @@ int ObDASIvfSQ8ScanIter::process_ivf_scan_sq(bool is_vectorized)
int ret = OB_SUCCESS;
if (vec_aux_ctdef_->is_post_filter()) {
if (OB_FAIL(process_ivf_scan_post_sq())) {
if (OB_FAIL(process_ivf_scan_post_sq(is_vectorized))) {
LOG_WARN("failed to process adaptorstate post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys));
}
} else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) {
@ -1913,20 +2019,52 @@ int ObDASIvfSQ8ScanIter::inner_release()
return ret;
}
int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8)
int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(bool is_vectorized, ObString &real_search_vec_u8)
{
int ret = OB_SUCCESS;
const ObDASScanCtDef *sq_meta_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(
vec_aux_ctdef_->get_ivf_sq_meta_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_SPECIAL_AUX_SCAN);
ObDASScanRtDef *sq_meta_rtdef = vec_aux_rtdef_->get_vec_aux_tbl_rtdef(vec_aux_ctdef_->get_ivf_sq_meta_tbl_idx());
if (OB_FAIL(do_table_full_scan(sq_meta_iter_first_scan_,
sq_meta_scan_param_,
ObString min_vec;
ObString step_vec;
if (OB_FAIL(do_table_full_scan(is_vectorized,
sq_meta_ctdef,
sq_meta_rtdef,
sq_meta_iter_,
SQ_MEAT_PRI_KEY_CNT,
sq_meta_tablet_id_))) {
sq_meta_tablet_id_,
sq_meta_iter_first_scan_,
sq_meta_scan_param_))) {
LOG_WARN("failed to do table scan sq_meta", K(ret));
} else if (is_vectorized) {
IVF_GET_NEXT_ROWS_BEGIN(sq_meta_iter_)
if (OB_SUCC(ret)) {
ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_);
guard.set_batch_size(scan_row_cnt);
bool has_lob_header = sq_meta_ctdef->result_output_.at(META_VECTOR_IDX)->obj_meta_.has_lob_header();
ObExpr *meta_vec_expr = sq_meta_ctdef->result_output_[META_VECTOR_IDX];
ObDatum *meta_vec_datum = meta_vec_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_);
for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) {
guard.set_batch_idx(i);
if (i == ObIvfConstant::SQ8_META_MIN_IDX || i == ObIvfConstant::SQ8_META_STEP_IDX) {
ObString c_vec = meta_vec_datum[i].get_string();
if (OB_FAIL(ObTextStringHelper::read_real_string_data(
&vec_op_alloc_,
ObLongTextType,
CS_TYPE_BINARY,
has_lob_header,
c_vec))) {
LOG_WARN("failed to get real data.", K(ret));
} else if (i == ObIvfConstant::SQ8_META_MIN_IDX) {
min_vec = c_vec;
} else if (i == ObIvfConstant::SQ8_META_STEP_IDX) {
step_vec = c_vec;
}
}
}
}
IVF_GET_NEXT_ROWS_END(sq_meta_iter_, sq_meta_scan_param_, sq_meta_tablet_id_)
} else {
// get min_vec max_vec step_vec in sq_meta table
storage::ObTableScanIterator *sq_meta_scan_iter =
@ -1935,8 +2073,6 @@ int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8)
ret = OB_ERR_UNEXPECTED;
LOG_WARN("real sql meta iter is null", K(ret));
} else {
ObString min_vec;
ObString step_vec;
int row_index = 0;
while (OB_SUCC(ret)) {
ObString c_vec;
@ -1970,34 +2106,39 @@ int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8)
}
if (ret == OB_ITER_END) {
ret = OB_SUCCESS;
uint8_t *res_vec = nullptr;
if (OB_ISNULL(min_vec.ptr()) || OB_ISNULL(step_vec.ptr())) {
if (OB_ISNULL(res_vec = reinterpret_cast<uint8_t *>(vec_op_alloc_.alloc(sizeof(uint8_t) * dim_)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to allocate memory", K(ret), K(sizeof(uint8_t) * dim_));
} else {
MEMSET(res_vec, 0, sizeof(uint8_t) * dim_);
}
} else if (OB_FAIL(
ObExprVecIVFSQ8DataVector::cal_u8_data_vector(vec_op_alloc_,
dim_,
reinterpret_cast<float *>(min_vec.ptr()),
reinterpret_cast<float *>(step_vec.ptr()),
reinterpret_cast<float *>(real_search_vec_.ptr()),
res_vec))) {
LOG_WARN("fail to cal u8 data vector", K(ret), K(dim_));
}
if (OB_SUCC(ret)) {
real_search_vec_u8.assign_ptr(reinterpret_cast<char *>(res_vec), dim_ * sizeof(uint8_t));
if (OB_FAIL(sq_meta_iter_->reuse())) {
LOG_WARN("fail to reuse scan iterator.", K(ret));
}
}
}
}
if (OB_SUCC(ret)) {
uint8_t *res_vec = nullptr;
if (OB_ISNULL(min_vec.ptr()) || OB_ISNULL(step_vec.ptr())) {
if (OB_ISNULL(res_vec = reinterpret_cast<uint8_t *>(vec_op_alloc_.alloc(sizeof(uint8_t) * dim_)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to allocate memory", K(ret), K(sizeof(uint8_t) * dim_));
} else {
MEMSET(res_vec, 0, sizeof(uint8_t) * dim_);
}
} else if (OB_FAIL(
ObExprVecIVFSQ8DataVector::cal_u8_data_vector(vec_op_alloc_,
dim_,
reinterpret_cast<float *>(min_vec.ptr()),
reinterpret_cast<float *>(step_vec.ptr()),
reinterpret_cast<float *>(real_search_vec_.ptr()),
res_vec))) {
LOG_WARN("fail to cal u8 data vector", K(ret), K(dim_));
}
if (OB_SUCC(ret)) {
real_search_vec_u8.assign_ptr(reinterpret_cast<char *>(res_vec), dim_ * sizeof(uint8_t));
}
}
return ret;
}
int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq()
int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq(bool is_vectorized)
{
int ret = OB_SUCCESS;
// post scan:
@ -2008,13 +2149,13 @@ int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq()
ObSEArray<ObString, 8> nearest_cids;
nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid"));
ObString real_search_vec_u8;
if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) {
if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) {
LOG_WARN("failed to get nearest probe center ids", K(ret));
} else if (OB_FAIL(get_real_search_vec_u8(real_search_vec_u8))) {
} else if (OB_FAIL(get_real_search_vec_u8(is_vectorized, real_search_vec_u8))) {
// real_search_vec to real_search_vec_u8 with min_vec max_vec step_vec
LOG_WARN("failed to get real search vec u8", K(ret));
} else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids<uint8_t>(
nearest_cids, reinterpret_cast<uint8_t *>(real_search_vec_u8.ptr())))) {
is_vectorized, nearest_cids, reinterpret_cast<uint8_t *>(real_search_vec_u8.ptr())))) {
LOG_WARN("failed to get nearest limit rowkeys in cids", K(nearest_cids));
}

View File

@ -24,6 +24,35 @@ namespace oceanbase
using namespace common;
namespace sql
{
#define IVF_GET_NEXT_ROWS_BEGIN(iter) \
bool index_end = false; \
iter->clear_evaluated_flag(); \
int64_t scan_row_cnt = 0; \
int64_t batch_row_count = ObVectorParamData::VI_PARAM_DATA_BATCH_SIZE; \
while (!index_end && OB_SUCC(ret)) { \
if (OB_FAIL(iter->get_next_rows(scan_row_cnt, batch_row_count))) { \
if (OB_ITER_END != ret) { \
LOG_WARN("failed to get next row.", K(ret)); \
} else { \
index_end = true; \
} \
} \
if (OB_FAIL(ret) && OB_ITER_END != ret) { \
} else if (scan_row_cnt > 0) { \
ret = OB_SUCCESS; \
}
#define IVF_GET_NEXT_ROWS_END(iter, scan_param, tablet_id) \
} \
if (index_end) { \
int tmp_ret = (ret == OB_ITER_END) ? OB_SUCCESS : ret; \
if (OB_FAIL(ObDasVecScanUtils::reuse_iter(ls_id_, iter, scan_param, tablet_id))) { \
LOG_WARN("failed to reuse rowkey cid iter.", K(ret)); \
} else { \
ret = tmp_ret; \
} \
}
struct ObDASIvfScanIterParam : public ObDASIterParam {
public:
explicit ObDASIvfScanIterParam(const ObVectorIndexAlgorithmType index_type)
@ -153,13 +182,14 @@ protected:
}
protected:
int do_table_full_scan(bool &first_scan,
ObTableScanParam &scan_param,
const ObDASScanCtDef *ctdef,
ObDASScanRtDef *rtdef,
ObDASScanIter *iter,
int64_t pri_key_cnt,
ObTabletID &tablet_id);
int do_table_full_scan(bool is_vectorized,
const ObDASScanCtDef *ctdef,
ObDASScanRtDef *rtdef,
ObDASScanIter *iter,
int64_t pri_key_cnt,
ObTabletID &tablet_id,
bool &first_scan,
ObTableScanParam &scan_param);
int do_aux_table_scan(bool &first_scan,
ObTableScanParam &scan_param,
const ObDASScanCtDef *ctdef,
@ -186,6 +216,9 @@ protected:
static const int64_t SQ_MEAT_PRI_KEY_CNT = 1;
static const int64_t SQ_MEAT_ALL_KEY_CNT = 2;
static const int64_t POST_ENLARGEMENT_FACTOR = 2;
// in centroid table
static const int64_t CID_IDX = 0;
static const int64_t CID_VECTOR_IDX = 1;
protected:
common::ObArenaAllocator vec_op_alloc_;
share::ObLSID ls_id_;
@ -257,13 +290,13 @@ protected:
virtual int inner_get_next_rows(int64_t &count, int64_t capacity) override;
protected:
int get_nearest_probe_center_ids(ObIArray<ObString> &nearest_cids);
int get_nearest_probe_center_ids(bool is_vectorized, ObIArray<ObString> &nearest_cids);
int get_main_rowkey_from_cid_vec_datum(const ObDASScanCtDef *cid_vec_ctdef,
const int64_t rowkey_cnt,
ObRowkey *&main_rowkey);
virtual int process_ivf_scan(bool is_vectorized);
template <typename T>
int get_nearest_limit_rowkeys_in_cids(const ObIArray<ObString> &nearest_cids, T *serch_vec);
int get_nearest_limit_rowkeys_in_cids(bool is_vectorized, const ObIArray<ObString> &nearest_cids, T *serch_vec);
int get_rowkey(ObIAllocator &allocator, ObRowkey *&rowkey) {
const ObDASScanCtDef *ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(vec_aux_ctdef_->get_ivf_rowkey_cid_tbl_idx(),
ObTSCIRScanType::OB_VEC_IVF_ROWKEY_CID_SCAN);
@ -283,9 +316,10 @@ protected:
bool is_vectorized,
int64_t batch_row_count,
bool &index_end);
virtual int process_ivf_scan_post();
virtual int process_ivf_scan_post(bool is_vectorized);
int process_ivf_scan_brute();
int generate_nearear_cid_heap(
bool is_vectorized,
share::ObVectorCentorClusterHelper<float, ObString> &nearest_cid_heap,
bool save_center_vec = false);
int prepare_cid_range(
@ -309,6 +343,7 @@ protected:
const ObDASScanCtDef *cid_vec_ctdef,
ObIAllocator& allocator,
blocksstable::ObDatumRow *datum_row,
bool save_center_vec,
ObString &cid,
ObString &cid_vec);
};
@ -347,8 +382,7 @@ protected:
}
int inner_release() override;
int process_ivf_scan(bool is_vectorized) override;
int process_ivf_scan_post() override;
int process_ivf_scan_post(bool is_vectorized) override;
int process_ivf_scan_pre(ObIAllocator &allocator, bool is_vectorized) override;
int filter_pre_rowkey_batch(const ObIArray<std::pair<ObString, float *>> &nearest_cids,
bool is_vectorized,
@ -367,10 +401,11 @@ protected:
ObRowkey *&main_rowkey,
ObArrayBinary *&com_key);
int calc_nearest_limit_rowkeys_in_cids(
bool is_vectorized,
const ObIArray<std::pair<ObString, float *>> &nearest_centers,
float *search_vec);
int get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float *&pq_cid_vec);
int get_nearest_probe_centers(ObIArray<std::pair<ObString, float *>> &nearest_centers);
int get_nearest_probe_centers(bool is_vectorized, ObIArray<std::pair<ObString, float *>> &nearest_centers);
int get_cid_from_pq_rowkey_cid_table(ObIAllocator &allocator, ObString &cid, ObArrayBinary *&pq_cids);
int check_cid_exist(
const ObIArray<std::pair<ObString, float *>> &dst_cids,
@ -416,6 +451,7 @@ private:
class ObDASIvfSQ8ScanIter : public ObDASIvfScanIter
{
public:
static const int64_t META_VECTOR_IDX = 1;
ObDASIvfSQ8ScanIter()
: ObDASIvfScanIter(),
sq_meta_iter_(nullptr),
@ -442,9 +478,9 @@ protected:
virtual int inner_get_next_row() override;
virtual int inner_get_next_rows(int64_t &count, int64_t capacity) override;
int get_real_search_vec_u8(ObString &real_search_vec_u8);
int get_real_search_vec_u8(bool is_vectorized, ObString &real_search_vec_u8);
int process_ivf_scan_sq(bool is_vectorized);
int process_ivf_scan_post_sq();
int process_ivf_scan_post_sq(bool is_vectorized);
private:
ObDASScanIter *sq_meta_iter_;