milvus/tests/milvus_ann_acc/test.py

132 lines
4.4 KiB
Python
Raw Normal View History

import os
import pdb
import time
import random
import sys
import h5py
import numpy
import logging
from logging import handlers
from client import MilvusClient
LOG_FOLDER = "logs"
logger = logging.getLogger("milvus_ann_acc")
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
if not os.path.exists(LOG_FOLDER):
os.system('mkdir -p %s' % LOG_FOLDER)
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
fileTimeHandler.suffix = "%Y%m%d.log"
fileTimeHandler.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG)
fileTimeHandler.setFormatter(formatter)
logger.addHandler(fileTimeHandler)
def get_dataset_fn(dataset_name):
file_path = "/test/milvus/ann_hdf5/"
if not os.path.exists(file_path):
raise Exception("%s not exists" % file_path)
return os.path.join(file_path, '%s.hdf5' % dataset_name)
def get_dataset(dataset_name):
hdf5_fn = get_dataset_fn(dataset_name)
hdf5_f = h5py.File(hdf5_fn)
return hdf5_f
def parse_dataset_name(dataset_name):
data_type = dataset_name.split("-")[0]
dimension = int(dataset_name.split("-")[1])
metric = dataset_name.split("-")[-1]
# metric = dataset.attrs['distance']
# dimension = len(dataset["train"][0])
if metric == "euclidean":
metric_type = "l2"
elif metric == "angular":
metric_type = "ip"
return ("ann"+data_type, dimension, metric_type)
def get_table_name(dataset_name, index_file_size):
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
dataset = get_dataset(dataset_name)
table_size = len(dataset["train"])
table_size = str(table_size // 1000000)+"m"
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
return table_name
def main(dataset_name, index_file_size, nlist=16384, force=False):
top_k = 10
nprobes = [32, 128]
dataset = get_dataset(dataset_name)
table_name = get_table_name(dataset_name, index_file_size)
m = MilvusClient(table_name)
if m.exists_table():
if force is True:
logger.info("Re-create table: %s" % table_name)
m.delete()
time.sleep(10)
else:
logger.info("Table name: %s existed" % table_name)
return
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
m.create_table(table_name, dimension, index_file_size, metric_type)
print(m.describe())
vectors = numpy.array(dataset["train"])
query_vectors = numpy.array(dataset["test"])
# m.insert(vectors)
interval = 100000
loops = len(vectors) // interval + 1
for i in range(loops):
start = i*interval
end = min((i+1)*interval, len(vectors))
tmp_vectors = vectors[start:end]
if start < end:
m.insert(tmp_vectors, ids=[i for i in range(start, end)])
time.sleep(60)
print(m.count())
for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
m.create_index(index_type, nlist)
print(m.describe_index())
if m.count() != len(vectors):
return
m.preload_table()
true_ids = numpy.array(dataset["neighbors"])
for nprobe in nprobes:
print("nprobe: %s" % nprobe)
sum_radio = 0.0; avg_radio = 0.0
result_ids = m.query(query_vectors, top_k, nprobe)
# print(result_ids[:10])
for index, result_item in enumerate(result_ids):
if len(set(true_ids[index][:top_k])) != len(set(result_item)):
logger.info("Error happened")
# logger.info(query_vectors[index])
# logger.info(true_ids[index][:top_k], result_item)
tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
sum_radio = sum_radio + (len(tmp) / top_k)
avg_radio = round(sum_radio / len(result_ids), 4)
logger.info(avg_radio)
m.drop_index()
if __name__ == "__main__":
print("glove-25-angular")
# main("sift-128-euclidean", 1024, force=True)
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("glove-25-angular", index_file_size, force=True)
print("sift-128-euclidean")
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("sift-128-euclidean", index_file_size, force=True)
# m = MilvusClient()