261 lines
8.7 KiB
Python
261 lines
8.7 KiB
Python
# Copyright 2024 PingCAP, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import peewee
|
|
import tidb_vector
|
|
import tabulate
|
|
import h5py
|
|
import numpy
|
|
import time
|
|
import argparse
|
|
|
|
from tidb_vector.peewee import VectorField, VectorAdaptor
|
|
from tidb_vector.utils import encode_vector
|
|
|
|
dataset_path_1 = "./datasets/fashion-mnist-784-euclidean.hdf5"
|
|
dataset_path_2 = "./datasets/mnist-784-euclidean.hdf5"
|
|
table_name = "recall_test"
|
|
|
|
mysql_db = peewee.MySQLDatabase(
|
|
"test",
|
|
host="127.0.0.1",
|
|
port=4000,
|
|
user="root",
|
|
passwd="",
|
|
)
|
|
|
|
|
|
class Sample(peewee.Model):
|
|
class Meta:
|
|
database = mysql_db
|
|
db_table = table_name
|
|
|
|
id = peewee.IntegerField(
|
|
primary_key=True,
|
|
)
|
|
vec = VectorField(784)
|
|
|
|
|
|
def connect():
|
|
print(
|
|
f"+ Connecting to {mysql_db.connect_params['user']}@{mysql_db.connect_params['host']}...",
|
|
flush=True,
|
|
)
|
|
mysql_db.connect()
|
|
|
|
|
|
def clean():
|
|
mysql_db.drop_tables([Sample], safe=True)
|
|
|
|
|
|
def create_table():
|
|
mysql_db.create_tables([Sample])
|
|
VectorAdaptor(mysql_db).create_vector_index(
|
|
Sample.vec, tidb_vector.DistanceMetric.L2
|
|
)
|
|
|
|
|
|
def load(dataset_path: str, begin: int = 0, end: int = 60000):
|
|
print()
|
|
print("+ Loading data...", flush=True)
|
|
|
|
with h5py.File(dataset_path, "r") as data_file:
|
|
data: numpy.ndarray = data_file["train"][()]
|
|
assert end <= len(data)
|
|
|
|
data_with_id = [(idx, data[idx]) for idx in range(begin, end)]
|
|
max_data_id = data_with_id[-1][0]
|
|
|
|
for batch in peewee.chunked(data_with_id, 1000):
|
|
print(
|
|
f" - Batch insert [{batch[0][0]}..{batch[-1][0]}] (max PK={max_data_id})...",
|
|
flush=True,
|
|
)
|
|
Sample.insert_many(batch, fields=[Sample.id, Sample.vec]).execute()
|
|
|
|
|
|
def remove(begin: int, end: int):
|
|
print()
|
|
print(f"+ Removing data in range [{begin}..{end})...", flush=True)
|
|
Sample.delete().where(Sample.id >= begin, Sample.id < end).execute()
|
|
|
|
|
|
def check(dataset_path: str, check_tiflash_used_index: bool):
|
|
recall = 0.0
|
|
|
|
print()
|
|
print("+ Current index distribution:")
|
|
cursor = mysql_db.execute_sql(
|
|
f"SELECT ROWS_STABLE_INDEXED, ROWS_STABLE_NOT_INDEXED, ROWS_DELTA_INDEXED, ROWS_DELTA_NOT_INDEXED FROM INFORMATION_SCHEMA.TIFLASH_INDEXES WHERE TIDB_TABLE='{table_name}'"
|
|
)
|
|
print(
|
|
tabulate.tabulate(
|
|
cursor.fetchall(),
|
|
headers=[
|
|
"StableIndexed",
|
|
"StableNotIndexed",
|
|
"DeltaIndexed",
|
|
"DeltaNotIndexed",
|
|
],
|
|
tablefmt="psql",
|
|
),
|
|
flush=True,
|
|
)
|
|
|
|
with h5py.File(dataset_path, "r") as data_file:
|
|
query_rows = data_file["test"][()]
|
|
query_rows_len = min(
|
|
len(query_rows), 200
|
|
) # Just check with first 200 test rows
|
|
|
|
print("+ Execution Plan:")
|
|
|
|
with mysql_db.execute_sql(
|
|
# EXPLAIN ANALYZE SELECT * FROM {table_name} ORDER BY VEC_L2_Distance(vec, %s) LIMIT 100
|
|
# In the cluster started by tiup, the tiflash component does not yet include pr-10103, so '*' is used here as a substitute.
|
|
# For details, see: https://github.com/pingcap/tiflash/pull/10103
|
|
f"EXPLAIN ANALYZE SELECT * FROM {table_name} ORDER BY VEC_L2_Distance(vec, %s) LIMIT 100",
|
|
(encode_vector(query_rows[0]),),
|
|
) as cursor:
|
|
plan = tabulate.tabulate(cursor.fetchall(), tablefmt="psql")
|
|
print(plan, flush=True)
|
|
assert "mpp[tiflash]" in plan
|
|
assert "annIndex:L2(vec.." in plan
|
|
if check_tiflash_used_index:
|
|
assert "vector_idx:{" in plan
|
|
|
|
print()
|
|
print(f"+ Checking recall (via {query_rows_len} groundtruths)...", flush=True)
|
|
|
|
total_recall = 0.0
|
|
total_tests = 0
|
|
|
|
for test_rowid in range(query_rows_len):
|
|
query_row: numpy.ndarray = query_rows[test_rowid]
|
|
groundtruth_results_set = set(data_file["neighbors"][test_rowid])
|
|
|
|
with mysql_db.execute_sql(
|
|
# SELECT id FROM {table_name} ORDER BY VEC_L2_Distance(vec, %s) LIMIT 100
|
|
# In the cluster started by tiup, the tiflash component does not yet include pr-10103, so '*' is used here as a substitute.
|
|
# For details, see: https://github.com/pingcap/tiflash/pull/10103
|
|
f"SELECT * FROM {table_name} ORDER BY VEC_L2_Distance(vec, %s) LIMIT 100",
|
|
(encode_vector(query_row),),
|
|
) as cursor:
|
|
actual_results = cursor.fetchall()
|
|
actual_results_set = set([int(row[0]) for row in actual_results])
|
|
recall = (
|
|
len(groundtruth_results_set & actual_results_set)
|
|
/ len(groundtruth_results_set)
|
|
* 100
|
|
)
|
|
total_recall += recall
|
|
total_tests += 1
|
|
|
|
if recall < 80:
|
|
print(
|
|
f" - WARNING: groundtruth #{test_rowid} recall {recall:.2f}%",
|
|
flush=True,
|
|
)
|
|
|
|
avg_recall = total_recall / total_tests
|
|
print(f" - Average recall: {recall:.2f}%", flush=True)
|
|
|
|
# For this dataset, our recall is very high, so we set a very high standard here
|
|
assert avg_recall >= 95
|
|
|
|
|
|
def compact_and_wait_index_built():
|
|
print()
|
|
print("+ Wait data synchronize...", flush=True)
|
|
cursor = mysql_db.execute_sql(f"SELECT COUNT(*) FROM {table_name}")
|
|
print(f" - Current row count: {cursor.fetchone()[0]}", flush=True)
|
|
print("+ Compact table...", flush=True)
|
|
mysql_db.execute_sql(f"ALTER TABLE {table_name} COMPACT")
|
|
print("+ Waiting index build finish...", flush=True)
|
|
|
|
start_time = time.time()
|
|
while True:
|
|
cursor = mysql_db.execute_sql(
|
|
f"SELECT ROWS_STABLE_NOT_INDEXED, ROWS_STABLE_INDEXED FROM INFORMATION_SCHEMA.TIFLASH_INDEXES WHERE TIDB_TABLE='{table_name}'"
|
|
)
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
time.sleep(10)
|
|
continue
|
|
|
|
if row[0] == 0:
|
|
break
|
|
|
|
print(f" - StableIndexed: {row[1]}, StableNotIndexed: {row[0]}", flush=True)
|
|
time.sleep(10)
|
|
|
|
if time.time() - start_time > 600:
|
|
raise Exception("Index build not finished in 10 minutes")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(prog="vector_recall")
|
|
parser.add_argument(
|
|
"--check-only",
|
|
help="Only do the check without loading data",
|
|
action="store_true",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
connect()
|
|
|
|
if args.check_only:
|
|
print("+ Perform check over existing data...", flush=True)
|
|
check(dataset_path_2, check_tiflash_used_index=True)
|
|
return
|
|
|
|
clean()
|
|
create_table()
|
|
|
|
print("+ Insert data and ensure there are both stable and delta...", flush=True)
|
|
# First insert a part of data, and makes it become the stable layer
|
|
load(dataset_path_1, 0, 30000)
|
|
compact_and_wait_index_built()
|
|
# Then insert the rest of data, becomes the delta layer
|
|
load(dataset_path_1, 30000, 60000)
|
|
# Now we check the recall when we hybrid some stable and data
|
|
check(dataset_path_1, check_tiflash_used_index=True)
|
|
print("+ Wait 10s so that some delta index may be built...", flush=True)
|
|
time.sleep(10)
|
|
check(dataset_path_1, check_tiflash_used_index=True)
|
|
|
|
# Try to remove some data, and insert data again to check multi-version recall
|
|
print("+ Reinsert multi-version data...", flush=True)
|
|
remove(22400, 41234) # This covers both delta and stable
|
|
load(dataset_path_1, 22400, 41234)
|
|
check(dataset_path_1, check_tiflash_used_index=True)
|
|
|
|
# Remove all data, insert dataset 2
|
|
print("+ Reinsert multi-version data using dataset 2...", flush=True)
|
|
remove(0, 60000)
|
|
load(dataset_path_2, 0, 20000)
|
|
compact_and_wait_index_built()
|
|
load(dataset_path_2, 20000, 60000)
|
|
check(dataset_path_2, check_tiflash_used_index=True)
|
|
print("+ Wait 10s so that some delta index may be built...", flush=True)
|
|
time.sleep(10)
|
|
check(dataset_path_2, check_tiflash_used_index=True)
|
|
|
|
# Compact all data, and check again, this checks the recall for stable only
|
|
compact_and_wait_index_built()
|
|
check(dataset_path_2, check_tiflash_used_index=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |