mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-04 21:09:06 +08:00
132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
|
import os
|
||
|
import pdb
|
||
|
import time
|
||
|
import random
|
||
|
import sys
|
||
|
import h5py
|
||
|
import numpy
|
||
|
import logging
|
||
|
from logging import handlers
|
||
|
|
||
|
from client import MilvusClient
|
||
|
|
||
|
LOG_FOLDER = "logs"
|
||
|
logger = logging.getLogger("milvus_ann_acc")
|
||
|
|
||
|
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
|
||
|
if not os.path.exists(LOG_FOLDER):
|
||
|
os.system('mkdir -p %s' % LOG_FOLDER)
|
||
|
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
|
||
|
fileTimeHandler.suffix = "%Y%m%d.log"
|
||
|
fileTimeHandler.setFormatter(formatter)
|
||
|
logging.basicConfig(level=logging.DEBUG)
|
||
|
fileTimeHandler.setFormatter(formatter)
|
||
|
logger.addHandler(fileTimeHandler)
|
||
|
|
||
|
|
||
|
def get_dataset_fn(dataset_name):
|
||
|
file_path = "/test/milvus/ann_hdf5/"
|
||
|
if not os.path.exists(file_path):
|
||
|
raise Exception("%s not exists" % file_path)
|
||
|
return os.path.join(file_path, '%s.hdf5' % dataset_name)
|
||
|
|
||
|
|
||
|
def get_dataset(dataset_name):
|
||
|
hdf5_fn = get_dataset_fn(dataset_name)
|
||
|
hdf5_f = h5py.File(hdf5_fn)
|
||
|
return hdf5_f
|
||
|
|
||
|
|
||
|
def parse_dataset_name(dataset_name):
|
||
|
data_type = dataset_name.split("-")[0]
|
||
|
dimension = int(dataset_name.split("-")[1])
|
||
|
metric = dataset_name.split("-")[-1]
|
||
|
# metric = dataset.attrs['distance']
|
||
|
# dimension = len(dataset["train"][0])
|
||
|
if metric == "euclidean":
|
||
|
metric_type = "l2"
|
||
|
elif metric == "angular":
|
||
|
metric_type = "ip"
|
||
|
return ("ann"+data_type, dimension, metric_type)
|
||
|
|
||
|
|
||
|
def get_table_name(dataset_name, index_file_size):
|
||
|
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
||
|
dataset = get_dataset(dataset_name)
|
||
|
table_size = len(dataset["train"])
|
||
|
table_size = str(table_size // 1000000)+"m"
|
||
|
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
|
||
|
return table_name
|
||
|
|
||
|
|
||
|
def main(dataset_name, index_file_size, nlist=16384, force=False):
|
||
|
top_k = 10
|
||
|
nprobes = [32, 128]
|
||
|
|
||
|
dataset = get_dataset(dataset_name)
|
||
|
table_name = get_table_name(dataset_name, index_file_size)
|
||
|
m = MilvusClient(table_name)
|
||
|
if m.exists_table():
|
||
|
if force is True:
|
||
|
logger.info("Re-create table: %s" % table_name)
|
||
|
m.delete()
|
||
|
time.sleep(10)
|
||
|
else:
|
||
|
logger.info("Table name: %s existed" % table_name)
|
||
|
return
|
||
|
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
||
|
m.create_table(table_name, dimension, index_file_size, metric_type)
|
||
|
print(m.describe())
|
||
|
vectors = numpy.array(dataset["train"])
|
||
|
query_vectors = numpy.array(dataset["test"])
|
||
|
# m.insert(vectors)
|
||
|
|
||
|
interval = 100000
|
||
|
loops = len(vectors) // interval + 1
|
||
|
|
||
|
for i in range(loops):
|
||
|
start = i*interval
|
||
|
end = min((i+1)*interval, len(vectors))
|
||
|
tmp_vectors = vectors[start:end]
|
||
|
if start < end:
|
||
|
m.insert(tmp_vectors, ids=[i for i in range(start, end)])
|
||
|
|
||
|
time.sleep(60)
|
||
|
print(m.count())
|
||
|
|
||
|
for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
|
||
|
m.create_index(index_type, nlist)
|
||
|
print(m.describe_index())
|
||
|
if m.count() != len(vectors):
|
||
|
return
|
||
|
m.preload_table()
|
||
|
true_ids = numpy.array(dataset["neighbors"])
|
||
|
for nprobe in nprobes:
|
||
|
print("nprobe: %s" % nprobe)
|
||
|
sum_radio = 0.0; avg_radio = 0.0
|
||
|
result_ids = m.query(query_vectors, top_k, nprobe)
|
||
|
# print(result_ids[:10])
|
||
|
for index, result_item in enumerate(result_ids):
|
||
|
if len(set(true_ids[index][:top_k])) != len(set(result_item)):
|
||
|
logger.info("Error happened")
|
||
|
# logger.info(query_vectors[index])
|
||
|
# logger.info(true_ids[index][:top_k], result_item)
|
||
|
tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
|
||
|
sum_radio = sum_radio + (len(tmp) / top_k)
|
||
|
avg_radio = round(sum_radio / len(result_ids), 4)
|
||
|
logger.info(avg_radio)
|
||
|
m.drop_index()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print("glove-25-angular")
|
||
|
# main("sift-128-euclidean", 1024, force=True)
|
||
|
for index_file_size in [50, 1024]:
|
||
|
print("Index file size: %d" % index_file_size)
|
||
|
main("glove-25-angular", index_file_size, force=True)
|
||
|
|
||
|
print("sift-128-euclidean")
|
||
|
for index_file_size in [50, 1024]:
|
||
|
print("Index file size: %d" % index_file_size)
|
||
|
main("sift-128-euclidean", index_file_size, force=True)
|
||
|
# m = MilvusClient()
|