From 02000f544be85fffd69b0dde6d9caa79ffb2b5dc Mon Sep 17 00:00:00 2001 From: lyxiong0 Date: Wed, 19 Feb 2025 14:16:26 +0000 Subject: [PATCH] [vector] ivf post-search adapts the vectorized interface --- src/sql/das/iter/ob_das_ivf_scan_iter.cpp | 409 +++++++++++++++------- src/sql/das/iter/ob_das_ivf_scan_iter.h | 66 +++- 2 files changed, 326 insertions(+), 149 deletions(-) diff --git a/src/sql/das/iter/ob_das_ivf_scan_iter.cpp b/src/sql/das/iter/ob_das_ivf_scan_iter.cpp index 52cdb2bbd4..200164bc5d 100644 --- a/src/sql/das/iter/ob_das_ivf_scan_iter.cpp +++ b/src/sql/das/iter/ob_das_ivf_scan_iter.cpp @@ -87,21 +87,31 @@ int ObDASIvfBaseScanIter::gen_rowkeys_itr() return ret; } -int ObDASIvfBaseScanIter::do_table_full_scan(bool &first_scan, - ObTableScanParam &scan_param, +int ObDASIvfBaseScanIter::do_table_full_scan(bool is_vectorized, const ObDASScanCtDef *ctdef, ObDASScanRtDef *rtdef, ObDASScanIter *iter, int64_t pri_key_cnt, - ObTabletID &tablet_id) + ObTabletID &tablet_id, + bool &first_scan, + ObTableScanParam &scan_param) { int ret = OB_SUCCESS; if (first_scan) { ObNewRange scan_range; - if (OB_FAIL(ObDasVecScanUtils::init_vec_aux_scan_param( - ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param))) { - LOG_WARN("failed to init scan param", K(ret)); + if (is_vectorized) { + if (OB_FAIL( + ObDasVecScanUtils::init_scan_param(ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false/*is_get*/))) { + LOG_WARN("failed to generate init vec aux scan param", K(ret)); + } + } else { + if (OB_FAIL(ObDasVecScanUtils::init_vec_aux_scan_param( + ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param))) { + LOG_WARN("failed to init scan param", K(ret)); + } + } + if (OB_FAIL(ret)) { } else if (OB_FALSE_IT(ObDasVecScanUtils::set_whole_range(scan_range, ctdef->ref_table_id_))) { } else if (OB_FAIL(scan_param.key_ranges_.push_back(scan_range))) { LOG_WARN("failed to append scan range", K(ret)); @@ -143,7 +153,7 @@ int ObDASIvfBaseScanIter::do_aux_table_scan(bool &first_scan, if (first_scan) { scan_param.need_switch_param_ = false; if (OB_FAIL(ObDasVecScanUtils::init_scan_param( - ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false))) { + ls_id_, tablet_id, ctdef, rtdef, tx_desc_, snapshot_, scan_param, false/*is_get*/))) { LOG_WARN("failed to init scan param", K(ret)); } else if (OB_FALSE_IT(iter->set_scan_param(scan_param))) { } else if (OB_FAIL(iter->do_table_scan())) { @@ -538,6 +548,7 @@ int ObDASIvfScanIter::parse_centroid_datum_with_deep_copy( const ObDASScanCtDef *centroid_ctdef, ObIAllocator& allocator, blocksstable::ObDatumRow *datum_row, + bool save_center_vec, ObString &cid, ObString &cid_vec) { @@ -565,13 +576,18 @@ int ObDASIvfScanIter::parse_centroid_datum_with_deep_copy( // ignoring null vector. } else if (OB_FAIL(ob_write_string(allocator, tmp_cid, cid))) { LOG_WARN("failed to write string", K(ret), K(tmp_cid)); - } else if (OB_FAIL(ob_write_string(allocator, tmp_cid_vec, cid_vec))) { - LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec)); + } else if (save_center_vec) { + if (OB_FAIL(ob_write_string(allocator, tmp_cid_vec, cid_vec))) { + LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec)); + } + } else { + cid_vec = tmp_cid_vec; } return ret; } int ObDASIvfScanIter::generate_nearear_cid_heap( + bool is_vectorized, share::ObVectorCentorClusterHelper &nearest_cid_heap, bool save_center_vec /*= false*/) { @@ -579,14 +595,57 @@ int ObDASIvfScanIter::generate_nearear_cid_heap( const ObDASScanCtDef *centroid_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef( vec_aux_ctdef_->get_ivf_centroid_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CENTROID_SCAN); ObDASScanRtDef *centroid_rtdef = vec_aux_rtdef_->get_vec_aux_tbl_rtdef(vec_aux_ctdef_->get_ivf_centroid_tbl_idx()); - if (OB_FAIL(do_table_full_scan(centroid_iter_first_scan_, - centroid_scan_param_, - centroid_ctdef, - centroid_rtdef, - centroid_iter_, - CENTROID_PRI_KEY_CNT, - centroid_tablet_id_))) { + if (OB_FAIL(do_table_full_scan(is_vectorized, + centroid_ctdef, + centroid_rtdef, + centroid_iter_, + CENTROID_PRI_KEY_CNT, + centroid_tablet_id_, + centroid_iter_first_scan_, + centroid_scan_param_))) { LOG_WARN("failed to do centroid table scan", K(ret)); + } else if (is_vectorized) { + IVF_GET_NEXT_ROWS_BEGIN(centroid_iter_) + if (OB_SUCC(ret)) { + ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_); + guard.set_batch_size(scan_row_cnt); + ObExpr *cid_expr = centroid_ctdef->result_output_[CID_IDX]; + ObExpr *cid_vec_expr = centroid_ctdef->result_output_[CID_VECTOR_IDX]; + ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_); + ObDatum *cid_vec_datum = cid_vec_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_); + bool has_lob_header = centroid_ctdef->result_output_.at(CID_VECTOR_IDX)->obj_meta_.has_lob_header(); + + for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) { + guard.set_batch_idx(i); + ObString tmp_cid = cid_datum[i].get_string(); + ObString tmp_cid_vec = cid_vec_datum[i].get_string(); + ObString cid; + ObString cid_vec; + if (OB_FAIL(ObTextStringHelper::read_real_string_data( + &vec_op_alloc_, + ObLongTextType, + CS_TYPE_BINARY, + has_lob_header, + tmp_cid_vec))) { + LOG_WARN("failed to get real data.", K(ret)); + } else if (tmp_cid_vec.empty()) { + // ignoring null vector. + } else if (OB_FAIL(ob_write_string(vec_op_alloc_, tmp_cid, cid))) { + LOG_WARN("failed to write string", K(ret), K(tmp_cid)); + } else if (save_center_vec) { + if (OB_FAIL(ob_write_string(vec_op_alloc_, tmp_cid_vec, cid_vec))) { + LOG_WARN("failed to write string", K(ret), K(tmp_cid_vec)); + } + } else { + cid_vec = tmp_cid_vec; + } + if (OB_FAIL(ret)) { + } else if (OB_FAIL(nearest_cid_heap.push_center(cid, reinterpret_cast(cid_vec.ptr()), dim_, save_center_vec))) { + LOG_WARN("failed to push center.", K(ret)); + } + } + } + IVF_GET_NEXT_ROWS_END(centroid_iter_, centroid_scan_param_, centroid_tablet_id_) } else { storage::ObTableScanIterator *centroid_scan_iter = static_cast(centroid_iter_->get_output_result_iter()); @@ -596,9 +655,6 @@ int ObDASIvfScanIter::generate_nearear_cid_heap( } else { // get vec in centroid_iter and get l2_distance(c_vec, search_vec_) // push in max_cid_heap - const ObDASScanCtDef *centroid_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef( - vec_aux_ctdef_->get_ivf_centroid_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CENTROID_SCAN); - ObString c_vec; ObString c_id; while (OB_SUCC(ret)) { @@ -607,7 +663,7 @@ int ObDASIvfScanIter::generate_nearear_cid_heap( if (OB_ITER_END != ret) { LOG_WARN("get next row failed.", K(ret)); } - } else if (OB_FAIL(parse_centroid_datum_with_deep_copy(centroid_ctdef, vec_op_alloc_, datum_row, c_id, c_vec))) { + } else if (OB_FAIL(parse_centroid_datum_with_deep_copy(centroid_ctdef, vec_op_alloc_, datum_row, save_center_vec, c_id, c_vec))) { LOG_WARN("fail to deep copy centroid datum", K(ret)); } else if (c_vec.empty()) { // ignoring null vector. @@ -627,14 +683,14 @@ int ObDASIvfScanIter::generate_nearear_cid_heap( return ret; } -int ObDASIvfScanIter::get_nearest_probe_center_ids(ObIArray &nearest_cids) +int ObDASIvfScanIter::get_nearest_probe_center_ids(bool is_vectorized, ObIArray &nearest_cids) { int ret = OB_SUCCESS; share::ObVectorCentorClusterHelper nearest_cid_heap( vec_op_alloc_, reinterpret_cast(real_search_vec_.ptr()), dim_, nprobes_); if (OB_FAIL(nearest_cids.reserve(nprobes_))) { LOG_WARN("failed to reserve nearest_cids", K(ret)); - } else if (OB_FAIL(generate_nearear_cid_heap(nearest_cid_heap))) { + } else if (OB_FAIL(generate_nearear_cid_heap(is_vectorized, nearest_cid_heap))) { LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_)); } if (OB_SUCC(ret) || ret == OB_ITER_END) { @@ -779,7 +835,7 @@ int ObDASIvfScanIter::parse_cid_vec_datum( } template -int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray &nearest_cids, T *serch_vec) +int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(bool is_vectorized, const ObIArray &nearest_cids, T *serch_vec) { int ret = OB_SUCCESS; @@ -802,6 +858,38 @@ int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray storage::ObTableScanIterator *cid_vec_scan_iter = nullptr; if (OB_FAIL(scan_cid_range(cid, cid_vec_pri_key_cnt, cid_vec_ctdef, cid_vec_rtdef, cid_vec_scan_iter))) { LOG_WARN("fail to scan cid range", K(ret), K(cid), K(cid_vec_pri_key_cnt)); + } else if (is_vectorized) { + IVF_GET_NEXT_ROWS_BEGIN(cid_vec_iter_) + if (OB_SUCC(ret)) { + ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_); + guard.set_batch_size(scan_row_cnt); + bool has_lob_header = cid_vec_ctdef->result_output_.at(CID_VECTOR_IDX)->obj_meta_.has_lob_header(); + ObExpr *cid_expr = cid_vec_ctdef->result_output_[CID_VECTOR_IDX]; + ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_); + + for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) { + guard.set_batch_idx(i); + ObRowkey *main_rowkey = nullptr; + ObString vec = cid_datum[i].get_string(); + void *buf = nullptr; + ObObj *obj_ptr = nullptr; + if (OB_FAIL(ObTextStringHelper::read_real_string_data( + &vec_op_alloc_, + ObLongTextType, + CS_TYPE_BINARY, + has_lob_header, + vec))) { + LOG_WARN("failed to get real data.", K(ret)); + } else if (OB_ISNULL(vec.ptr())) { + // ignoring null vector. + } else if (OB_FAIL(get_main_rowkey_from_cid_vec_datum(cid_vec_ctdef, rowkey_cnt, main_rowkey))) { + LOG_WARN("fail to get main rowkey", K(ret)); + } else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, reinterpret_cast(vec.ptr()), dim_))) { + LOG_WARN("failed to push center.", K(ret)); + } + } + } + IVF_GET_NEXT_ROWS_END(cid_vec_iter_, cid_vec_scan_param_, cid_vec_tablet_id_) } else { while (OB_SUCC(ret)) { ObRowkey *main_rowkey; @@ -819,12 +907,12 @@ int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(const ObIArray LOG_WARN("failed to push center.", K(ret)); } } - } - if (ret == OB_ITER_END) { - ret = OB_SUCCESS; - if (OB_FAIL(cid_vec_iter_->reuse())) { - LOG_WARN("fail to reuse scan iterator.", K(ret)); + if (ret == OB_ITER_END) { + ret = OB_SUCCESS; + if (OB_FAIL(cid_vec_iter_->reuse())) { + LOG_WARN("fail to reuse scan iterator.", K(ret)); + } } } } @@ -909,16 +997,16 @@ int ObDASIvfScanIter::process_ivf_scan_brute() return ret; } -int ObDASIvfScanIter::process_ivf_scan_post() +int ObDASIvfScanIter::process_ivf_scan_post(bool is_vectorized) { int ret = OB_SUCCESS; ObSEArray nearest_cids; nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid")); - if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) { + if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) { // 1. Scan the centroid table, range: [min] ~ [max]; sort by l2_distance(c_vec, search_vec_); limit nprobes; // return the cid column. LOG_WARN("failed to get nearest probe center ids", K(ret)); - } else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids(nearest_cids, + } else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids(is_vectorized, nearest_cids, reinterpret_cast(real_search_vec_.ptr())))) { // 2. get cidx in (top-nprobes 个 cid) // scan cid_vector table, range: [cidx, min] ~ [cidx, max]; sort by l2_distance(vec, search_vec_); limit @@ -933,7 +1021,7 @@ int ObDASIvfScanIter::process_ivf_scan(bool is_vectorized) int ret = OB_SUCCESS; if (vec_aux_ctdef_->is_post_filter()) { - if (OB_FAIL(process_ivf_scan_post())) { + if (OB_FAIL(process_ivf_scan_post(is_vectorized))) { LOG_WARN("failed to process ivf_scan post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys)); } } else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) { @@ -1032,7 +1120,7 @@ int ObDASIvfScanIter::filter_pre_rowkey_batch(const ObIArray &nearest_ } else if (OB_FAIL(filter_rowkey_by_cid(nearest_cids, is_vectorized, batch_row_count, index_end))) { LOG_WARN("filter rowkey batch failed", K(ret), K(nearest_cids), K(is_vectorized), K(batch_row_count)); } else { - int tmp_ret = ret; + int tmp_ret = ret; // NOTE(liyao): 跳过reuse_iter返回值 if (OB_FAIL(ObDasVecScanUtils::reuse_iter( ls_id_, rowkey_cid_iter_, rowkey_cid_scan_param_, rowkey_cid_tablet_id_))) { LOG_WARN("failed to reuse rowkey cid iter.", K(ret)); @@ -1052,7 +1140,7 @@ int ObDASIvfScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_vect int64_t batch_row_count = ObVectorParamData::VI_PARAM_DATA_BATCH_SIZE; ObSEArray nearest_cids; nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid")); - if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) { + if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) { LOG_WARN("failed to get nearest probe center ids", K(ret)); } else if (nearest_cids.count() == 1 && nearest_cids.at(0).empty()) { // create an index on the table, No 1 table is empty @@ -1148,7 +1236,7 @@ int ObDASIvfScanIter::inner_get_next_row() if (limit_param_.limit_ + limit_param_.offset_ == 0) { ret = OB_ITER_END; } else if (OB_ISNULL(saved_rowkeys_itr_)) { - if (OB_FAIL(process_ivf_scan(false))) { + if (OB_FAIL(process_ivf_scan(false/*is_vectorized*/))) { if (OB_ITER_END != ret) { LOG_WARN("failed to process ivf scan state", K(ret)); } @@ -1173,7 +1261,7 @@ int ObDASIvfScanIter::inner_get_next_rows(int64_t &count, int64_t capacity) if (limit_param_.limit_ + limit_param_.offset_ == 0) { ret = OB_ITER_END; } else if (OB_ISNULL(saved_rowkeys_itr_)) { - if (OB_FAIL(process_ivf_scan(true))) { + if (OB_FAIL(process_ivf_scan(true/*is_vectorized*/))) { if (OB_ITER_END != ret) { LOG_WARN("failed to process ivf scan state", K(ret)); } @@ -1325,9 +1413,10 @@ int ObDASIvfPQScanIter::get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float * LOG_WARN("invalid null scan iter", K(ret)); } else { int row_cnt = 0; + bool has_lob_header = pq_cid_vec_ctdef->result_output_.at(PQ_CENTROID_VEC_IDX)->obj_meta_.has_lob_header(); while (OB_SUCC(ret)) { ObExpr *pq_center_vec_expr = pq_cid_vec_ctdef->result_output_[PQ_CENTROID_VEC_IDX]; - ObRowkey *main_rowkey; + ObRowkey *main_rowkey = nullptr; ObString vec; // cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY] if (OB_FAIL(cid_vec_scan_iter->get_next_row())) { @@ -1340,7 +1429,7 @@ int ObDASIvfPQScanIter::get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float * &vec_op_alloc_, ObLongTextType, CS_TYPE_BINARY, - pq_cid_vec_ctdef->result_output_.at(PQ_CENTROID_VEC_IDX)->obj_meta_.has_lob_header(), + has_lob_header, vec))) { LOG_WARN("failed to get real data.", K(ret)); } else if (OB_ISNULL(vec.ptr())) { @@ -1370,7 +1459,7 @@ int ObDASIvfPQScanIter::calc_distance_between_pq_ids( double &distance) { int ret = OB_SUCCESS; - distance = 0.0; + double square = 0.0; for (int j = 0; OB_SUCC(ret) && j < m_; ++j) { // 3.2.1 pq_center_ids[j] is put into ivf_pq_centroid table to find pq_center_vecs[j] float *pq_cid_vec = nullptr; @@ -1379,18 +1468,20 @@ int ObDASIvfPQScanIter::calc_distance_between_pq_ids( } else { // 3.3.2 Calculate the distance between pq_center_vecs[j] and r(x)[j]. // The sum of j = 0 ~ m is the distance from x to rowkey - double cur_distance = DBL_MAX; - if (OB_FAIL(ObVectorL2Distance::l2_distance_func(splited_residual.at(j), pq_cid_vec, dim_ / m_, cur_distance))) { + double cur_square = DBL_MAX; + if (OB_FAIL(ObVectorL2Distance::l2_square_func(splited_residual.at(j), pq_cid_vec, dim_ / m_, cur_square))) { LOG_WARN("failed to calc l2 distance", K(ret)); } else { - distance += cur_distance; + square += cur_square; } } } // end for + distance = sqrt(square); return ret; } int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids( + bool is_vectorized, const ObIArray &nearest_centers, float *search_vec) { @@ -1429,40 +1520,75 @@ int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids( storage::ObTableScanIterator *cid_vec_scan_iter = nullptr; if (OB_FAIL(scan_cid_range(cur_cid, cid_vec_pri_key_cnt, cid_vec_ctdef, cid_vec_rtdef, cid_vec_scan_iter))) { LOG_WARN("fail to scan cid range", K(ret), K(cur_cid), K(cid_vec_pri_key_cnt)); - } - while (OB_SUCC(ret)) { - ObRowkey *main_rowkey; - ObString vec_arr_str; - ObArrayBinary *pq_center_ids = NULL; - double distance = 0.0; - // cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY] - if (OB_FAIL(cid_vec_scan_iter->get_next_row())) { - if (OB_ITER_END != ret) { - LOG_WARN("failed to scan vid rowkey iter", K(ret)); - } - } else if (OB_FAIL(parse_cid_vec_datum( - tmp_allocator, - cid_vec_column_count, - cid_vec_ctdef, - rowkey_cnt, - main_rowkey, - pq_center_ids))) { - LOG_WARN("fail to parse cid vec datum", K(ret), K(cid_vec_column_count), K(rowkey_cnt)); - } else if (OB_ISNULL(pq_center_ids)) { - // ignore null arr - } else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) { - LOG_WARN("fail to calc distance between pq ids", K(ret)); - } else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) { - LOG_WARN("failed to push center.", K(ret)); - } - } // end while + } else if (is_vectorized) { + IVF_GET_NEXT_ROWS_BEGIN(cid_vec_iter_) + if (OB_SUCC(ret)) { + ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_); + guard.set_batch_size(scan_row_cnt); + ObExpr *cid_expr = cid_vec_ctdef->result_output_[PQ_IDS_IDX]; + ObDatum *cid_datum = cid_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_); - if (ret == OB_ITER_END) { - ret = OB_SUCCESS; - if (OB_FAIL(cid_vec_iter_->reuse())) { - LOG_WARN("fail to reuse scan iterator.", K(ret)); - } else { - splited_residual.reuse(); + for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) { + guard.set_batch_idx(i); + if (cid_datum[i].is_null()) { + // continue + } else { + ObRowkey *main_rowkey = nullptr; + ObString vec = cid_datum[i].get_string(); + void *buf = nullptr; + ObObj *obj_ptr = nullptr; + ObArrayBinary *pq_center_ids = NULL; + double distance = 0.0; + if (OB_FAIL(get_pq_cids_from_datum(tmp_allocator, vec, pq_center_ids))) { + LOG_WARN("fail to get pq cids from datum", K(ret)); + } else if (OB_ISNULL(pq_center_ids)) { + // ignore null arr + } else if (OB_FAIL(get_main_rowkey_from_cid_vec_datum(cid_vec_ctdef, rowkey_cnt, main_rowkey))) { + LOG_WARN("fail to get main rowkey", K(ret)); + } else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) { + LOG_WARN("fail to calc distance between pq ids", K(ret)); + } else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) { + LOG_WARN("failed to push center.", K(ret)); + } + } + } + } + IVF_GET_NEXT_ROWS_END(cid_vec_iter_, cid_vec_scan_param_, cid_vec_tablet_id_) + } else { + while (OB_SUCC(ret)) { + ObRowkey *main_rowkey = nullptr; + ObString vec_arr_str; + ObArrayBinary *pq_center_ids = NULL; + double distance = 0.0; + // cid_vec_scan_iter output: [IVF_CID_VEC_CID_COL IVF_CID_VEC_VECTOR_COL ROWKEY] + if (OB_FAIL(cid_vec_scan_iter->get_next_row())) { + if (OB_ITER_END != ret) { + LOG_WARN("failed to scan vid rowkey iter", K(ret)); + } + } else if (OB_FAIL(parse_cid_vec_datum( + tmp_allocator, + cid_vec_column_count, + cid_vec_ctdef, + rowkey_cnt, + main_rowkey, + pq_center_ids))) { + LOG_WARN("fail to parse cid vec datum", K(ret), K(cid_vec_column_count), K(rowkey_cnt)); + } else if (OB_ISNULL(pq_center_ids)) { + // ignore null arr + } else if (OB_FAIL(calc_distance_between_pq_ids(*pq_center_ids, splited_residual, distance))) { + LOG_WARN("fail to calc distance between pq ids", K(ret)); + } else if (OB_FAIL(nearest_rowkey_heap.push_center(main_rowkey, distance))) { + LOG_WARN("failed to push center.", K(ret)); + } + } // end while + + if (ret == OB_ITER_END) { + ret = OB_SUCCESS; + if (OB_FAIL(cid_vec_iter_->reuse())) { + LOG_WARN("fail to reuse scan iterator.", K(ret)); + } else { + splited_residual.reuse(); + } } } } @@ -1478,14 +1604,14 @@ int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids( return ret; } -int ObDASIvfPQScanIter::get_nearest_probe_centers(ObIArray &nearest_centers) +int ObDASIvfPQScanIter::get_nearest_probe_centers(bool is_vectorized, ObIArray &nearest_centers) { int ret = OB_SUCCESS; share::ObVectorCentorClusterHelper nearest_cid_heap( vec_op_alloc_, reinterpret_cast(real_search_vec_.ptr()), dim_, nprobes_); if (OB_FAIL(nearest_centers.reserve(nprobes_))) { LOG_WARN("failed to reserve nearest_centers", K(ret)); - } else if (OB_FAIL(generate_nearear_cid_heap(nearest_cid_heap, true/*save_center_vec*/))) { + } else if (OB_FAIL(generate_nearear_cid_heap(is_vectorized, nearest_cid_heap, true/*save_center_vec*/))) { LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_)); } if (OB_SUCC(ret) || ret == OB_ITER_END) { @@ -1521,12 +1647,12 @@ int ObDASIvfPQScanIter::get_rowkey_brute_post() return ret; } -int ObDASIvfPQScanIter::process_ivf_scan_post() +int ObDASIvfPQScanIter::process_ivf_scan_post(bool is_vectorized) { int ret = OB_SUCCESS; ObSEArray nearest_centers; nearest_centers.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid")); - if (OB_FAIL(get_nearest_probe_centers(nearest_centers))) { + if (OB_FAIL(get_nearest_probe_centers(is_vectorized, nearest_centers))) { // 1. Scan the ivf_centroid table, calculate the distance between vec_x and cid_vec, // and get the nearest cluster center (cid 1, cid_vec 1)... (cid n, cid_vec n) LOG_WARN("failed to get nearest probe center ids", K(ret)); @@ -1535,6 +1661,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_post() LOG_WARN("failed to get limit rowkey brute", K(ret)); } } else if (OB_FAIL(calc_nearest_limit_rowkeys_in_cids( + is_vectorized, nearest_centers, reinterpret_cast(real_search_vec_.ptr())))) { // 2. search nearest rowkeys @@ -1749,7 +1876,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve IvfRowkeyHeap nearest_rowkey_heap( vec_op_alloc_, reinterpret_cast(real_search_vec_.ptr())/*unused*/, dim_ / m_, get_nprobe(limit_param_, PQ_ID_ENLARGEMENT_FACTOR)); - if (OB_FAIL(get_nearest_probe_centers(nearest_centers))) { + if (OB_FAIL(get_nearest_probe_centers(is_vectorized, nearest_centers))) { // 1. Scan the ivf_centroid table, calculate the distance between vec_x and cid_vec, // and get the nearest cluster center (cid 1, cid_vec 1)... (cid n, cid_vec n) LOG_WARN("failed to get nearest probe center ids", K(ret)); @@ -1797,27 +1924,6 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve return ret; } -int ObDASIvfPQScanIter::process_ivf_scan(bool is_vectorized) -{ - int ret = OB_SUCCESS; - - if (vec_aux_ctdef_->is_post_filter()) { - if (OB_FAIL(process_ivf_scan_post())) { - LOG_WARN("failed to process adaptorstate post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys)); - } - } else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) { - LOG_WARN("failed to process adaptor state hnsw", K(ret)); - } - - if (OB_FAIL(ret)) { - } else if (OB_FAIL(gen_rowkeys_itr())) { - if (ret != OB_ITER_END) { - LOG_WARN("failed to gen adapter rowkeys itr", K(saved_rowkeys_)); - } - } - return ret; -} - /********************************************************************************************************/ int ObDASIvfSQ8ScanIter::inner_get_next_rows(int64_t &count, int64_t capacity) { @@ -1825,7 +1931,7 @@ int ObDASIvfSQ8ScanIter::inner_get_next_rows(int64_t &count, int64_t capacity) if (limit_param_.limit_ + limit_param_.offset_ == 0) { ret = OB_ITER_END; } else if (OB_ISNULL(saved_rowkeys_itr_)) { - if (OB_FAIL(process_ivf_scan_sq(true))) { + if (OB_FAIL(process_ivf_scan_sq(true/*is_vectorized*/))) { if (OB_ITER_END != ret) { LOG_WARN("failed to process adaptor state", K(ret)); } @@ -1849,7 +1955,7 @@ int ObDASIvfSQ8ScanIter::inner_get_next_row() if (limit_param_.limit_ + limit_param_.offset_ == 0) { ret = OB_ITER_END; } else if (OB_ISNULL(saved_rowkeys_itr_)) { - if (OB_FAIL(process_ivf_scan_sq(false))) { + if (OB_FAIL(process_ivf_scan_sq(false/*is_vectorized*/))) { if (OB_ITER_END != ret) { LOG_WARN("failed to process adaptor state", K(ret)); } @@ -1871,7 +1977,7 @@ int ObDASIvfSQ8ScanIter::process_ivf_scan_sq(bool is_vectorized) int ret = OB_SUCCESS; if (vec_aux_ctdef_->is_post_filter()) { - if (OB_FAIL(process_ivf_scan_post_sq())) { + if (OB_FAIL(process_ivf_scan_post_sq(is_vectorized))) { LOG_WARN("failed to process adaptorstate post filter", K(ret), K_(pre_fileter_rowkeys), K_(saved_rowkeys)); } } else if (OB_FAIL(process_ivf_scan_pre(vec_op_alloc_, is_vectorized))) { @@ -1913,20 +2019,52 @@ int ObDASIvfSQ8ScanIter::inner_release() return ret; } -int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8) +int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(bool is_vectorized, ObString &real_search_vec_u8) { int ret = OB_SUCCESS; const ObDASScanCtDef *sq_meta_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef( vec_aux_ctdef_->get_ivf_sq_meta_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_SPECIAL_AUX_SCAN); ObDASScanRtDef *sq_meta_rtdef = vec_aux_rtdef_->get_vec_aux_tbl_rtdef(vec_aux_ctdef_->get_ivf_sq_meta_tbl_idx()); - if (OB_FAIL(do_table_full_scan(sq_meta_iter_first_scan_, - sq_meta_scan_param_, + ObString min_vec; + ObString step_vec; + if (OB_FAIL(do_table_full_scan(is_vectorized, sq_meta_ctdef, sq_meta_rtdef, sq_meta_iter_, SQ_MEAT_PRI_KEY_CNT, - sq_meta_tablet_id_))) { + sq_meta_tablet_id_, + sq_meta_iter_first_scan_, + sq_meta_scan_param_))) { LOG_WARN("failed to do table scan sq_meta", K(ret)); + } else if (is_vectorized) { + IVF_GET_NEXT_ROWS_BEGIN(sq_meta_iter_) + if (OB_SUCC(ret)) { + ObEvalCtx::BatchInfoScopeGuard guard(*vec_aux_rtdef_->eval_ctx_); + guard.set_batch_size(scan_row_cnt); + bool has_lob_header = sq_meta_ctdef->result_output_.at(META_VECTOR_IDX)->obj_meta_.has_lob_header(); + ObExpr *meta_vec_expr = sq_meta_ctdef->result_output_[META_VECTOR_IDX]; + ObDatum *meta_vec_datum = meta_vec_expr->locate_batch_datums(*vec_aux_rtdef_->eval_ctx_); + + for (int64_t i = 0; OB_SUCC(ret) && i < scan_row_cnt; ++i) { + guard.set_batch_idx(i); + if (i == ObIvfConstant::SQ8_META_MIN_IDX || i == ObIvfConstant::SQ8_META_STEP_IDX) { + ObString c_vec = meta_vec_datum[i].get_string(); + if (OB_FAIL(ObTextStringHelper::read_real_string_data( + &vec_op_alloc_, + ObLongTextType, + CS_TYPE_BINARY, + has_lob_header, + c_vec))) { + LOG_WARN("failed to get real data.", K(ret)); + } else if (i == ObIvfConstant::SQ8_META_MIN_IDX) { + min_vec = c_vec; + } else if (i == ObIvfConstant::SQ8_META_STEP_IDX) { + step_vec = c_vec; + } + } + } + } + IVF_GET_NEXT_ROWS_END(sq_meta_iter_, sq_meta_scan_param_, sq_meta_tablet_id_) } else { // get min_vec max_vec step_vec in sq_meta table storage::ObTableScanIterator *sq_meta_scan_iter = @@ -1935,8 +2073,6 @@ int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8) ret = OB_ERR_UNEXPECTED; LOG_WARN("real sql meta iter is null", K(ret)); } else { - ObString min_vec; - ObString step_vec; int row_index = 0; while (OB_SUCC(ret)) { ObString c_vec; @@ -1970,34 +2106,39 @@ int ObDASIvfSQ8ScanIter::get_real_search_vec_u8(ObString &real_search_vec_u8) } if (ret == OB_ITER_END) { - ret = OB_SUCCESS; - uint8_t *res_vec = nullptr; - if (OB_ISNULL(min_vec.ptr()) || OB_ISNULL(step_vec.ptr())) { - if (OB_ISNULL(res_vec = reinterpret_cast(vec_op_alloc_.alloc(sizeof(uint8_t) * dim_)))) { - ret = OB_ALLOCATE_MEMORY_FAILED; - LOG_WARN("failed to allocate memory", K(ret), K(sizeof(uint8_t) * dim_)); - } else { - MEMSET(res_vec, 0, sizeof(uint8_t) * dim_); - } - } else if (OB_FAIL( - ObExprVecIVFSQ8DataVector::cal_u8_data_vector(vec_op_alloc_, - dim_, - reinterpret_cast(min_vec.ptr()), - reinterpret_cast(step_vec.ptr()), - reinterpret_cast(real_search_vec_.ptr()), - res_vec))) { - LOG_WARN("fail to cal u8 data vector", K(ret), K(dim_)); - } - if (OB_SUCC(ret)) { - real_search_vec_u8.assign_ptr(reinterpret_cast(res_vec), dim_ * sizeof(uint8_t)); + if (OB_FAIL(sq_meta_iter_->reuse())) { + LOG_WARN("fail to reuse scan iterator.", K(ret)); } } } } + + if (OB_SUCC(ret)) { + uint8_t *res_vec = nullptr; + if (OB_ISNULL(min_vec.ptr()) || OB_ISNULL(step_vec.ptr())) { + if (OB_ISNULL(res_vec = reinterpret_cast(vec_op_alloc_.alloc(sizeof(uint8_t) * dim_)))) { + ret = OB_ALLOCATE_MEMORY_FAILED; + LOG_WARN("failed to allocate memory", K(ret), K(sizeof(uint8_t) * dim_)); + } else { + MEMSET(res_vec, 0, sizeof(uint8_t) * dim_); + } + } else if (OB_FAIL( + ObExprVecIVFSQ8DataVector::cal_u8_data_vector(vec_op_alloc_, + dim_, + reinterpret_cast(min_vec.ptr()), + reinterpret_cast(step_vec.ptr()), + reinterpret_cast(real_search_vec_.ptr()), + res_vec))) { + LOG_WARN("fail to cal u8 data vector", K(ret), K(dim_)); + } + if (OB_SUCC(ret)) { + real_search_vec_u8.assign_ptr(reinterpret_cast(res_vec), dim_ * sizeof(uint8_t)); + } + } return ret; } -int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq() +int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq(bool is_vectorized) { int ret = OB_SUCCESS; // post scan: @@ -2008,13 +2149,13 @@ int ObDASIvfSQ8ScanIter::process_ivf_scan_post_sq() ObSEArray nearest_cids; nearest_cids.set_attr(ObMemAttr(MTL_ID(), "VecIdxNearCid")); ObString real_search_vec_u8; - if (OB_FAIL(get_nearest_probe_center_ids(nearest_cids))) { + if (OB_FAIL(get_nearest_probe_center_ids(is_vectorized, nearest_cids))) { LOG_WARN("failed to get nearest probe center ids", K(ret)); - } else if (OB_FAIL(get_real_search_vec_u8(real_search_vec_u8))) { + } else if (OB_FAIL(get_real_search_vec_u8(is_vectorized, real_search_vec_u8))) { // real_search_vec to real_search_vec_u8 with min_vec max_vec step_vec LOG_WARN("failed to get real search vec u8", K(ret)); } else if (OB_FAIL(get_nearest_limit_rowkeys_in_cids( - nearest_cids, reinterpret_cast(real_search_vec_u8.ptr())))) { + is_vectorized, nearest_cids, reinterpret_cast(real_search_vec_u8.ptr())))) { LOG_WARN("failed to get nearest limit rowkeys in cids", K(nearest_cids)); } diff --git a/src/sql/das/iter/ob_das_ivf_scan_iter.h b/src/sql/das/iter/ob_das_ivf_scan_iter.h index 65d2c6c349..911d9e93b0 100644 --- a/src/sql/das/iter/ob_das_ivf_scan_iter.h +++ b/src/sql/das/iter/ob_das_ivf_scan_iter.h @@ -24,6 +24,35 @@ namespace oceanbase using namespace common; namespace sql { +#define IVF_GET_NEXT_ROWS_BEGIN(iter) \ + bool index_end = false; \ + iter->clear_evaluated_flag(); \ + int64_t scan_row_cnt = 0; \ + int64_t batch_row_count = ObVectorParamData::VI_PARAM_DATA_BATCH_SIZE; \ + while (!index_end && OB_SUCC(ret)) { \ + if (OB_FAIL(iter->get_next_rows(scan_row_cnt, batch_row_count))) { \ + if (OB_ITER_END != ret) { \ + LOG_WARN("failed to get next row.", K(ret)); \ + } else { \ + index_end = true; \ + } \ + } \ + if (OB_FAIL(ret) && OB_ITER_END != ret) { \ + } else if (scan_row_cnt > 0) { \ + ret = OB_SUCCESS; \ + } + +#define IVF_GET_NEXT_ROWS_END(iter, scan_param, tablet_id) \ + } \ + if (index_end) { \ + int tmp_ret = (ret == OB_ITER_END) ? OB_SUCCESS : ret; \ + if (OB_FAIL(ObDasVecScanUtils::reuse_iter(ls_id_, iter, scan_param, tablet_id))) { \ + LOG_WARN("failed to reuse rowkey cid iter.", K(ret)); \ + } else { \ + ret = tmp_ret; \ + } \ + } + struct ObDASIvfScanIterParam : public ObDASIterParam { public: explicit ObDASIvfScanIterParam(const ObVectorIndexAlgorithmType index_type) @@ -153,13 +182,14 @@ protected: } protected: - int do_table_full_scan(bool &first_scan, - ObTableScanParam &scan_param, - const ObDASScanCtDef *ctdef, - ObDASScanRtDef *rtdef, - ObDASScanIter *iter, - int64_t pri_key_cnt, - ObTabletID &tablet_id); + int do_table_full_scan(bool is_vectorized, + const ObDASScanCtDef *ctdef, + ObDASScanRtDef *rtdef, + ObDASScanIter *iter, + int64_t pri_key_cnt, + ObTabletID &tablet_id, + bool &first_scan, + ObTableScanParam &scan_param); int do_aux_table_scan(bool &first_scan, ObTableScanParam &scan_param, const ObDASScanCtDef *ctdef, @@ -186,6 +216,9 @@ protected: static const int64_t SQ_MEAT_PRI_KEY_CNT = 1; static const int64_t SQ_MEAT_ALL_KEY_CNT = 2; static const int64_t POST_ENLARGEMENT_FACTOR = 2; + // in centroid table + static const int64_t CID_IDX = 0; + static const int64_t CID_VECTOR_IDX = 1; protected: common::ObArenaAllocator vec_op_alloc_; share::ObLSID ls_id_; @@ -257,13 +290,13 @@ protected: virtual int inner_get_next_rows(int64_t &count, int64_t capacity) override; protected: - int get_nearest_probe_center_ids(ObIArray &nearest_cids); + int get_nearest_probe_center_ids(bool is_vectorized, ObIArray &nearest_cids); int get_main_rowkey_from_cid_vec_datum(const ObDASScanCtDef *cid_vec_ctdef, const int64_t rowkey_cnt, ObRowkey *&main_rowkey); virtual int process_ivf_scan(bool is_vectorized); template - int get_nearest_limit_rowkeys_in_cids(const ObIArray &nearest_cids, T *serch_vec); + int get_nearest_limit_rowkeys_in_cids(bool is_vectorized, const ObIArray &nearest_cids, T *serch_vec); int get_rowkey(ObIAllocator &allocator, ObRowkey *&rowkey) { const ObDASScanCtDef *ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(vec_aux_ctdef_->get_ivf_rowkey_cid_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_ROWKEY_CID_SCAN); @@ -283,9 +316,10 @@ protected: bool is_vectorized, int64_t batch_row_count, bool &index_end); - virtual int process_ivf_scan_post(); + virtual int process_ivf_scan_post(bool is_vectorized); int process_ivf_scan_brute(); int generate_nearear_cid_heap( + bool is_vectorized, share::ObVectorCentorClusterHelper &nearest_cid_heap, bool save_center_vec = false); int prepare_cid_range( @@ -309,6 +343,7 @@ protected: const ObDASScanCtDef *cid_vec_ctdef, ObIAllocator& allocator, blocksstable::ObDatumRow *datum_row, + bool save_center_vec, ObString &cid, ObString &cid_vec); }; @@ -347,8 +382,7 @@ protected: } int inner_release() override; - int process_ivf_scan(bool is_vectorized) override; - int process_ivf_scan_post() override; + int process_ivf_scan_post(bool is_vectorized) override; int process_ivf_scan_pre(ObIAllocator &allocator, bool is_vectorized) override; int filter_pre_rowkey_batch(const ObIArray> &nearest_cids, bool is_vectorized, @@ -367,10 +401,11 @@ protected: ObRowkey *&main_rowkey, ObArrayBinary *&com_key); int calc_nearest_limit_rowkeys_in_cids( + bool is_vectorized, const ObIArray> &nearest_centers, float *search_vec); int get_pq_cid_vec_by_pq_cid(const ObString &pq_cid, float *&pq_cid_vec); - int get_nearest_probe_centers(ObIArray> &nearest_centers); + int get_nearest_probe_centers(bool is_vectorized, ObIArray> &nearest_centers); int get_cid_from_pq_rowkey_cid_table(ObIAllocator &allocator, ObString &cid, ObArrayBinary *&pq_cids); int check_cid_exist( const ObIArray> &dst_cids, @@ -416,6 +451,7 @@ private: class ObDASIvfSQ8ScanIter : public ObDASIvfScanIter { public: + static const int64_t META_VECTOR_IDX = 1; ObDASIvfSQ8ScanIter() : ObDASIvfScanIter(), sq_meta_iter_(nullptr), @@ -442,9 +478,9 @@ protected: virtual int inner_get_next_row() override; virtual int inner_get_next_rows(int64_t &count, int64_t capacity) override; - int get_real_search_vec_u8(ObString &real_search_vec_u8); + int get_real_search_vec_u8(bool is_vectorized, ObString &real_search_vec_u8); int process_ivf_scan_sq(bool is_vectorized); - int process_ivf_scan_post_sq(); + int process_ivf_scan_post_sq(bool is_vectorized); private: ObDASScanIter *sq_meta_iter_;