milvus/tests/milvus_ann_acc/test.py
JinHai-CN 3b0ca71602 #18 Add all test cases
Former-commit-id: ac930b6af9c664da4382e97722fed11a70bb2c99
2019-10-16 18:40:31 +08:00

132 lines
4.4 KiB
Python

import os
import pdb
import time
import random
import sys
import h5py
import numpy
import logging
from logging import handlers
from client import MilvusClient
LOG_FOLDER = "logs"
logger = logging.getLogger("milvus_ann_acc")
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
if not os.path.exists(LOG_FOLDER):
os.system('mkdir -p %s' % LOG_FOLDER)
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
fileTimeHandler.suffix = "%Y%m%d.log"
fileTimeHandler.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG)
fileTimeHandler.setFormatter(formatter)
logger.addHandler(fileTimeHandler)
def get_dataset_fn(dataset_name):
file_path = "/test/milvus/ann_hdf5/"
if not os.path.exists(file_path):
raise Exception("%s not exists" % file_path)
return os.path.join(file_path, '%s.hdf5' % dataset_name)
def get_dataset(dataset_name):
hdf5_fn = get_dataset_fn(dataset_name)
hdf5_f = h5py.File(hdf5_fn)
return hdf5_f
def parse_dataset_name(dataset_name):
data_type = dataset_name.split("-")[0]
dimension = int(dataset_name.split("-")[1])
metric = dataset_name.split("-")[-1]
# metric = dataset.attrs['distance']
# dimension = len(dataset["train"][0])
if metric == "euclidean":
metric_type = "l2"
elif metric == "angular":
metric_type = "ip"
return ("ann"+data_type, dimension, metric_type)
def get_table_name(dataset_name, index_file_size):
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
dataset = get_dataset(dataset_name)
table_size = len(dataset["train"])
table_size = str(table_size // 1000000)+"m"
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
return table_name
def main(dataset_name, index_file_size, nlist=16384, force=False):
top_k = 10
nprobes = [32, 128]
dataset = get_dataset(dataset_name)
table_name = get_table_name(dataset_name, index_file_size)
m = MilvusClient(table_name)
if m.exists_table():
if force is True:
logger.info("Re-create table: %s" % table_name)
m.delete()
time.sleep(10)
else:
logger.info("Table name: %s existed" % table_name)
return
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
m.create_table(table_name, dimension, index_file_size, metric_type)
print(m.describe())
vectors = numpy.array(dataset["train"])
query_vectors = numpy.array(dataset["test"])
# m.insert(vectors)
interval = 100000
loops = len(vectors) // interval + 1
for i in range(loops):
start = i*interval
end = min((i+1)*interval, len(vectors))
tmp_vectors = vectors[start:end]
if start < end:
m.insert(tmp_vectors, ids=[i for i in range(start, end)])
time.sleep(60)
print(m.count())
for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
m.create_index(index_type, nlist)
print(m.describe_index())
if m.count() != len(vectors):
return
m.preload_table()
true_ids = numpy.array(dataset["neighbors"])
for nprobe in nprobes:
print("nprobe: %s" % nprobe)
sum_radio = 0.0; avg_radio = 0.0
result_ids = m.query(query_vectors, top_k, nprobe)
# print(result_ids[:10])
for index, result_item in enumerate(result_ids):
if len(set(true_ids[index][:top_k])) != len(set(result_item)):
logger.info("Error happened")
# logger.info(query_vectors[index])
# logger.info(true_ids[index][:top_k], result_item)
tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
sum_radio = sum_radio + (len(tmp) / top_k)
avg_radio = round(sum_radio / len(result_ids), 4)
logger.info(avg_radio)
m.drop_index()
if __name__ == "__main__":
print("glove-25-angular")
# main("sift-128-euclidean", 1024, force=True)
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("glove-25-angular", index_file_size, force=True)
print("sift-128-euclidean")
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("sift-128-euclidean", index_file_size, force=True)
# m = MilvusClient()