mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-05 05:18:52 +08:00
19aca9a653
Signed-off-by: wangting0128 <ting.wang@zilliz.com>
89 lines
3.4 KiB
Python
89 lines
3.4 KiB
Python
import random
|
|
import logging
|
|
# import math
|
|
from locust import TaskSet, task
|
|
from . import utils
|
|
|
|
logger = logging.getLogger("milvus_benchmark.runners.locust_tasks")
|
|
|
|
|
|
class Tasks(TaskSet):
|
|
@task
|
|
def query(self):
|
|
""" search interface """
|
|
op = "query"
|
|
# X = utils.generate_vectors(self.params[op]["nq"], self.op_info["dimension"])
|
|
vector_query = {"vector": {self.op_info["vector_field_name"]: {
|
|
"topk": self.params[op]["top_k"],
|
|
"query": self.values["X"][:self.params[op]["nq"]],
|
|
"metric_type": self.params[op]["metric_type"] if "metric_type" in self.params[op] else utils.DEFAULT_METRIC_TYPE,
|
|
"params": self.params[op]["search_param"]}
|
|
}}
|
|
filter_query = []
|
|
if "filters" in self.params[op]:
|
|
for filter in self.params[op]["filters"]:
|
|
if isinstance(filter, dict) and "range" in filter:
|
|
filter_query.append(eval(filter["range"]))
|
|
if isinstance(filter, dict) and "term" in filter:
|
|
filter_query.append(eval(filter["term"]))
|
|
|
|
guarantee_timestamp = self.params[op]["guarantee_timestamp"] if "guarantee_timestamp" in self.params[op] else None
|
|
|
|
# logger.debug(filter_query)
|
|
self.client.query(vector_query, filter_query=filter_query, log=False, guarantee_timestamp=guarantee_timestamp,
|
|
timeout=30)
|
|
|
|
@task
|
|
def flush(self):
|
|
self.client.flush(log=False, timeout=30)
|
|
|
|
@task
|
|
def load(self):
|
|
self.client.load_collection(timeout=30)
|
|
|
|
@task
|
|
def release(self):
|
|
self.client.release_collection()
|
|
self.client.load_collection(timeout=30)
|
|
|
|
# @task
|
|
# def release_index(self):
|
|
# self.client.release_index()
|
|
|
|
# @task
|
|
# def create_index(self):
|
|
# self.client.release_index()
|
|
|
|
@task
|
|
def insert(self):
|
|
op = "insert"
|
|
# ids = [random.randint(1000000, 10000000) for _ in range(self.params[op]["ni_per"])]
|
|
# X = [[random.random() for _ in range(self.op_info["dimension"])] for _ in range(self.params[op]["ni_per"])]
|
|
entities = utils.generate_entities(self.op_info["collection_info"], self.values["X"][:self.params[op]["ni_per"]], self.values["ids"][:self.params[op]["ni_per"]])
|
|
self.client.insert(entities, log=False, timeout=300)
|
|
|
|
@task
|
|
def insert_flush(self):
|
|
op = "insert_flush"
|
|
# ids = [random.randint(1000000, 10000000) for _ in range(self.params[op]["ni_per"])]
|
|
# X = [[random.random() for _ in range(self.op_info["dimension"])] for _ in range(self.params[op]["ni_per"])]
|
|
entities = utils.generate_entities(self.op_info["collection_info"], self.values["X"][:self.params[op]["ni_per"]], self.values["ids"][:self.params[op]["ni_per"]])
|
|
self.client.insert(entities, log=False)
|
|
self.client.flush(log=False)
|
|
|
|
@task
|
|
def insert_rand(self):
|
|
self.client.insert_rand(log=False)
|
|
|
|
@task
|
|
def get(self):
|
|
""" query interface """
|
|
op = "get"
|
|
self.client.get(self.values["get_ids"][:self.params[op]["ids_length"]], timeout=300)
|
|
|
|
@task
|
|
def scene_test(self):
|
|
op = "scene_test"
|
|
collection_name = op + '_' + str(random.randint(1, 10000)) + '_' + str(random.randint(10001, 999999))
|
|
self.client.scene_test(collection_name, vectors=self.values["X"][:3000], ids=self.values["ids"][:3000])
|