milvus/tests/benchmark/milvus_benchmark/client.py
wt 3c2a46f2be
[skip ci] Update case of benchmark (#10828)
Signed-off-by: wangting0128 <ting.wang@zilliz.com>
2021-10-28 18:14:26 +08:00

559 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import pdb
import random
import logging
import json
import time, datetime
import traceback
from multiprocessing import Process
from pymilvus import Milvus, DataType
import numpy as np
import utils as util
import config
from milvus_benchmark.runners import utils
logger = logging.getLogger("milvus_benchmark.client")
# yaml file and code file comparison table of Index parameters
INDEX_MAP = {
"flat": "FLAT",
"ivf_flat": "IVF_FLAT",
"ivf_sq8": "IVF_SQ8",
"nsg": "NSG",
"ivf_sq8h": "IVF_SQ8_HYBRID",
"ivf_pq": "IVF_PQ",
"hnsw": "HNSW",
"annoy": "ANNOY",
"bin_flat": "BIN_FLAT",
"bin_ivf_flat": "BIN_IVF_FLAT",
"rhnsw_pq": "RHNSW_PQ",
"rhnsw_sq": "RHNSW_SQ"
}
epsilon = 0.1
DEFAULT_WARM_QUERY_TOPK = 1
DEFAULT_WARM_QUERY_NQ = 1
def time_wrapper(func):
"""
This decorator prints the execution time for the decorated function.
"""
def wrapper(*args, **kwargs):
"""
log: Specify output log
rps: Specify the rps of the return interface
"""
start = time.time()
# logger.debug("Milvus {} start".format(func.__name__))
log = kwargs.get("log", True)
kwargs.pop("log", None)
rps = kwargs.get("rps", False)
kwargs.pop("rps", None)
result = func(*args, **kwargs)
end = time.time()
if log:
_rps = round(end - start, 2)
logger.debug("Milvus {} run in {}s".format(func.__name__, _rps))
if rps is not False:
return result, _rps
return result
return wrapper
class MilvusClient(object):
def __init__(self, collection_name=None, host=None, port=None, timeout=300):
self._collection_name = collection_name
self._collection_info = None
start_time = time.time()
if not host:
host = config.SERVER_HOST_DEFAULT
if not port:
port = config.SERVER_PORT_DEFAULT
# retry connect remote server
i = 0
while time.time() < start_time + timeout:
try:
self._milvus = Milvus(
host=host,
port=port,
try_connect=False,
pre_ping=False)
break
except Exception as e:
logger.error(str(e))
logger.error("Milvus connect failed: %d times" % i)
i = i + 1
time.sleep(30)
if time.time() > start_time + timeout:
raise Exception("Server connect timeout")
# self._metric_type = None
def __str__(self):
return 'Milvus collection %s' % self._collection_name
def set_collection(self, collection_name):
""" seting collection name """
self._collection_name = collection_name
# TODO: server not support
# def check_status(self, status):
# if not status.OK():
# logger.error(status.message)
# logger.error(self._milvus.server_status())
# logger.error(self.count())
# raise Exception("Status not ok")
def check_result_ids(self, result):
for index, item in enumerate(result):
if item[0].distance >= epsilon:
logger.error(index)
logger.error(item[0].distance)
raise Exception("Distance wrong")
@property
def collection_name(self):
return self._collection_name
# only support the given field name
def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False,
collection_name=None, other_fields=None):
self._dimension = dimension
if not collection_name:
collection_name = self._collection_name
vec_field_name = utils.get_default_field_name(data_type)
fields = [
{"name": vec_field_name, "type": data_type, "params": {"dim": dimension}},
{"name": "id", "type": DataType.INT64, "is_primary": True}
]
if other_fields:
other_fields = other_fields.split(",")
for other_field_name in other_fields:
if other_field_name.startswith("int"):
field_type = DataType.INT64
elif other_field_name.startswith("float"):
field_type = DataType.FLOAT
elif other_field_name.startswith("double"):
field_type = DataType.DOUBLE
else:
raise Exception("Field name not supported")
fields.append({"name": other_field_name, "type": field_type})
create_param = {
"fields": fields,
"auto_id": auto_id}
try:
self._milvus.create_collection(collection_name, create_param)
logger.info("Create collection: <%s> successfully" % collection_name)
except Exception as e:
logger.error(str(e))
raise
def create_partition(self, tag, collection_name=None):
if not collection_name:
collection_name = self._collection_name
self._milvus.create_partition(collection_name, tag)
@time_wrapper
def insert(self, entities, collection_name=None, timeout=None):
tmp_collection_name = self._collection_name if collection_name is None else collection_name
try:
insert_res = self._milvus.insert(tmp_collection_name, entities, timeout=timeout)
return insert_res.primary_keys
except Exception as e:
logger.error(str(e))
@time_wrapper
def insert_flush(self, entities, _async=False, collection_name=None):
# the method that included insert and flush
tmp_collection_name = self._collection_name if collection_name is None else collection_name
try:
insert_res = self._milvus.insert(tmp_collection_name, entities)
return insert_res.primary_keys
except Exception as e:
logger.error(str(e))
self._milvus.flush([tmp_collection_name], _async=_async)
def get_dimension(self):
info = self.get_info()
for field in info["fields"]:
if field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
return field["params"]["dim"]
def get_rand_ids(self, length):
segment_ids = []
while True:
stats = self.get_stats()
segments = stats["partitions"][0]["segments"]
# random choice one segment
segment = random.choice(segments)
try:
segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["id"])
except Exception as e:
logger.error(str(e))
if not len(segment_ids):
continue
elif len(segment_ids) > length:
return random.sample(segment_ids, length)
else:
logger.debug("Reset length: %d" % len(segment_ids))
return segment_ids
# def get_rand_ids_each_segment(self, length):
# res = []
# status, stats = self._milvus.get_collection_stats(self._collection_name)
# self.check_status(status)
# segments = stats["partitions"][0]["segments"]
# segments_num = len(segments)
# # random choice from each segment
# for segment in segments:
# status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
# self.check_status(status)
# res.extend(segment_ids[:length])
# return segments_num, res
# def get_rand_entities(self, length):
# ids = self.get_rand_ids(length)
# status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
# self.check_status(status)
# return ids, get_res
@time_wrapper
def get_entities(self, get_ids):
get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
return get_res
@time_wrapper
def delete(self, ids, collection_name=None):
# delete entity by id
tmp_collection_name = self._collection_name if collection_name is None else collection_name
self._milvus.delete_entity_by_id(tmp_collection_name, ids)
def delete_rand(self):
delete_id_length = random.randint(1, 100)
count_before = self.count()
logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length))
delete_ids = self.get_rand_ids(delete_id_length)
self.delete(delete_ids)
self.flush()
logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
for item in get_res:
assert not item
# if count_before - len(delete_ids) < self.count():
# logger.error(delete_ids)
# raise Exception("Error occured")
@time_wrapper
def flush(self, _async=False, collection_name=None, timeout=None):
tmp_collection_name = self._collection_name if collection_name is None else collection_name
self._milvus.flush([tmp_collection_name], _async=_async, timeout=timeout)
@time_wrapper
def compact(self, collection_name=None):
tmp_collection_name = self._collection_name if collection_name is None else collection_name
status = self._milvus.compact(tmp_collection_name)
self.check_status(status)
# only support "in" in expr
@time_wrapper
def get(self, ids, collection_name=None, timeout=None):
tmp_collection_name = self._collection_name if collection_name is None else collection_name
# res = self._milvus.get(tmp_collection_name, ids, output_fields=None, partition_names=None)
ids_expr = "id in %s" % (str(ids))
res = self._milvus.query(tmp_collection_name, ids_expr, output_fields=None, partition_names=None, timeout=timeout)
return res
@time_wrapper
def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None):
index_type = INDEX_MAP[index_type]
metric_type = utils.metric_type_trans(metric_type)
logger.info("Building index start, collection_name: %s, index_type: %s, metric_type: %s" % (
self._collection_name, index_type, metric_type))
if index_param:
logger.info(index_param)
index_params = {
"index_type": index_type,
"metric_type": metric_type,
"params": index_param
}
self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async)
# TODO: need to check
def describe_index(self, field_name, collection_name=None):
# stats = self.get_stats()
tmp_collection_name = self._collection_name if collection_name is None else collection_name
info = self._milvus.describe_index(tmp_collection_name, field_name)
logger.info(info)
index_info = {"index_type": "flat", "metric_type": None, "index_param": None}
if info:
index_info = {"index_type": info["index_type"], "metric_type": info["metric_type"], "index_param": info["params"]}
# transfer index type name
for k, v in INDEX_MAP.items():
if index_info['index_type'] == v:
index_info['index_type'] = k
return index_info
def drop_index(self, field_name):
logger.info("Drop index: %s" % self._collection_name)
return self._milvus.drop_index(self._collection_name, field_name)
@time_wrapper
def query(self, vector_query, filter_query=None, collection_name=None, timeout=300):
""" This method corresponds to the search method of milvus """
tmp_collection_name = self._collection_name if collection_name is None else collection_name
params = util.search_param_analysis(vector_query, filter_query)
params.update({"timeout": timeout})
logger.debug("Params of search : %s" % str(params))
result = self._milvus.search(tmp_collection_name, **params)
# must_params = [vector_query]
# if filter_query:
# must_params.extend(filter_query)
# query = {
# "bool": {"must": must_params}
# }
# result = self._milvus.search(tmp_collection_name, query, timeout=timeout)
return result
@time_wrapper
def warm_query(self, index_field_name, search_param, metric_type, times=2):
query_vectors = [[random.random() for _ in range(self._dimension)] for _ in range(DEFAULT_WARM_QUERY_NQ)]
# index_info = self.describe_index(index_field_name)
vector_query = {"vector": {index_field_name: {
"topk": DEFAULT_WARM_QUERY_TOPK,
"query": query_vectors,
"metric_type": metric_type,
"params": search_param}
}}
must_params = [vector_query]
query = {
"bool": {"must": must_params}
}
logger.debug("Start warm up query")
for i in range(times):
params = util.search_param_analysis(vector_query, None)
self._milvus.search(self._collection_name, **params)
# self._milvus.search(self._collection_name, query)
logger.debug("End warm up query")
@time_wrapper
def load_and_query(self, vector_query, filter_query=None, collection_name=None, timeout=120):
tmp_collection_name = self._collection_name if collection_name is None else collection_name
must_params = [vector_query]
if filter_query:
must_params.extend(filter_query)
query = {
"bool": {"must": must_params}
}
self.load_collection(tmp_collection_name)
params = util.search_param_analysis(vector_query, filter_query)
params.update({"timeout": timeout})
result = self._milvus.search(tmp_collection_name, **params)
# result = self._milvus.search(tmp_collection_name, query, timeout=timeout)
return result
def get_ids(self, result):
# idss = result._entities.ids
ids = []
# len_idss = len(idss)
# len_r = len(result)
# top_k = len_idss // len_r
# for offset in range(0, len_idss, top_k):
# ids.append(idss[offset: min(offset + top_k, len_idss)])
for res in result:
ids.append(res.ids)
return ids
def query_rand(self, nq_max=100, timeout=None):
# for ivf search
dimension = 128
top_k = random.randint(1, 100)
nq = random.randint(1, nq_max)
nprobe = random.randint(1, 100)
search_param = {"nprobe": nprobe}
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
metric_type = random.choice(["l2", "ip"])
logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
vec_field_name = utils.get_default_field_name()
vector_query = {"vector": {vec_field_name: {
"topk": top_k,
"query": query_vectors,
"metric_type": utils.metric_type_trans(metric_type),
"params": search_param}
}}
self.query(vector_query, timeout=timeout)
def load_query_rand(self, nq_max=100, timeout=None):
# for ivf search
dimension = 128
top_k = random.randint(1, 100)
nq = random.randint(1, nq_max)
nprobe = random.randint(1, 100)
search_param = {"nprobe": nprobe}
query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
metric_type = random.choice(["l2", "ip"])
logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
vec_field_name = utils.get_default_field_name()
vector_query = {"vector": {vec_field_name: {
"topk": top_k,
"query": query_vectors,
"metric_type": utils.metric_type_trans(metric_type),
"params": search_param}
}}
self.load_and_query(vector_query, timeout=timeout)
# TODO: need to check
def count(self, collection_name=None):
if collection_name is None:
collection_name = self._collection_name
row_count = self._milvus.get_collection_stats(collection_name)["row_count"]
logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name))
return row_count
def drop(self, timeout=120, collection_name=None):
"""
drop steps:
1.drop collection
2.check collection exist
3.Set timeout to exit
"""
timeout = int(timeout)
if collection_name is None:
collection_name = self._collection_name
logger.info("Start delete collection: %s" % collection_name)
self._milvus.drop_collection(collection_name)
i = 0
while i < timeout:
try:
res = self._milvus.has_collection(collection_name)
if res:
time.sleep(1)
i = i + 1
continue
else:
break
except Exception as e:
logger.warning("Collection count failed: {}".format(str(e)))
break
if i >= timeout:
logger.error("Delete collection timeout")
def get_stats(self):
return self._milvus.get_collection_stats(self._collection_name)
def get_info(self, collection_name=None):
if collection_name is None:
collection_name = self._collection_name
return self._milvus.describe_collection(collection_name)
def show_collections(self):
return self._milvus.list_collections()
def exists_collection(self, collection_name=None):
if collection_name is None:
collection_name = self._collection_name
res = self._milvus.has_collection(collection_name)
return res
def clean_db(self):
collection_names = self.show_collections()
for name in collection_names:
self.drop(collection_name=name)
@time_wrapper
def load_collection(self, collection_name=None, timeout=3000):
if collection_name is None:
collection_name = self._collection_name
return self._milvus.load_collection(collection_name, timeout=timeout)
@time_wrapper
def release_collection(self, collection_name=None, timeout=3000):
if collection_name is None:
collection_name = self._collection_name
return self._milvus.release_collection(collection_name, timeout=timeout)
@time_wrapper
def load_partitions(self, tag_names, collection_name=None, timeout=3000):
if collection_name is None:
collection_name = self._collection_name
return self._milvus.load_partitions(collection_name, tag_names, timeout=timeout)
@time_wrapper
def release_partitions(self, tag_names, collection_name=None, timeout=3000):
if collection_name is None:
collection_name = self._collection_name
return self._milvus.release_partitions(collection_name, tag_names, timeout=timeout)
@time_wrapper
def scene_test(self, collection_name=None, vectors=None, ids=None):
"""
Scene test steps
1.create collection with the specified collection name
2.insert data
3.flush data
4.create index
5.drop collection
"""
logger.debug("[scene_test] Start scene test : %s" % collection_name)
self.create_collection(dimension=128, collection_name=collection_name)
time.sleep(1)
collection_info = self.get_info(collection_name)
entities = utils.generate_entities(collection_info, vectors, ids)
logger.debug("[scene_test] Start insert : %s" % collection_name)
self.insert(entities, collection_name=collection_name)
logger.debug("[scene_test] Start flush : %s" % collection_name)
self.flush(collection_name=collection_name)
logger.debug("[scene_test] Start create index : %s" % collection_name)
self.create_index(field_name='float_vector', index_type="ivf_sq8", metric_type='l2',
collection_name=collection_name, index_param={'nlist': 2048})
time.sleep(59)
logger.debug("[scene_test] Start drop : %s" % collection_name)
self.drop(collection_name=collection_name)
logger.debug("[scene_test]Scene test close : %s" % collection_name)
# TODO: remove
# def get_server_version(self):
# return self._milvus.server_version()
# def get_server_mode(self):
# return self.cmd("mode")
# def get_server_commit(self):
# return self.cmd("build_commit_id")
# def get_server_config(self):
# return json.loads(self.cmd("get_milvus_config"))
# def get_mem_info(self):
# result = json.loads(self.cmd("get_system_info"))
# result_human = {
# # unit: Gb
# "memory_used": round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
# }
# return result_human
# def cmd(self, command):
# res = self._milvus._cmd(command)
# logger.info("Server command: %s, result: %s" % (command, res))
# return res
# @time_wrapper
# def set_config(self, parent_key, child_key, value):
# self._milvus.set_config(parent_key, child_key, value)
# def get_config(self, key):
# return self._milvus.get_config(key)