2019-10-16 18:40:31 +08:00
|
|
|
import os
|
|
|
|
import logging
|
|
|
|
import pdb
|
|
|
|
import time
|
|
|
|
import random
|
|
|
|
from multiprocessing import Process
|
|
|
|
import numpy as np
|
|
|
|
from client import MilvusClient
|
|
|
|
import utils
|
|
|
|
import parser
|
|
|
|
from runner import Runner
|
|
|
|
|
|
|
|
logger = logging.getLogger("milvus_benchmark.local_runner")
|
|
|
|
|
|
|
|
|
|
|
|
class LocalRunner(Runner):
|
|
|
|
"""run local mode"""
|
|
|
|
def __init__(self, ip, port):
|
|
|
|
super(LocalRunner, self).__init__()
|
|
|
|
self.ip = ip
|
|
|
|
self.port = port
|
|
|
|
|
|
|
|
def run(self, definition, run_type=None):
|
|
|
|
if run_type == "performance":
|
|
|
|
for op_type, op_value in definition.items():
|
|
|
|
run_count = op_value["run_count"]
|
|
|
|
run_params = op_value["params"]
|
|
|
|
|
|
|
|
if op_type == "insert":
|
|
|
|
for index, param in enumerate(run_params):
|
|
|
|
table_name = param["table_name"]
|
|
|
|
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
|
|
|
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
|
|
|
# Check has table or not
|
|
|
|
if milvus.exists_table():
|
|
|
|
milvus.delete()
|
|
|
|
time.sleep(10)
|
|
|
|
milvus.create_table(table_name, dimension, index_file_size, metric_type)
|
|
|
|
res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
|
|
|
|
logger.info(res)
|
|
|
|
|
|
|
|
elif op_type == "query":
|
|
|
|
for index, param in enumerate(run_params):
|
|
|
|
logger.info("Definition param: %s" % str(param))
|
|
|
|
table_name = param["dataset"]
|
|
|
|
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
|
|
|
|
|
|
|
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
2019-12-04 14:18:20 +08:00
|
|
|
logger.info(milvus.describe())
|
|
|
|
logger.info(milvus.describe_index())
|
2019-10-16 18:40:31 +08:00
|
|
|
# parse index info
|
|
|
|
index_types = param["index.index_types"]
|
|
|
|
nlists = param["index.nlists"]
|
|
|
|
# parse top-k, nq, nprobe
|
|
|
|
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
2019-12-04 14:18:20 +08:00
|
|
|
# milvus.drop_index()
|
2019-10-16 18:40:31 +08:00
|
|
|
|
|
|
|
for index_type in index_types:
|
|
|
|
for nlist in nlists:
|
2019-12-04 14:18:20 +08:00
|
|
|
# milvus.create_index(index_type, nlist)
|
2019-10-16 18:40:31 +08:00
|
|
|
# preload index
|
2019-12-04 14:18:20 +08:00
|
|
|
logger.info("Start preloading table")
|
2019-10-16 18:40:31 +08:00
|
|
|
milvus.preload_table()
|
2019-12-04 14:18:20 +08:00
|
|
|
logger.info("End preloading table")
|
2019-10-16 18:40:31 +08:00
|
|
|
# Run query test
|
2019-12-04 14:18:20 +08:00
|
|
|
logger.info("Start warm up query")
|
|
|
|
res = self.do_query(milvus, table_name, [1], [1], 1, 1)
|
|
|
|
logger.info("End warm up query")
|
2019-10-16 18:40:31 +08:00
|
|
|
for nprobe in nprobes:
|
|
|
|
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
|
|
|
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
|
2019-12-04 14:18:20 +08:00
|
|
|
headers = ["nq/topk"]
|
2019-10-16 18:40:31 +08:00
|
|
|
headers.extend([str(top_k) for top_k in top_ks])
|
|
|
|
utils.print_table(headers, nqs, res)
|
|
|
|
|
2019-12-04 14:18:20 +08:00
|
|
|
elif run_type == "accuracy":
|
|
|
|
for op_type, op_value in definition.items():
|
|
|
|
if op_type != "query":
|
|
|
|
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
|
|
|
break
|
|
|
|
run_count = op_value["run_count"]
|
|
|
|
run_params = op_value["params"]
|
|
|
|
for index, param in enumerate(run_params):
|
|
|
|
logger.info("Definition param: %s" % str(param))
|
|
|
|
table_name = param["dataset"]
|
|
|
|
sift_acc = False
|
|
|
|
if "sift_acc" in param:
|
|
|
|
sift_acc = param["sift_acc"]
|
|
|
|
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
|
|
|
|
|
|
|
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
|
|
|
logger.debug(milvus.show_tables())
|
|
|
|
# Check has table or not
|
|
|
|
if not milvus.exists_table():
|
|
|
|
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# parse index info
|
|
|
|
index_types = param["index.index_types"]
|
|
|
|
nlists = param["index.nlists"]
|
|
|
|
# parse top-k, nq, nprobe
|
|
|
|
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
|
|
|
|
|
|
|
if sift_acc is True:
|
|
|
|
# preload groundtruth data
|
|
|
|
true_ids_all = self.get_groundtruth_ids(table_size)
|
|
|
|
|
|
|
|
acc_dict = {}
|
|
|
|
for index_type in index_types:
|
|
|
|
for nlist in nlists:
|
|
|
|
result = milvus.describe_index()
|
|
|
|
logger.info(result)
|
|
|
|
# milvus.drop_index()
|
|
|
|
milvus.create_index(index_type, nlist)
|
|
|
|
# preload index
|
|
|
|
milvus.preload_table()
|
|
|
|
# Run query test
|
|
|
|
for nprobe in nprobes:
|
|
|
|
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
|
|
|
for top_k in top_ks:
|
|
|
|
for nq in nqs:
|
|
|
|
result_ids = []
|
|
|
|
id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
|
|
|
(table_name, index_type, nlist, metric_type, nprobe, top_k, nq)
|
|
|
|
if sift_acc is False:
|
|
|
|
self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix)
|
|
|
|
if index_type != "flat":
|
|
|
|
# Compute accuracy
|
|
|
|
base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
|
|
|
(table_name, nlist, metric_type, nprobe, top_k, nq)
|
|
|
|
avg_acc = self.compute_accuracy(base_name, id_prefix)
|
|
|
|
logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc))
|
|
|
|
else:
|
|
|
|
result_ids, result_distances = self.do_query_ids(milvus, table_name, top_k, nq, nprobe)
|
|
|
|
debug_file_ids = "0.5.3_result_ids"
|
|
|
|
debug_file_distances = "0.5.3_result_distances"
|
|
|
|
with open(debug_file_ids, "w+") as fd:
|
|
|
|
total = 0
|
|
|
|
for index, item in enumerate(result_ids):
|
|
|
|
true_item = true_ids_all[:nq, :top_k].tolist()[index]
|
|
|
|
tmp = set(item).intersection(set(true_item))
|
|
|
|
total = total + len(tmp)
|
|
|
|
fd.write("query: N-%d, intersection: %d, total: %d\n" % (index, len(tmp), total))
|
|
|
|
fd.write("%s\n" % str(item))
|
|
|
|
fd.write("%s\n" % str(true_item))
|
|
|
|
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
|
|
|
|
logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value))
|
|
|
|
# # print accuracy table
|
|
|
|
# headers = [table_name]
|
|
|
|
# headers.extend([str(top_k) for top_k in top_ks])
|
|
|
|
# utils.print_table(headers, nqs, res)
|
|
|
|
|
2019-10-16 18:40:31 +08:00
|
|
|
elif run_type == "stability":
|
|
|
|
for op_type, op_value in definition.items():
|
|
|
|
if op_type != "query":
|
|
|
|
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
|
|
|
break
|
|
|
|
run_count = op_value["run_count"]
|
|
|
|
run_params = op_value["params"]
|
2019-12-04 14:18:20 +08:00
|
|
|
nq = 100000
|
2019-10-16 18:40:31 +08:00
|
|
|
|
|
|
|
for index, param in enumerate(run_params):
|
|
|
|
logger.info("Definition param: %s" % str(param))
|
|
|
|
table_name = param["dataset"]
|
2019-12-04 14:18:20 +08:00
|
|
|
index_type = param["index_type"]
|
2019-10-16 18:40:31 +08:00
|
|
|
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
|
|
|
|
|
|
|
# set default test time
|
|
|
|
if "during_time" not in param:
|
|
|
|
during_time = 100 # seconds
|
|
|
|
else:
|
|
|
|
during_time = int(param["during_time"]) * 60
|
|
|
|
# set default query process num
|
|
|
|
if "query_process_num" not in param:
|
|
|
|
query_process_num = 10
|
|
|
|
else:
|
|
|
|
query_process_num = int(param["query_process_num"])
|
2019-12-04 14:18:20 +08:00
|
|
|
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
|
|
|
logger.debug(milvus.show_tables())
|
|
|
|
logger.debug(milvus.describe_index())
|
|
|
|
logger.debug(milvus.count())
|
2019-10-16 18:40:31 +08:00
|
|
|
# Check has table or not
|
|
|
|
if not milvus.exists_table():
|
|
|
|
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
|
|
|
continue
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
2019-12-04 14:18:20 +08:00
|
|
|
i = 0
|
2019-10-16 18:40:31 +08:00
|
|
|
while time.time() < start_time + during_time:
|
2019-12-04 14:18:20 +08:00
|
|
|
# processes = []
|
2019-10-16 18:40:31 +08:00
|
|
|
# # do query
|
|
|
|
# for i in range(query_process_num):
|
|
|
|
# milvus_instance = MilvusClient(table_name)
|
|
|
|
# top_k = random.choice([x for x in range(1, 100)])
|
|
|
|
# nq = random.choice([x for x in range(1, 1000)])
|
|
|
|
# nprobe = random.choice([x for x in range(1, 500)])
|
|
|
|
# logger.info(nprobe)
|
|
|
|
# p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, ))
|
|
|
|
# processes.append(p)
|
|
|
|
# p.start()
|
|
|
|
# time.sleep(0.1)
|
|
|
|
# for p in processes:
|
|
|
|
# p.join()
|
2019-12-04 14:18:20 +08:00
|
|
|
i = i + 1
|
|
|
|
milvus_instance = MilvusClient(table_name, ip=self.ip, port=self.port)
|
2019-10-16 18:40:31 +08:00
|
|
|
top_ks = random.sample([x for x in range(1, 100)], 4)
|
2019-12-04 14:18:20 +08:00
|
|
|
nqs = random.sample([x for x in range(1, 200)], 3)
|
|
|
|
nprobe = random.choice([x for x in range(1, 100)])
|
|
|
|
res = self.do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count)
|
2019-10-16 18:40:31 +08:00
|
|
|
# milvus_instance = MilvusClient(table_name)
|
|
|
|
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
|
|
|
|
if not status.OK():
|
|
|
|
logger.error(status.message)
|
2019-12-04 14:18:20 +08:00
|
|
|
# status = milvus_instance.create_index(index_type, 16384)
|