mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
370 lines
15 KiB
Python
370 lines
15 KiB
Python
|
import os
|
||
|
import threading
|
||
|
import logging
|
||
|
import pdb
|
||
|
import time
|
||
|
import random
|
||
|
import grpc
|
||
|
from multiprocessing import Process
|
||
|
from itertools import product
|
||
|
import numpy as np
|
||
|
import sklearn.preprocessing
|
||
|
from milvus import DataType
|
||
|
from client import MilvusClient
|
||
|
import utils
|
||
|
import parser
|
||
|
|
||
|
logger = logging.getLogger("milvus_benchmark.runner")
|
||
|
|
||
|
VECTORS_PER_FILE = 1000000
|
||
|
SIFT_VECTORS_PER_FILE = 100000
|
||
|
BINARY_VECTORS_PER_FILE = 2000000
|
||
|
|
||
|
MAX_NQ = 10001
|
||
|
FILE_PREFIX = "binary_"
|
||
|
|
||
|
# FOLDER_NAME = 'ann_1000m/source_data'
|
||
|
SRC_BINARY_DATA_DIR = '/test/milvus/raw_data/random/'
|
||
|
SIFT_SRC_DATA_DIR = '/test/milvus/raw_data/sift1b/'
|
||
|
DEEP_SRC_DATA_DIR = '/test/milvus/raw_data/deep1b/'
|
||
|
BINARY_SRC_DATA_DIR = '/test/milvus/raw_data/binary/'
|
||
|
SIFT_SRC_GROUNDTRUTH_DATA_DIR = SIFT_SRC_DATA_DIR + 'gnd'
|
||
|
|
||
|
WARM_TOP_K = 1
|
||
|
WARM_NQ = 1
|
||
|
DEFAULT_DIM = 512
|
||
|
|
||
|
|
||
|
GROUNDTRUTH_MAP = {
|
||
|
"1000000": "idx_1M.ivecs",
|
||
|
"2000000": "idx_2M.ivecs",
|
||
|
"5000000": "idx_5M.ivecs",
|
||
|
"10000000": "idx_10M.ivecs",
|
||
|
"20000000": "idx_20M.ivecs",
|
||
|
"50000000": "idx_50M.ivecs",
|
||
|
"100000000": "idx_100M.ivecs",
|
||
|
"200000000": "idx_200M.ivecs",
|
||
|
"500000000": "idx_500M.ivecs",
|
||
|
"1000000000": "idx_1000M.ivecs",
|
||
|
}
|
||
|
|
||
|
|
||
|
def gen_file_name(idx, dimension, data_type):
|
||
|
s = "%05d" % idx
|
||
|
fname = FILE_PREFIX + str(dimension) + "d_" + s + ".npy"
|
||
|
if data_type == "random":
|
||
|
fname = SRC_BINARY_DATA_DIR+fname
|
||
|
elif data_type == "sift":
|
||
|
fname = SIFT_SRC_DATA_DIR+fname
|
||
|
elif data_type == "deep":
|
||
|
fname = DEEP_SRC_DATA_DIR+fname
|
||
|
elif data_type == "binary":
|
||
|
fname = BINARY_SRC_DATA_DIR+fname
|
||
|
return fname
|
||
|
|
||
|
|
||
|
def get_vectors_from_binary(nq, dimension, data_type):
|
||
|
# use the first file, nq should be less than VECTORS_PER_FILE
|
||
|
if nq > MAX_NQ:
|
||
|
raise Exception("Over size nq")
|
||
|
if data_type == "random":
|
||
|
file_name = SRC_BINARY_DATA_DIR+'query_%d.npy' % dimension
|
||
|
elif data_type == "sift":
|
||
|
file_name = SIFT_SRC_DATA_DIR+'query.npy'
|
||
|
elif data_type == "deep":
|
||
|
file_name = DEEP_SRC_DATA_DIR+'query.npy'
|
||
|
elif data_type == "binary":
|
||
|
file_name = BINARY_SRC_DATA_DIR+'query.npy'
|
||
|
data = np.load(file_name)
|
||
|
vectors = data[0:nq].tolist()
|
||
|
return vectors
|
||
|
|
||
|
|
||
|
class Runner(object):
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def gen_executors(self, operations):
|
||
|
l = []
|
||
|
for name, operation in operations.items():
|
||
|
weight = operation["weight"] if "weight" in operation else 1
|
||
|
l.extend([name] * weight)
|
||
|
random.shuffle(l)
|
||
|
return l
|
||
|
|
||
|
def get_vector_type(self, data_type):
|
||
|
vector_type = ''
|
||
|
if data_type in ["random", "sift", "deep", "glove"]:
|
||
|
vector_type = DataType.FLOAT_VECTOR
|
||
|
elif data_type in ["binary"]:
|
||
|
vector_type = DataType.BINARY_VECTOR
|
||
|
else:
|
||
|
raise Exception("Data type: %s not defined" % data_type)
|
||
|
return vector_type
|
||
|
|
||
|
def get_vector_type_from_metric(self, metric_type):
|
||
|
vector_type = ''
|
||
|
if metric_type in ["hamming", "jaccard"]:
|
||
|
vector_type = DataType.BINARY_VECTOR
|
||
|
else:
|
||
|
vector_type = DataType.FLOAT_VECTOR
|
||
|
return vector_type
|
||
|
|
||
|
def normalize(self, metric_type, X):
|
||
|
if metric_type == "ip":
|
||
|
logger.info("Set normalize for metric_type: %s" % metric_type)
|
||
|
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
|
||
|
X = X.astype(np.float32)
|
||
|
elif metric_type == "l2":
|
||
|
X = X.astype(np.float32)
|
||
|
elif metric_type in ["jaccard", "hamming", "sub", "super"]:
|
||
|
tmp = []
|
||
|
for item in X:
|
||
|
new_vector = bytes(np.packbits(item, axis=-1).tolist())
|
||
|
tmp.append(new_vector)
|
||
|
X = tmp
|
||
|
return X
|
||
|
|
||
|
def generate_combinations(self, args):
|
||
|
if isinstance(args, list):
|
||
|
args = [el if isinstance(el, list) else [el] for el in args]
|
||
|
return [list(x) for x in product(*args)]
|
||
|
elif isinstance(args, dict):
|
||
|
flat = []
|
||
|
for k, v in args.items():
|
||
|
if isinstance(v, list):
|
||
|
flat.append([(k, el) for el in v])
|
||
|
else:
|
||
|
flat.append([(k, v)])
|
||
|
return [dict(x) for x in product(*flat)]
|
||
|
else:
|
||
|
raise TypeError("No args handling exists for %s" % type(args).__name__)
|
||
|
|
||
|
def do_insert(self, milvus, collection_name, data_type, dimension, size, ni):
|
||
|
'''
|
||
|
@params:
|
||
|
mivlus: server connect instance
|
||
|
dimension: collection dimensionn
|
||
|
# index_file_size: size trigger file merge
|
||
|
size: row count of vectors to be insert
|
||
|
ni: row count of vectors to be insert each time
|
||
|
# store_id: if store the ids returned by call add_vectors or not
|
||
|
@return:
|
||
|
total_time: total time for all insert operation
|
||
|
qps: vectors added per second
|
||
|
ni_time: avarage insert operation time
|
||
|
'''
|
||
|
bi_res = {}
|
||
|
total_time = 0.0
|
||
|
qps = 0.0
|
||
|
ni_time = 0.0
|
||
|
if data_type == "random":
|
||
|
if dimension == 512:
|
||
|
vectors_per_file = VECTORS_PER_FILE
|
||
|
elif dimension == 4096:
|
||
|
vectors_per_file = 100000
|
||
|
elif dimension == 16384:
|
||
|
vectors_per_file = 10000
|
||
|
elif data_type == "sift":
|
||
|
vectors_per_file = SIFT_VECTORS_PER_FILE
|
||
|
elif data_type in ["binary"]:
|
||
|
vectors_per_file = BINARY_VECTORS_PER_FILE
|
||
|
else:
|
||
|
raise Exception("data_type: %s not supported" % data_type)
|
||
|
if size % vectors_per_file or size % ni:
|
||
|
raise Exception("Not invalid collection size or ni")
|
||
|
i = 0
|
||
|
while i < (size // vectors_per_file):
|
||
|
vectors = []
|
||
|
if vectors_per_file >= ni:
|
||
|
file_name = gen_file_name(i, dimension, data_type)
|
||
|
# logger.info("Load npy file: %s start" % file_name)
|
||
|
data = np.load(file_name)
|
||
|
# logger.info("Load npy file: %s end" % file_name)
|
||
|
for j in range(vectors_per_file // ni):
|
||
|
vectors = data[j*ni:(j+1)*ni].tolist()
|
||
|
if vectors:
|
||
|
# start insert vectors
|
||
|
start_id = i * vectors_per_file + j * ni
|
||
|
end_id = start_id + len(vectors)
|
||
|
logger.debug("Start id: %s, end id: %s" % (start_id, end_id))
|
||
|
ids = [k for k in range(start_id, end_id)]
|
||
|
entities = milvus.generate_entities(vectors, ids)
|
||
|
ni_start_time = time.time()
|
||
|
try:
|
||
|
res_ids = milvus.insert(entities, ids=ids)
|
||
|
except grpc.RpcError as e:
|
||
|
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||
|
logger.debug("Retry insert")
|
||
|
def retry():
|
||
|
res_ids = milvus.insert(entities, ids=ids)
|
||
|
|
||
|
t0 = threading.Thread(target=retry)
|
||
|
t0.start()
|
||
|
t0.join()
|
||
|
logger.debug("Retry successfully")
|
||
|
raise e
|
||
|
assert ids == res_ids
|
||
|
# milvus.flush()
|
||
|
logger.debug(milvus.count())
|
||
|
ni_end_time = time.time()
|
||
|
total_time = total_time + ni_end_time - ni_start_time
|
||
|
i += 1
|
||
|
else:
|
||
|
vectors.clear()
|
||
|
loops = ni // vectors_per_file
|
||
|
for j in range(loops):
|
||
|
file_name = gen_file_name(loops*i+j, dimension, data_type)
|
||
|
data = np.load(file_name)
|
||
|
vectors.extend(data.tolist())
|
||
|
if vectors:
|
||
|
start_id = i * vectors_per_file
|
||
|
end_id = start_id + len(vectors)
|
||
|
logger.info("Start id: %s, end id: %s" % (start_id, end_id))
|
||
|
ids = [k for k in range(start_id, end_id)]
|
||
|
entities = milvus.generate_entities(vectors, ids)
|
||
|
ni_start_time = time.time()
|
||
|
try:
|
||
|
res_ids = milvus.insert(entities, ids=ids)
|
||
|
except grpc.RpcError as e:
|
||
|
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||
|
logger.debug("Retry insert")
|
||
|
def retry():
|
||
|
res_ids = milvus.insert(entities, ids=ids)
|
||
|
|
||
|
t0 = threading.Thread(target=retry)
|
||
|
t0.start()
|
||
|
t0.join()
|
||
|
logger.debug("Retry successfully")
|
||
|
raise e
|
||
|
|
||
|
assert ids == res_ids
|
||
|
# milvus.flush()
|
||
|
logger.debug(milvus.count())
|
||
|
ni_end_time = time.time()
|
||
|
total_time = total_time + ni_end_time - ni_start_time
|
||
|
i += loops
|
||
|
qps = round(size / total_time, 2)
|
||
|
ni_time = round(total_time / (size / ni), 2)
|
||
|
bi_res["total_time"] = round(total_time, 2)
|
||
|
bi_res["qps"] = qps
|
||
|
bi_res["ni_time"] = ni_time
|
||
|
return bi_res
|
||
|
|
||
|
def do_query(self, milvus, collection_name, vec_field_name, top_ks, nqs, run_count=1, search_param=None, filter_query=None):
|
||
|
bi_res = []
|
||
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||
|
for nq in nqs:
|
||
|
tmp_res = []
|
||
|
query_vectors = base_query_vectors[0:nq]
|
||
|
for top_k in top_ks:
|
||
|
avg_query_time = 0.0
|
||
|
min_query_time = 0.0
|
||
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(query_vectors)))
|
||
|
for i in range(run_count):
|
||
|
logger.debug("Start run query, run %d of %s" % (i+1, run_count))
|
||
|
start_time = time.time()
|
||
|
vector_query = {"vector": {vec_field_name: {
|
||
|
"topk": top_k,
|
||
|
"query": query_vectors,
|
||
|
"metric_type": utils.metric_type_trans(metric_type),
|
||
|
"params": search_param}
|
||
|
}}
|
||
|
query_res = milvus.query(vector_query, filter_query=filter_query)
|
||
|
interval_time = time.time() - start_time
|
||
|
if (i == 0) or (min_query_time > interval_time):
|
||
|
min_query_time = interval_time
|
||
|
logger.info("Min query time: %.2f" % min_query_time)
|
||
|
tmp_res.append(round(min_query_time, 2))
|
||
|
bi_res.append(tmp_res)
|
||
|
return bi_res
|
||
|
|
||
|
def do_query_qps(self, milvus, query_vectors, top_k, search_param):
|
||
|
start_time = time.time()
|
||
|
result = milvus.query(query_vectors, top_k, search_param)
|
||
|
end_time = time.time()
|
||
|
return end_time - start_time
|
||
|
|
||
|
def do_query_ids(self, milvus, collection_name, vec_field_name, top_k, nq, search_param=None, filter_query=None):
|
||
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||
|
query_vectors = base_query_vectors[0:nq]
|
||
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(query_vectors)))
|
||
|
vector_query = {"vector": {vec_field_name: {
|
||
|
"topk": top_k,
|
||
|
"query": query_vectors,
|
||
|
"metric_type": utils.metric_type_trans(metric_type),
|
||
|
"params": search_param}
|
||
|
}}
|
||
|
query_res = milvus.query(vector_query, filter_query=filter_query)
|
||
|
result_ids = milvus.get_ids(query_res)
|
||
|
return result_ids
|
||
|
|
||
|
def do_query_acc(self, milvus, collection_name, top_k, nq, id_store_name, search_param=None):
|
||
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name)
|
||
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||
|
vectors = base_query_vectors[0:nq]
|
||
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
|
||
|
query_res = milvus.query(vectors, top_k, search_param=None)
|
||
|
# if file existed, cover it
|
||
|
if os.path.isfile(id_store_name):
|
||
|
os.remove(id_store_name)
|
||
|
with open(id_store_name, 'a+') as fd:
|
||
|
for nq_item in query_res:
|
||
|
for item in nq_item:
|
||
|
fd.write(str(item.id)+'\t')
|
||
|
fd.write('\n')
|
||
|
|
||
|
# compute and print accuracy
|
||
|
def compute_accuracy(self, flat_file_name, index_file_name):
|
||
|
flat_id_list = []; index_id_list = []
|
||
|
logger.info("Loading flat id file: %s" % flat_file_name)
|
||
|
with open(flat_file_name, 'r') as flat_id_fd:
|
||
|
for line in flat_id_fd:
|
||
|
tmp_list = line.strip("\n").strip().split("\t")
|
||
|
flat_id_list.append(tmp_list)
|
||
|
logger.info("Loading index id file: %s" % index_file_name)
|
||
|
with open(index_file_name) as index_id_fd:
|
||
|
for line in index_id_fd:
|
||
|
tmp_list = line.strip("\n").strip().split("\t")
|
||
|
index_id_list.append(tmp_list)
|
||
|
if len(flat_id_list) != len(index_id_list):
|
||
|
raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list)))
|
||
|
# get the accuracy
|
||
|
return self.get_recall_value(flat_id_list, index_id_list)
|
||
|
|
||
|
def get_recall_value(self, true_ids, result_ids):
|
||
|
"""
|
||
|
Use the intersection length
|
||
|
"""
|
||
|
sum_radio = 0.0
|
||
|
for index, item in enumerate(result_ids):
|
||
|
# tmp = set(item).intersection(set(flat_id_list[index]))
|
||
|
tmp = set(true_ids[index]).intersection(set(item))
|
||
|
sum_radio = sum_radio + len(tmp) / len(item)
|
||
|
# logger.debug(sum_radio)
|
||
|
return round(sum_radio / len(result_ids), 3)
|
||
|
|
||
|
"""
|
||
|
Implementation based on:
|
||
|
https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py
|
||
|
"""
|
||
|
def get_groundtruth_ids(self, collection_size):
|
||
|
fname = GROUNDTRUTH_MAP[str(collection_size)]
|
||
|
fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname
|
||
|
a = np.fromfile(fname, dtype='int32')
|
||
|
d = a[0]
|
||
|
true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
|
||
|
return true_ids
|
||
|
|
||
|
def get_fields(self, milvus, collection_name):
|
||
|
fields = []
|
||
|
info = milvus.get_info(collection_name)
|
||
|
for item in info["fields"]:
|
||
|
fields.append(item["name"])
|
||
|
return fields
|
||
|
|
||
|
# def get_filter_query(self, filter_query):
|
||
|
# for filter in filter_query:
|