[test]Update restful api test (#25581)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2023-08-01 15:43:05 +08:00 committed by GitHub
parent 7fec0d61cc
commit eade5f9b7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1979 additions and 2098 deletions

View File

@ -1,17 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Alias(RestClient):
def drop_alias():
pass
def alter_alias():
pass
def create_alias():
pass

View File

@ -1,62 +0,0 @@
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Collection(RestClient):
@DELETE("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def drop_collection(self, payload):
"""Drop a collection"""
@GET("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def describe_collection(self, payload):
"""Describe a collection"""
@POST("collection")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def create_collection(self, payload):
"""Create a collection"""
@GET("collection/existence")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def has_collection(self, payload):
"""Check if a collection exists"""
@DELETE("collection/load")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def release_collection(self, payload):
"""Release a collection"""
@POST("collection/load")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def load_collection(self, payload):
"""Load a collection"""
@GET("collection/statistics")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_collection_statistics(self, payload):
"""Get collection statistics"""
@GET("collections")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def show_collections(self, payload):
"""Show collections"""
if __name__ == '__main__':
client = Collection("http://localhost:19121/api/v1")
print(client)

View File

@ -1,19 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Credential(RestClient):
def delete_credential():
pass
def update_credential():
pass
def create_credential():
pass
def list_credentials():
pass

View File

@ -1,63 +0,0 @@
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Entity(RestClient):
@POST("distance")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def calc_distance(self, payload):
""" Calculate distance between two points """
@DELETE("entities")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def delete(self, payload):
"""delete entities"""
@POST("entities")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def insert(self, payload):
"""insert entities"""
@POST("persist")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def flush(self, payload):
"""flush entities"""
@POST("persist/segment-info")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_persistent_segment_info(self, payload):
"""get persistent segment info"""
@POST("persist/state")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_flush_state(self, payload):
"""get flush state"""
@POST("query")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def query(self, payload):
"""query entities"""
@POST("query-segment-info")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_query_segment_info(self, payload):
"""get query segment info"""
@POST("search")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def search(self, payload):
"""search entities"""

View File

@ -1,18 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Import(RestClient):
def list_import_tasks():
pass
def exec_import():
pass
def get_import_state():
pass

View File

@ -1,38 +0,0 @@
import json
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Index(RestClient):
@DELETE("/index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def drop_index(self, payload):
"""Drop an index"""
@GET("/index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def describe_index(self, payload):
"""Describe an index"""
@POST("index")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def create_index(self, payload):
"""create index"""
@GET("index/progress")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_index_build_progress(self, payload):
"""get index build progress"""
@GET("index/state")
@body("payload", lambda p: json.dumps(p))
@on(200, lambda r: r.json())
def get_index_state(self, payload):
"""get index state"""

View File

@ -1,11 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Metrics(RestClient):
def get_metrics():
pass

View File

@ -0,0 +1,257 @@
import json
import requests
import time
import uuid
from utils.util_log import test_log as logger
def logger_request_response(response, url, tt, headers, data, str_data, str_response, method):
if len(data) > 2000:
data = data[:1000] + "..." + data[-1000:]
try:
if response.status_code == 200:
if ('code' in response.json() and response.json()["code"] == 200) or ('Code' in response.json() and response.json()["Code"] == 0):
logger.debug(
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {str_data}, response: {str_response}")
else:
logger.error(
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
else:
logger.error(
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
except Exception as e:
logger.error(e)
logger.error(
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
class Requests:
def __init__(self, url=None, api_key=None):
self.url = url
self.api_key = api_key
self.headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
'RequestId': str(uuid.uuid1())
}
def update_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
'RequestId': str(uuid.uuid1())
}
return headers
def post(self, url, headers=None, data=None):
headers = headers if headers is not None else self.update_headers()
data = json.dumps(data)
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
t0 = time.time()
response = requests.post(url, headers=headers, data=data)
tt = time.time() - t0
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
logger_request_response(response, url, tt, headers, data, str_data, str_response, "post")
return response
def get(self, url, headers=None, params=None, data=None):
headers = headers if headers is not None else self.update_headers()
data = json.dumps(data)
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
t0 = time.time()
if data is None or data == "null":
response = requests.get(url, headers=headers, params=params)
else:
response = requests.get(url, headers=headers, params=params, data=data)
tt = time.time() - t0
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
logger_request_response(response, url, tt, headers, data, str_data, str_response, "get")
return response
def put(self, url, headers=None, data=None):
headers = headers if headers is not None else self.update_headers()
data = json.dumps(data)
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
t0 = time.time()
response = requests.put(url, headers=headers, data=data)
tt = time.time() - t0
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
logger_request_response(response, url, tt, headers, data, str_data, str_response, "put")
return response
def delete(self, url, headers=None, data=None):
headers = headers if headers is not None else self.update_headers()
data = json.dumps(data)
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
t0 = time.time()
response = requests.delete(url, headers=headers, data=data)
tt = time.time() - t0
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
logger_request_response(response, url, tt, headers, data, str_data, str_response, "delete")
return response
class VectorClient(Requests):
def __init__(self, url, api_key, protocol="http"):
super().__init__(url, api_key)
self.protocol = protocol
self.url = url
self.api_key = api_key
self.db_name = None
self.headers = self.update_headers()
def update_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
'RequestId': str(uuid.uuid1())
}
return headers
def vector_search(self, payload, db_name="default", timeout=10):
time.sleep(1)
url = f'{self.protocol}://{self.url}/vector/search'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if "data" in rsp and len(rsp["data"]) == 0:
t0 = time.time()
while time.time() - t0 < timeout:
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if len(rsp["data"]) > 0:
break
time.sleep(1)
else:
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if "data" in rsp and len(rsp["data"]) == 0:
logger.info(f"after {timeout}s, still no data")
return response.json()
def vector_query(self, payload, db_name="default", timeout=10):
time.sleep(1)
url = f'{self.protocol}://{self.url}/vector/query'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if "data" in rsp and len(rsp["data"]) == 0:
t0 = time.time()
while time.time() - t0 < timeout:
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if len(rsp["data"]) > 0:
break
time.sleep(1)
else:
response = self.post(url, headers=self.update_headers(), data=payload)
rsp = response.json()
if "data" in rsp and len(rsp["data"]) == 0:
logger.info(f"after {timeout}s, still no data")
return response.json()
def vector_get(self, payload, db_name="default"):
time.sleep(1)
url = f'{self.protocol}://{self.url}/vector/get'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
return response.json()
def vector_delete(self, payload, db_name="default"):
url = f'{self.protocol}://{self.url}/vector/delete'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
return response.json()
def vector_insert(self, payload, db_name="default"):
url = f'{self.protocol}://{self.url}/vector/insert'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
return response.json()
class CollectionClient(Requests):
def __init__(self, url, api_key, protocol="http"):
super().__init__(url, api_key)
self.protocol = protocol
self.url = url
self.api_key = api_key
self.db_name = None
self.headers = self.update_headers()
def update_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
'RequestId': str(uuid.uuid1())
}
return headers
def collection_list(self, db_name="default"):
url = f'{self.protocol}://{self.url}/vector/collections'
params = {}
if self.db_name is not None:
params = {
"dbName": self.db_name
}
if db_name != "default":
params = {
"dbName": db_name
}
response = self.get(url, headers=self.update_headers(), params=params)
res = response.json()
return res
def collection_create(self, payload, db_name="default"):
time.sleep(1) # wait for collection created and in case of rate limit
url = f'{self.protocol}://{self.url}/vector/collections/create'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
return response.json()
def collection_describe(self, collection_name, db_name="default"):
url = f'{self.protocol}://{self.url}/vector/collections/describe'
params = {"collectionName": collection_name}
if self.db_name is not None:
params = {
"collectionName": collection_name,
"dbName": self.db_name
}
if db_name != "default":
params = {
"collectionName": collection_name,
"dbName": db_name
}
response = self.get(url, headers=self.update_headers(), params=params)
return response.json()
def collection_drop(self, payload, db_name="default"):
time.sleep(1) # wait for collection drop and in case of rate limit
url = f'{self.protocol}://{self.url}/vector/collections/drop'
if self.db_name is not None:
payload["dbName"] = self.db_name
if db_name != "default":
payload["dbName"] = db_name
response = self.post(url, headers=self.update_headers(), data=payload)
return response.json()

View File

@ -1,21 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Ops(RestClient):
def manual_compaction():
pass
def get_compaction_plans():
pass
def get_compaction_state():
pass
def load_balance():
pass
def get_replicas():
pass

View File

@ -1,27 +0,0 @@
from decorest import GET, POST, DELETE
from decorest import HttpStatus, RestClient
from decorest import accept, body, content, endpoint, form
from decorest import header, multipart, on, query, stream, timeout
class Partition(RestClient):
def drop_partition():
pass
def create_partition():
pass
def has_partition():
pass
def get_partition_statistics():
pass
def show_partitions():
pass
def release_partition():
pass
def load_partition():
pass

View File

@ -1,43 +0,0 @@
from datetime import date, datetime
from typing import List, Union, Optional
from pydantic import BaseModel, UUID4, conlist
from pydantic_factories import ModelFactory
class Person(BaseModel):
def __init__(self, length):
super().__init__()
self.len = length
id: UUID4
name: str
hobbies: List[str]
age: Union[float, int]
birthday: Union[datetime, date]
class Pet(BaseModel):
name: str
age: int
class PetFactory(BaseModel):
name: str
pet: Pet
age: Optional[int] = None
sample = {
"name": "John",
"pet": {
"name": "Fido",
"age": 3
}
}
result = PetFactory(**sample)
print(result)

View File

@ -1,72 +0,0 @@
from time import sleep
from decorest import HttpStatus, RestClient
from models.schema import CollectionSchema
from base.collection_service import CollectionService
from base.index_service import IndexService
from base.entity_service import EntityService
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
class Base:
"""init base class"""
endpoint = None
collection_service = None
index_service = None
entity_service = None
collection_name = None
collection_object_list = []
def setup_class(self):
log.info("setup class")
def teardown_class(self):
log.info("teardown class")
def setup_method(self, method):
log.info(("*" * 35) + " setup " + ("*" * 35))
log.info("[setup_method] Start setup test case %s." % method.__name__)
host = cf.param_info.param_host
port = cf.param_info.param_port
self.endpoint = "http://" + host + ":" + str(port) + "/api/v1"
self.collection_service = CollectionService(self.endpoint)
self.index_service = IndexService(self.endpoint)
self.entity_service = EntityService(self.endpoint)
def teardown_method(self, method):
res = self.collection_service.has_collection(collection_name=self.collection_name)
log.info(f"collection {self.collection_name} exists: {res}")
if res["value"] is True:
res = self.collection_service.drop_collection(self.collection_name)
log.info(f"drop collection {self.collection_name} res: {res}")
res = self.collection_service.show_collections()
all_collections = res["collection_names"]
union_collections = set(all_collections) & set(self.collection_object_list)
for collection in union_collections:
res = self.collection_service.drop_collection(collection)
log.info(f"drop collection {collection} res: {res}")
log.info("[teardown_method] Start teardown test case %s." % method.__name__)
log.info(("*" * 35) + " teardown " + ("*" * 35))
class TestBase(Base):
"""init test base class"""
def init_collection(self, name=None, schema=None):
collection_name = cf.gen_unique_str("test") if name is None else name
self.collection_name = collection_name
self.collection_object_list.append(collection_name)
if schema is None:
schema = cf.gen_default_schema(collection_name=collection_name)
# create collection
res = self.collection_service.create_collection(collection_name=collection_name, schema=schema)
log.info(f"create collection name: {collection_name} with schema: {schema}")
return collection_name, schema

View File

@ -1,96 +0,0 @@
from api.collection import Collection
from utils.util_log import test_log as log
from models import milvus
TIMEOUT = 30
class CollectionService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._collection = Collection(endpoint=endpoint)
def create_collection(self, collection_name, consistency_level=1, schema=None, shards_num=2):
payload = {
"collection_name": collection_name,
"consistency_level": consistency_level,
"schema": schema,
"shards_num": shards_num
}
log.info(f"payload: {payload}")
# payload = milvus.CreateCollectionRequest(collection_name=collection_name,
# consistency_level=consistency_level,
# schema=schema,
# shards_num=shards_num)
# payload = payload.dict()
rsp = self._collection.create_collection(payload)
return rsp
def has_collection(self, collection_name=None, time_stamp=0):
payload = {
"collection_name": collection_name,
"time_stamp": time_stamp
}
# payload = milvus.HasCollectionRequest(collection_name=collection_name, time_stamp=time_stamp)
# payload = payload.dict()
return self._collection.has_collection(payload)
def drop_collection(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.DropCollectionRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.drop_collection(payload)
def describe_collection(self, collection_name, collection_id=None, time_stamp=0):
payload = {
"collection_name": collection_name,
"collection_id": collection_id,
"time_stamp": time_stamp
}
# payload = milvus.DescribeCollectionRequest(collection_name=collection_name,
# collectionID=collection_id,
# time_stamp=time_stamp)
# payload = payload.dict()
return self._collection.describe_collection(payload)
def load_collection(self, collection_name, replica_number=1):
payload = {
"collection_name": collection_name,
"replica_number": replica_number
}
# payload = milvus.LoadCollectionRequest(collection_name=collection_name, replica_number=replica_number)
# payload = payload.dict()
return self._collection.load_collection(payload)
def release_collection(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.ReleaseCollectionRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.release_collection(payload)
def get_collection_statistics(self, collection_name):
payload = {
"collection_name": collection_name
}
# payload = milvus.GetCollectionStatisticsRequest(collection_name=collection_name)
# payload = payload.dict()
return self._collection.get_collection_statistics(payload)
def show_collections(self, collection_names=None, type=0):
payload = {
"collection_names": collection_names,
"type": type
}
# payload = milvus.ShowCollectionsRequest(collection_names=collection_names, type=type)
# payload = payload.dict()
return self._collection.show_collections(payload)

View File

@ -1,182 +0,0 @@
from api.entity import Entity
from common import common_type as ct
from utils.util_log import test_log as log
from models import common, schema, milvus, server
TIMEOUT = 30
class EntityService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._entity = Entity(endpoint=endpoint)
def calc_distance(self, base=None, op_left=None, op_right=None, params=None):
payload = {
"base": base,
"op_left": op_left,
"op_right": op_right,
"params": params
}
# payload = milvus.CalcDistanceRequest(base=base, op_left=op_left, op_right=op_right, params=params)
# payload = payload.dict()
return self._entity.calc_distance(payload)
def delete(self, base=None, collection_name=None, db_name=None, expr=None, hash_keys=None, partition_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"expr": expr,
"hash_keys": hash_keys,
"partition_name": partition_name
}
# payload = server.DeleteRequest(base=base,
# collection_name=collection_name,
# db_name=db_name,
# expr=expr,
# hash_keys=hash_keys,
# partition_name=partition_name)
# payload = payload.dict()
return self._entity.delete(payload)
def insert(self, base=None, collection_name=None, db_name=None, fields_data=None, hash_keys=None, num_rows=None,
partition_name=None, check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"fields_data": fields_data,
"hash_keys": hash_keys,
"num_rows": num_rows,
"partition_name": partition_name
}
# payload = milvus.InsertRequest(base=base,
# collection_name=collection_name,
# db_name=db_name,
# fields_data=fields_data,
# hash_keys=hash_keys,
# num_rows=num_rows,
# partition_name=partition_name)
# payload = payload.dict()
rsp = self._entity.insert(payload)
if check_task:
assert rsp["status"] == {}
assert rsp["insert_cnt"] == num_rows
return rsp
def flush(self, base=None, collection_names=None, db_name=None, check_task=True):
payload = {
"base": base,
"collection_names": collection_names,
"db_name": db_name
}
# payload = server.FlushRequest(base=base,
# collection_names=collection_names,
# db_name=db_name)
# payload = payload.dict()
rsp = self._entity.flush(payload)
if check_task:
assert rsp["status"] == {}
def get_persistent_segment_info(self, base=None, collection_name=None, db_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name
}
# payload = server.GetPersistentSegmentInfoRequest(base=base,
# collection_name=collection_name,
# db_name=db_name)
# payload = payload.dict()
return self._entity.get_persistent_segment_info(payload)
def get_flush_state(self, segment_ids=None):
payload = {
"segment_ids": segment_ids
}
# payload = server.GetFlushStateRequest(segment_ids=segment_ids)
# payload = payload.dict()
return self._entity.get_flush_state(payload)
def query(self, base=None, collection_name=None, db_name=None, expr=None,
guarantee_timestamp=None, output_fields=None, partition_names=None, travel_timestamp=None,
check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"expr": expr,
"guarantee_timestamp": guarantee_timestamp,
"output_fields": output_fields,
"partition_names": partition_names,
"travel_timestamp": travel_timestamp
}
#
# payload = server.QueryRequest(base=base, collection_name=collection_name, db_name=db_name, expr=expr,
# guarantee_timestamp=guarantee_timestamp, output_fields=output_fields,
# partition_names=partition_names, travel_timestamp=travel_timestamp)
# payload = payload.dict()
rsp = self._entity.query(payload)
if check_task:
fields_data = rsp["fields_data"]
for field_data in fields_data:
if field_data["field_name"] in expr:
data = field_data["Field"]["Scalars"]["Data"]["LongData"]["data"]
for d in data:
s = expr.replace(field_data["field_name"], str(d))
assert eval(s) is True
return rsp
def get_query_segment_info(self, base=None, collection_name=None, db_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name
}
# payload = server.GetQuerySegmentInfoRequest(base=base,
# collection_name=collection_name,
# db_name=db_name)
# payload = payload.dict()
return self._entity.get_query_segment_info(payload)
def search(self, base=None, collection_name=None, vectors=None, db_name=None, dsl=None,
output_fields=None, dsl_type=1,
guarantee_timestamp=None, partition_names=None, placeholder_group=None,
search_params=None, travel_timestamp=None, check_task=True):
payload = {
"base": base,
"collection_name": collection_name,
"output_fields": output_fields,
"vectors": vectors,
"db_name": db_name,
"dsl": dsl,
"dsl_type": dsl_type,
"guarantee_timestamp": guarantee_timestamp,
"partition_names": partition_names,
"placeholder_group": placeholder_group,
"search_params": search_params,
"travel_timestamp": travel_timestamp
}
# payload = server.SearchRequest(base=base, collection_name=collection_name, db_name=db_name, dsl=dsl,
# dsl_type=dsl_type, guarantee_timestamp=guarantee_timestamp,
# partition_names=partition_names, placeholder_group=placeholder_group,
# search_params=search_params, travel_timestamp=travel_timestamp)
# payload = payload.dict()
rsp = self._entity.search(payload)
if check_task:
assert rsp["status"] == {}
assert rsp["results"]["num_queries"] == len(vectors)
assert len(rsp["results"]["ids"]["IdField"]["IntId"]["data"]) == sum(rsp["results"]["topks"])
return rsp

View File

@ -0,0 +1,41 @@
from enum import Enum
class BaseError(Enum):
pass
class VectorInsertError(BaseError):
pass
class VectorSearchError(BaseError):
pass
class VectorGetError(BaseError):
pass
class VectorQueryError(BaseError):
pass
class VectorDeleteError(BaseError):
pass
class CollectionListError(BaseError):
pass
class CollectionCreateError(BaseError):
pass
class CollectionDropError(BaseError):
pass
class CollectionDescribeError(BaseError):
pass

View File

@ -1,57 +0,0 @@
from api.index import Index
from models import common, schema, milvus, server
TIMEOUT = 30
class IndexService:
def __init__(self, endpoint=None, timeout=None):
if timeout is None:
timeout = TIMEOUT
if endpoint is None:
endpoint = "http://localhost:9091/api/v1"
self._index = Index(endpoint=endpoint)
def drop_index(self, base, collection_name, db_name, field_name, index_name):
payload = server.DropIndexRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.drop_index(payload)
def describe_index(self, base, collection_name, db_name, field_name, index_name):
payload = server.DescribeIndexRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.describe_index(payload)
def create_index(self, base=None, collection_name=None, db_name=None, extra_params=None,
field_name=None, index_name=None):
payload = {
"base": base,
"collection_name": collection_name,
"db_name": db_name,
"extra_params": extra_params,
"field_name": field_name,
"index_name": index_name
}
# payload = server.CreateIndexRequest(base=base, collection_name=collection_name, db_name=db_name,
# extra_params=extra_params, field_name=field_name, index_name=index_name)
# payload = payload.dict()
return self._index.create_index(payload)
def get_index_build_progress(self, base, collection_name, db_name, field_name, index_name):
payload = server.GetIndexBuildProgressRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.get_index_build_progress(payload)
def get_index_state(self, base, collection_name, db_name, field_name, index_name):
payload = server.GetIndexStateRequest(base=base, collection_name=collection_name,
db_name=db_name, field_name=field_name, index_name=index_name)
payload = payload.dict()
return self._index.get_index_state(payload)

View File

@ -0,0 +1,114 @@
import json
import sys
import pytest
import time
from pymilvus import connections, db
from utils.util_log import test_log as logger
from api.milvus import VectorClient, CollectionClient
from utils.utils import get_data_by_payload
def get_config():
pass
class Base:
name = None
host = None
port = None
url = None
api_key = None
username = None
password = None
invalid_api_key = None
vector_client = None
collection_client = None
class TestBase(Base):
def teardown_method(self):
self.collection_client.api_key = self.api_key
all_collections = self.collection_client.collection_list()['data']
if self.name in all_collections:
logger.info(f"collection {self.name} exist, drop it")
payload = {
"collectionName": self.name,
}
try:
rsp = self.collection_client.collection_drop(payload)
except Exception as e:
logger.error(e)
@pytest.fixture(scope="function", autouse=True)
def init_client(self, host, port, username, password):
self.host = host
self.port = port
self.url = f"{host}:{port}/v1"
self.username = username
self.password = password
self.api_key = f"{self.username}:{self.password}"
self.invalid_api_key = "invalid_token"
self.vector_client = VectorClient(self.url, self.api_key)
self.collection_client = CollectionClient(self.url, self.api_key)
def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100, batch_size=1000):
# create collection
schema_payload = {
"collectionName": collection_name,
"dimension": dim,
"metricType": metric_type,
"description": "test collection",
"primaryField": pk_field,
"vectorField": "vector",
}
rsp = self.collection_client.collection_create(schema_payload)
assert rsp['code'] == 200
self.wait_collection_load_completed(collection_name)
batch_size = batch_size
batch = nb // batch_size
# in case of nb < batch_size
if batch == 0:
batch = 1
batch_size = nb
data = []
for i in range(batch):
nb = batch_size
data = get_data_by_payload(schema_payload, nb)
payload = {
"collectionName": collection_name,
"data": data
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
rsp = self.vector_client.vector_insert(payload)
assert rsp['code'] == 200
return schema_payload, data
def wait_collection_load_completed(self, name):
t0 = time.time()
timeout = 60
while True and time.time() - t0 < timeout:
rsp = self.collection_client.collection_describe(name)
if "data" in rsp and "load" in rsp["data"] and rsp["data"]["load"] == "LoadStateLoaded":
break
else:
time.sleep(5)
def create_database(self, db_name="default"):
connections.connect(host=self.host, port=self.port)
all_db = db.list_database()
logger.info(f"all database: {all_db}")
if db_name not in all_db:
logger.info(f"create database: {db_name}")
try:
db.create_database(db_name=db_name)
except Exception as e:
logger.error(e)
def update_database(self, db_name="default"):
self.create_database(db_name=db_name)
self.collection_client.db_name = db_name
self.vector_client.db_name = db_name

View File

@ -1,27 +0,0 @@
class CheckTasks:
""" The name of the method used to check the result """
check_nothing = "check_nothing"
err_res = "error_response"
ccr = "check_connection_result"
check_collection_property = "check_collection_property"
check_partition_property = "check_partition_property"
check_search_results = "check_search_results"
check_query_results = "check_query_results"
check_query_empty = "check_query_empty" # verify that query result is empty
check_query_not_empty = "check_query_not_empty"
check_distance = "check_distance"
check_delete_compact = "check_delete_compact"
check_merge_compact = "check_merge_compact"
check_role_property = "check_role_property"
check_permission_deny = "check_permission_deny"
check_value_equal = "check_value_equal"
class ResponseChecker:
def __init__(self, check_task, check_items):
self.check_task = check_task
self.check_items = check_items

View File

@ -1,21 +0,0 @@
from utils.util_log import test_log as log
def ip_check(ip):
if ip == "localhost":
return True
if not isinstance(ip, str):
log.error("[IP_CHECK] IP(%s) is not a string." % ip)
return False
return True
def number_check(num):
if str(num).isdigit():
return True
else:
log.error("[NUMBER_CHECK] Number(%s) is not a numbers." % num)
return False

View File

@ -1,292 +0,0 @@
import json
import os
import random
import string
import numpy as np
from enum import Enum
from common import common_type as ct
from utils.util_log import test_log as log
class ParamInfo:
def __init__(self):
self.param_host = ""
self.param_port = ""
def prepare_param_info(self, host, http_port):
self.param_host = host
self.param_port = http_port
param_info = ParamInfo()
class DataType(Enum):
Bool: 1
Int8: 2
Int16: 3
Int32: 4
Int64: 5
Float: 10
Double: 11
String: 20
VarChar: 21
BinaryVector: 100
FloatVector: 101
def gen_unique_str(str_value=None):
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
return "test_" + prefix if str_value is None else str_value + "_" + prefix
def gen_field(name=ct.default_bool_field_name, description=ct.default_desc, type_params=None, index_params=None,
data_type="Int64", is_primary_key=False, auto_id=False, dim=128, max_length=256):
data_type_map = {
"Bool": 1,
"Int8": 2,
"Int16": 3,
"Int32": 4,
"Int64": 5,
"Float": 10,
"Double": 11,
"String": 20,
"VarChar": 21,
"BinaryVector": 100,
"FloatVector": 101,
}
if data_type == "Int64":
is_primary_key = True
auto_id = True
if type_params is None:
type_params = []
if index_params is None:
index_params = []
if data_type in ["FloatVector", "BinaryVector"]:
type_params = [{"key": "dim", "value": str(dim)}]
if data_type in ["String", "VarChar"]:
type_params = [{"key": "max_length", "value": str(dim)}]
return {
"name": name,
"description": description,
"data_type": data_type_map.get(data_type, 0),
"type_params": type_params,
"index_params": index_params,
"is_primary_key": is_primary_key,
"auto_id": auto_id,
}
def gen_schema(name, fields, description=ct.default_desc, auto_id=False):
return {
"name": name,
"description": description,
"auto_id": auto_id,
"fields": fields,
}
def gen_default_schema(data_types=None, dim=ct.default_dim, collection_name=None):
if data_types is None:
data_types = ["Int64", "Float", "VarChar", "FloatVector"]
fields = []
for data_type in data_types:
if data_type in ["FloatVector", "BinaryVector"]:
fields.append(gen_field(name=data_type, data_type=data_type, type_params=[{"key": "dim", "value": dim}]))
else:
fields.append(gen_field(name=data_type, data_type=data_type))
return {
"autoID": True,
"fields": fields,
"description": ct.default_desc,
"name": collection_name,
}
def gen_fields_data(schema=None, nb=ct.default_nb,):
if schema is None:
schema = gen_default_schema()
fields = schema["fields"]
fields_data = []
for field in fields:
if field["data_type"] == 1:
fields_data.append([random.choice([True, False]) for i in range(nb)])
elif field["data_type"] == 2:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 3:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 4:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 5:
fields_data.append([i for i in range(nb)])
elif field["data_type"] == 10:
fields_data.append([np.float64(i) for i in range(nb)]) # json not support float32
elif field["data_type"] == 11:
fields_data.append([np.float64(i) for i in range(nb)])
elif field["data_type"] == 20:
fields_data.append([gen_unique_str((str(i))) for i in range(nb)])
elif field["data_type"] == 21:
fields_data.append([gen_unique_str(str(i)) for i in range(nb)])
elif field["data_type"] == 100:
dim = ct.default_dim
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
fields_data.append(gen_binary_vectors(nb, dim))
elif field["data_type"] == 101:
dim = ct.default_dim
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
fields_data.append(gen_float_vectors(nb, dim))
else:
log.error("Unknown data type.")
fields_data_body = []
for i, field in enumerate(fields):
fields_data_body.append({
"field_name": field["name"],
"type": field["data_type"],
"field": fields_data[i],
})
return fields_data_body
def get_vector_field(schema):
for field in schema["fields"]:
if field["data_type"] in [100, 101]:
return field["name"]
return None
def get_varchar_field(schema):
for field in schema["fields"]:
if field["data_type"] == 21:
return field["name"]
return None
def gen_vectors(nq=None, schema=None):
if nq is None:
nq = ct.default_nq
dim = ct.default_dim
data_type = 101
for field in schema["fields"]:
if field["data_type"] in [100, 101]:
dim = ct.default_dim
data_type = field["data_type"]
for k, v in field["type_params"]:
if k == "dim":
dim = int(v)
break
if data_type == 100:
return gen_binary_vectors(nq, dim)
if data_type == 101:
return gen_float_vectors(nq, dim)
def gen_float_vectors(nb, dim):
return [[np.float64(random.uniform(-1.0, 1.0)) for _ in range(dim)] for _ in range(nb)] # json not support float32
def gen_binary_vectors(nb, dim):
raw_vectors = []
binary_vectors = []
for _ in range(nb):
raw_vector = [random.randint(0, 1) for _ in range(dim)]
raw_vectors.append(raw_vector)
# packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return binary_vectors
def gen_index_params(index_type=None):
if index_type is None:
index_params = ct.default_index_params
else:
index_params = ct.all_index_params_map[index_type]
extra_params = []
for k, v in index_params.items():
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
extra_params.append(item)
return extra_params
def gen_search_param_by_index_type(index_type, metric_type="L2"):
search_params = []
if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]:
for nprobe in [10]:
ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}}
search_params.append(ivf_search_params)
elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]:
for nprobe in [10]:
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [64]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
for search_k in [1000]:
annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}}
search_params.append(annoy_search_param)
else:
log.info("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params
def gen_search_params(index_type=None, anns_field=ct.default_float_vec_field_name,
topk=ct.default_top_k):
if index_type is None:
search_params = gen_search_param_by_index_type(ct.default_index_type)[0]
else:
search_params = gen_search_param_by_index_type(index_type)[0]
extra_params = []
for k, v in search_params.items():
item = {"key": k, "value": json.dumps(v) if isinstance(v, dict) else str(v)}
extra_params.append(item)
extra_params.append({"key": "anns_field", "value": anns_field})
extra_params.append({"key": "topk", "value": str(topk)})
return extra_params
def gen_search_vectors(dim, nb, is_binary=False):
if is_binary:
return gen_binary_vectors(nb, dim)
return gen_float_vectors(nb, dim)
def modify_file(file_path_list, is_modify=False, input_content=""):
"""
file_path_list : file list -> list[<file_path>]
is_modify : does the file need to be reset
input_content the content that need to insert to the file
"""
if not isinstance(file_path_list, list):
log.error("[modify_file] file is not a list.")
for file_path in file_path_list:
folder_path, file_name = os.path.split(file_path)
if not os.path.isdir(folder_path):
log.debug("[modify_file] folder(%s) is not exist." % folder_path)
os.makedirs(folder_path)
if not os.path.isfile(file_path):
log.error("[modify_file] file(%s) is not exist." % file_path)
else:
if is_modify is True:
log.debug("[modify_file] start modifying file(%s)..." % file_path)
with open(file_path, "r+") as f:
f.seek(0)
f.truncate()
f.write(input_content)
f.close()
log.info("[modify_file] file(%s) modification is complete." % file_path_list)
if __name__ == '__main__':
a = gen_binary_vectors(10, 128)
print(a)

View File

@ -1,80 +0,0 @@
""" Initialized parameters """
port = 19530
epsilon = 0.000001
namespace = "milvus"
default_flush_interval = 1
big_flush_interval = 1000
default_drop_interval = 3
default_dim = 128
default_nb = 3000
default_nb_medium = 5000
default_top_k = 10
default_nq = 2
default_limit = 10
default_search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
default_search_ip_params = {"metric_type": "IP", "params": {"nprobe": 10}}
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
default_index_type = "HNSW"
default_index_params = {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"}
default_varchar_index = {}
default_binary_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"}
default_diskann_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
default_diskann_search_params = {"metric_type": "L2", "params": {"search_list": 30}}
max_top_k = 16384
max_partition_num = 4096 # 256
default_segment_row_limit = 1000
default_server_segment_row_limit = 1024 * 512
default_alias = "default"
default_user = "root"
default_password = "Milvus"
default_bool_field_name = "Bool"
default_int8_field_name = "Int8"
default_int16_field_name = "Int16"
default_int32_field_name = "Int32"
default_int64_field_name = "Int64"
default_float_field_name = "Float"
default_double_field_name = "Double"
default_string_field_name = "Varchar"
default_float_vec_field_name = "FloatVector"
another_float_vec_field_name = "FloatVector1"
default_binary_vec_field_name = "BinaryVector"
default_partition_name = "_default"
default_tag = "1970_01_01"
row_count = "row_count"
default_length = 65535
default_desc = ""
default_collection_desc = "default collection"
default_index_name = "default_index_name"
default_binary_desc = "default binary collection"
collection_desc = "collection"
int_field_desc = "int64 type field"
float_field_desc = "float type field"
float_vec_field_desc = "float vector type field"
binary_vec_field_desc = "binary vector type field"
max_dim = 32768
min_dim = 1
gracefulTime = 1
default_nlist = 128
compact_segment_num_threshold = 4
compact_delta_ratio_reciprocal = 5 # compact_delta_binlog_ratio is 0.2
compact_retention_duration = 40 # compaction travel time retention range 20s
max_compaction_interval = 60 # the max time interval (s) from the last compaction
max_field_num = 256 # Maximum number of fields in a collection
default_dsl = f"{default_int64_field_name} in [2,4,6,8]"
default_expr = f"{default_int64_field_name} in [2,4,6,8]"
metric_types = []
all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT"]
all_index_params_map = {"FLAT": {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_FLAT": {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_SQ8": {"index_type": "IVF_SQ8", "params": {"nlist": 128}, "metric_type": "L2"},
"IVF_PQ": {"index_type": "IVF_PQ", "params": {"nlist": 128, "m": 16, "nbits": 8},
"metric_type": "L2"},
"HNSW": {"index_type": "HNSW", "params": {"M": 48, "efConstruction": 500}, "metric_type": "L2"},
"ANNOY": {"index_type": "ANNOY", "params": {"n_trees": 50}, "metric_type": "L2"},
"DISKANN": {"index_type": "DISKANN", "params": {}, "metric_type": "L2"},
"BIN_FLAT": {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"},
"BIN_IVF_FLAT": {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128},
"metric_type": "JACCARD"}
}

View File

@ -1,15 +1,12 @@
import pytest
import common.common_func as cf
from check.param_check import ip_check, number_check
from config.log_config import log_config
from utils.util_log import test_log as log
from common.common_func import param_info
import yaml
def pytest_addoption(parser):
parser.addoption("--host", action="store", default="127.0.0.1", help="Milvus host")
parser.addoption("--port", action="store", default="9091", help="Milvus http port")
parser.addoption('--clean_log', action='store_true', default=False, help="clean log before testing")
parser.addoption("--host", action="store", default="127.0.0.1", help="host")
parser.addoption("--port", action="store", default="19530", help="port")
parser.addoption("--username", action="store", default="root", help="email")
parser.addoption("--password", action="store", default="Milvus", help="password")
@pytest.fixture
@ -23,27 +20,11 @@ def port(request):
@pytest.fixture
def clean_log(request):
return request.config.getoption("--clean_log")
def username(request):
return request.config.getoption("--username")
@pytest.fixture(scope="session", autouse=True)
def initialize_env(request):
""" clean log before testing """
host = request.config.getoption("--host")
port = request.config.getoption("--port")
clean_log = request.config.getoption("--clean_log")
@pytest.fixture
def password(request):
return request.config.getoption("--password")
""" params check """
assert ip_check(host) and number_check(port)
""" modify log files """
file_path_list = [log_config.log_debug, log_config.log_info, log_config.log_err]
if log_config.log_worker != "":
file_path_list.append(log_config.log_worker)
cf.modify_file(file_path_list=file_path_list, is_modify=clean_log)
log.info("#" * 80)
log.info("[initialize_milvus] Log cleaned up, start testing...")
param_info.prepare_param_info(host, port)

View File

@ -1,3 +0,0 @@
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00

View File

@ -1,28 +0,0 @@
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, Field
class KeyDataPair(BaseModel):
data: Optional[List[int]] = None
key: Optional[str] = None
class KeyValuePair(BaseModel):
key: Optional[str] = Field(None, example='dim')
value: Optional[str] = Field(None, example='128')
class MsgBase(BaseModel):
msg_type: Optional[int] = Field(None, description='Not useful for now')
class Status(BaseModel):
error_code: Optional[int] = None
reason: Optional[str] = None

View File

@ -1,138 +0,0 @@
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, Field
from models import common, schema
class DescribeCollectionRequest(BaseModel):
collection_name: Optional[str] = None
collectionID: Optional[int] = Field(
None, description='The collection ID you want to describe'
)
time_stamp: Optional[int] = Field(
None,
description='If time_stamp is not zero, will describe collection success when time_stamp >= created collection timestamp, otherwise will throw error.',
)
class DropCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The unique collection name in milvus.(Required)'
)
class FieldData(BaseModel):
field: Optional[List] = None
field_id: Optional[int] = None
field_name: Optional[str] = None
type: Optional[int] = Field(
None,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
)
class GetCollectionStatisticsRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want get statistics'
)
class HasCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The unique collection name in milvus.(Required)'
)
time_stamp: Optional[int] = Field(
None,
description='If time_stamp is not zero, will return true when time_stamp >= created collection timestamp, otherwise will return false.',
)
class InsertRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
fields_data: Optional[List[FieldData]] = None
hash_keys: Optional[List[int]] = None
num_rows: Optional[int] = None
partition_name: Optional[str] = None
class LoadCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want to load'
)
replica_number: Optional[int] = Field(
None, description='The replica number to load, default by 1'
)
class ReleaseCollectionRequest(BaseModel):
collection_name: Optional[str] = Field(
None, description='The collection name you want to release'
)
class ShowCollectionsRequest(BaseModel):
collection_names: Optional[List[str]] = Field(
None,
description="When type is InMemory, will return these collection's inMemory_percentages.(Optional)",
)
type: Optional[int] = Field(
None,
description='Decide return Loaded collections or All collections(Optional)',
)
class VectorIDs(BaseModel):
collection_name: Optional[str] = None
field_name: Optional[str] = None
id_array: Optional[List[int]] = None
partition_names: Optional[List[str]] = None
class VectorsArray(BaseModel):
binary_vectors: Optional[List[int]] = Field(
None,
description='Vectors is an array of binary vector divided by given dim. Disabled when IDs is set',
)
dim: Optional[int] = Field(
None, description='Dim of vectors or binary_vectors, not needed when use ids'
)
ids: Optional[VectorIDs] = None
vectors: Optional[List[float]] = Field(
None,
description='Vectors is an array of vector divided by given dim. Disabled when ids or binary_vectors is set',
)
class CalcDistanceRequest(BaseModel):
base: Optional[common.MsgBase] = None
op_left: Optional[VectorsArray] = None
op_right: Optional[VectorsArray] = None
params: Optional[List[common.KeyValuePair]] = None
class CreateCollectionRequest(BaseModel):
collection_name: str = Field(
...,
description='The unique collection name in milvus.(Required)',
example='book',
)
consistency_level: int = Field(
...,
description='The consistency level that the collection used, modification is not supported now.\n"Strong": 0,\n"Session": 1,\n"Bounded": 2,\n"Eventually": 3,\n"Customized": 4,',
example=1,
)
schema_: schema.CollectionSchema = Field(..., alias='schema')
shards_num: Optional[int] = Field(
None,
description='Once set, no modification is allowed (Optional)\nhttps://github.com/milvus-io/milvus/issues/6690',
example=1,
)

View File

@ -1,72 +0,0 @@
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from models import common
class FieldData(BaseModel):
field: Optional[Any] = Field(
None,
description='Types that are assignable to Field:\n\t*FieldData_Scalars\n\t*FieldData_Vectors',
)
field_id: Optional[int] = None
field_name: Optional[str] = None
type: Optional[int] = Field(
None,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
)
class FieldSchema(BaseModel):
autoID: Optional[bool] = None
data_type: int = Field(
...,
description='0: "None",\n1: "Bool",\n2: "Int8",\n3: "Int16",\n4: "Int32",\n5: "Int64",\n10: "Float",\n11: "Double",\n20: "String",\n21: "VarChar",\n100: "BinaryVector",\n101: "FloatVector",',
example=101,
)
description: Optional[str] = Field(
None, example='embedded vector of book introduction'
)
fieldID: Optional[int] = None
index_params: Optional[List[common.KeyValuePair]] = None
is_primary_key: Optional[bool] = Field(None, example=False)
name: str = Field(..., example='book_intro')
type_params: Optional[List[common.KeyValuePair]] = None
class IDs(BaseModel):
idField: Optional[Any] = Field(
None,
description='Types that are assignable to IdField:\n\t*IDs_IntId\n\t*IDs_StrId',
)
class LongArray(BaseModel):
data: Optional[List[int]] = None
class SearchResultData(BaseModel):
fields_data: Optional[List[FieldData]] = None
ids: Optional[IDs] = None
num_queries: Optional[int] = None
scores: Optional[List[float]] = None
top_k: Optional[int] = None
topks: Optional[List[int]] = None
class CollectionSchema(BaseModel):
autoID: Optional[bool] = Field(
None,
description='deprecated later, keep compatible with c++ part now',
example=False,
)
description: Optional[str] = Field(None, example='Test book search')
fields: Optional[List[FieldSchema]] = None
name: str = Field(..., example='book')

View File

@ -1,587 +0,0 @@
# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2022-12-08T02:46:08+00:00
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from models import common, schema
class AlterAliasRequest(BaseModel):
alias: Optional[str] = None
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
class BoolResponse(BaseModel):
status: Optional[common.Status] = None
value: Optional[bool] = None
class CalcDistanceResults(BaseModel):
array: Optional[Any] = Field(
None,
description='num(op_left)*num(op_right) distance values, "HAMMIN" return integer distance\n\nTypes that are assignable to Array:\n\t*CalcDistanceResults_IntDist\n\t*CalcDistanceResults_FloatDist',
)
status: Optional[common.Status] = None
class CompactionMergeInfo(BaseModel):
sources: Optional[List[int]] = None
target: Optional[int] = None
class CreateAliasRequest(BaseModel):
alias: Optional[str] = None
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
class CreateCredentialRequest(BaseModel):
base: Optional[common.MsgBase] = None
created_utc_timestamps: Optional[int] = Field(None, description='create time')
modified_utc_timestamps: Optional[int] = Field(None, description='modify time')
password: Optional[str] = Field(None, description='ciphertext password')
username: Optional[str] = Field(None, description='username')
class CreateIndexRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The particular collection name you want to create index.'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
extra_params: Optional[List[common.KeyValuePair]] = Field(
None,
description='Support keys: index_type,metric_type, params. Different index_type may has different params.',
)
field_name: Optional[str] = Field(
None, description='The vector field name in this particular collection'
)
index_name: Optional[str] = Field(
None,
description="Version before 2.0.2 doesn't contain index_name, we use default index name.",
)
class CreatePartitionRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_name: Optional[str] = Field(
None, description='The partition name you want to create.'
)
class DeleteCredentialRequest(BaseModel):
base: Optional[common.MsgBase] = None
username: Optional[str] = Field(None, description='Not useful for now')
class DeleteRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
expr: Optional[str] = None
hash_keys: Optional[List[int]] = None
partition_name: Optional[str] = None
class DescribeIndexRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The particular collection name in Milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
field_name: Optional[str] = Field(
None, description='The vector field name in this particular collection'
)
index_name: Optional[str] = Field(
None, description='No need to set up for now @2021.06.30'
)
class DropAliasRequest(BaseModel):
alias: Optional[str] = None
base: Optional[common.MsgBase] = None
db_name: Optional[str] = None
class DropIndexRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(None, description='must')
db_name: Optional[str] = None
field_name: Optional[str] = None
index_name: Optional[str] = Field(
None, description='No need to set up for now @2021.06.30'
)
class DropPartitionRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_name: Optional[str] = Field(
None, description='The partition name you want to drop'
)
class FlushRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_names: Optional[List[str]] = None
db_name: Optional[str] = None
class FlushResponse(BaseModel):
coll_segIDs: Optional[Dict[str, schema.LongArray]] = None
db_name: Optional[str] = None
status: Optional[common.Status] = None
class GetCollectionStatisticsResponse(BaseModel):
stats: Optional[List[common.KeyValuePair]] = Field(
None, description='Collection statistics data'
)
status: Optional[common.Status] = None
class GetCompactionPlansRequest(BaseModel):
compactionID: Optional[int] = None
class GetCompactionPlansResponse(BaseModel):
mergeInfos: Optional[List[CompactionMergeInfo]] = None
state: Optional[int] = None
status: Optional[common.Status] = None
class GetCompactionStateRequest(BaseModel):
compactionID: Optional[int] = None
class GetCompactionStateResponse(BaseModel):
completedPlanNo: Optional[int] = None
executingPlanNo: Optional[int] = None
state: Optional[int] = None
status: Optional[common.Status] = None
timeoutPlanNo: Optional[int] = None
class GetFlushStateRequest(BaseModel):
segmentIDs: Optional[List[int]] = Field(None, alias='segment_ids')
class GetFlushStateResponse(BaseModel):
flushed: Optional[bool] = None
status: Optional[common.Status] = None
class GetImportStateRequest(BaseModel):
task: Optional[int] = Field(None, description='id of an import task')
class GetImportStateResponse(BaseModel):
id: Optional[int] = Field(None, description='id of an import task')
id_list: Optional[List[int]] = Field(
None, description='auto generated ids if the primary key is autoid'
)
infos: Optional[List[common.KeyValuePair]] = Field(
None,
description='more informations about the task, progress percent, file path, failed reason, etc.',
)
row_count: Optional[int] = Field(
None,
description='if the task is finished, this value is how many rows are imported. if the task is not finished, this value is how many rows are parsed. return 0 if failed.',
)
state: Optional[int] = Field(
None, description='is this import task finished or not'
)
status: Optional[common.Status] = None
class GetIndexBuildProgressRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
field_name: Optional[str] = Field(
None, description='The vector field name in this collection'
)
index_name: Optional[str] = Field(None, description='Not useful for now')
class GetIndexBuildProgressResponse(BaseModel):
indexed_rows: Optional[int] = None
status: Optional[common.Status] = None
total_rows: Optional[int] = None
class GetIndexStateRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(None, description='must')
db_name: Optional[str] = None
field_name: Optional[str] = None
index_name: Optional[str] = Field(
None, description='No need to set up for now @2021.06.30'
)
class GetIndexStateResponse(BaseModel):
fail_reason: Optional[str] = None
state: Optional[int] = None
status: Optional[common.Status] = None
class GetMetricsRequest(BaseModel):
base: Optional[common.MsgBase] = None
request: Optional[str] = Field(None, description='request is of jsonic format')
class GetMetricsResponse(BaseModel):
component_name: Optional[str] = Field(
None, description='metrics from which component'
)
response: Optional[str] = Field(None, description='response is of jsonic format')
status: Optional[common.Status] = None
class GetPartitionStatisticsRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_name: Optional[str] = Field(
None, description='The partition name you want to collect statistics'
)
class GetPartitionStatisticsResponse(BaseModel):
stats: Optional[List[common.KeyValuePair]] = None
status: Optional[common.Status] = None
class GetPersistentSegmentInfoRequest(BaseModel):
base: Optional[common.MsgBase] = None
collectionName: Optional[str] = Field(None, alias="collection_name", description='must')
dbName: Optional[str] = Field(None, alias="db_name")
class GetQuerySegmentInfoRequest(BaseModel):
base: Optional[common.MsgBase] = None
collectionName: Optional[str] = Field(None, alias="collection_name", description='must')
dbName: Optional[str] = Field(None, alias="db_name")
class GetReplicasRequest(BaseModel):
base: Optional[common.MsgBase] = None
collectionID: Optional[int] = Field(None, alias="collection_id")
with_shard_nodes: Optional[bool] = None
class HasPartitionRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_name: Optional[str] = Field(
None, description='The partition name you want to check'
)
class ImportRequest(BaseModel):
channel_names: Optional[List[str]] = Field(
None, description='channel names for the collection'
)
collection_name: Optional[str] = Field(None, description='target collection')
files: Optional[List[str]] = Field(None, description='file paths to be imported')
options: Optional[List[common.KeyValuePair]] = Field(
None, description='import options, bucket, etc.'
)
partition_name: Optional[str] = Field(None, description='target partition')
row_based: Optional[bool] = Field(
None, description='the file is row-based or column-based'
)
class ImportResponse(BaseModel):
status: Optional[common.Status] = None
tasks: Optional[List[int]] = Field(None, description='id array of import tasks')
class IndexDescription(BaseModel):
field_name: Optional[str] = Field(None, description='The vector field name')
index_name: Optional[str] = Field(None, description='Index name')
indexID: Optional[int] = Field(None, description='Index id')
params: Optional[List[common.KeyValuePair]] = Field(
None, description='Will return index_type, metric_type, params(like nlist).'
)
class ListCredUsersRequest(BaseModel):
base: Optional[common.MsgBase] = None
class ListCredUsersResponse(BaseModel):
status: Optional[common.Status] = None
usernames: Optional[List[str]] = Field(None, description='username array')
class ListImportTasksRequest(BaseModel):
pass
class ListImportTasksResponse(BaseModel):
status: Optional[common.Status] = None
tasks: Optional[List[GetImportStateResponse]] = Field(
None, description='list of all import tasks'
)
class LoadBalanceRequest(BaseModel):
base: Optional[common.MsgBase] = None
collectionName: Optional[str] = None
dst_nodeIDs: Optional[List[int]] = None
sealed_segmentIDs: Optional[List[int]] = None
src_nodeID: Optional[int] = None
class LoadPartitionsRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_names: Optional[List[str]] = Field(
None, description='The partition names you want to load'
)
replica_number: Optional[int] = Field(
None, description='The replicas number you would load, 1 by default'
)
class ManualCompactionRequest(BaseModel):
collectionID: Optional[int] = None
timetravel: Optional[int] = None
class ManualCompactionResponse(BaseModel):
compactionID: Optional[int] = None
status: Optional[common.Status] = None
class PersistentSegmentInfo(BaseModel):
collectionID: Optional[int] = None
num_rows: Optional[int] = None
partitionID: Optional[int] = None
segmentID: Optional[int] = None
state: Optional[int] = None
class QueryRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = None
db_name: Optional[str] = None
expr: Optional[str] = None
guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp')
output_fields: Optional[List[str]] = None
partition_names: Optional[List[str]] = None
travel_timestamp: Optional[int] = None
class QueryResults(BaseModel):
collection_name: Optional[str] = None
fields_data: Optional[List[schema.FieldData]] = None
status: Optional[common.Status] = None
class QuerySegmentInfo(BaseModel):
collectionID: Optional[int] = None
index_name: Optional[str] = None
indexID: Optional[int] = None
mem_size: Optional[int] = None
nodeID: Optional[int] = None
num_rows: Optional[int] = None
partitionID: Optional[int] = None
segmentID: Optional[int] = None
state: Optional[int] = None
class ReleasePartitionsRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None, description='The collection name in milvus'
)
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_names: Optional[List[str]] = Field(
None, description='The partition names you want to release'
)
class SearchRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(None, description='must')
db_name: Optional[str] = None
dsl: Optional[str] = Field(None, description='must')
dsl_type: Optional[int] = Field(None, description='must')
guarantee_timestamp: Optional[int] = Field(None, description='guarantee_timestamp')
output_fields: Optional[List[str]] = None
partition_names: Optional[List[str]] = Field(None, description='must')
placeholder_group: Optional[List[int]] = Field(
None, description='serialized `PlaceholderGroup`'
)
search_params: Optional[List[common.KeyValuePair]] = Field(None, description='must')
travel_timestamp: Optional[int] = None
class SearchResults(BaseModel):
collection_name: Optional[str] = None
results: Optional[schema.SearchResultData] = None
status: Optional[common.Status] = None
class ShardReplica(BaseModel):
dm_channel_name: Optional[str] = None
leader_addr: Optional[str] = Field(None, description='IP:port')
leaderID: Optional[int] = None
node_ids: Optional[List[int]] = Field(
None,
description='optional, DO NOT save it in meta, set it only for GetReplicas()\nif with_shard_nodes is true',
)
class ShowCollectionsResponse(BaseModel):
collection_ids: Optional[List[int]] = Field(None, description='Collection Id array')
collection_names: Optional[List[str]] = Field(
None, description='Collection name array'
)
created_timestamps: Optional[List[int]] = Field(
None, description='Hybrid timestamps in milvus'
)
created_utc_timestamps: Optional[List[int]] = Field(
None, description='The utc timestamp calculated by created_timestamp'
)
inMemory_percentages: Optional[List[int]] = Field(
None, description='Load percentage on querynode when type is InMemory'
)
status: Optional[common.Status] = None
class ShowPartitionsRequest(BaseModel):
base: Optional[common.MsgBase] = None
collection_name: Optional[str] = Field(
None,
description='The collection name you want to describe, you can pass collection_name or collectionID',
)
collectionID: Optional[int] = Field(None, description='The collection id in milvus')
db_name: Optional[str] = Field(None, description='Not useful for now')
partition_names: Optional[List[str]] = Field(
None,
description="When type is InMemory, will return these patitions' inMemory_percentages.(Optional)",
)
type: Optional[int] = Field(
None, description='Decide return Loaded partitions or All partitions(Optional)'
)
class ShowPartitionsResponse(BaseModel):
created_timestamps: Optional[List[int]] = Field(
None, description='All hybrid timestamps'
)
created_utc_timestamps: Optional[List[int]] = Field(
None, description='All utc timestamps calculated by created_timestamps'
)
inMemory_percentages: Optional[List[int]] = Field(
None, description='Load percentage on querynode'
)
partition_names: Optional[List[str]] = Field(
None, description='All partition names for this collection'
)
partitionIDs: Optional[List[int]] = Field(
None, description='All partition ids for this collection'
)
status: Optional[common.Status] = None
class UpdateCredentialRequest(BaseModel):
base: Optional[common.MsgBase] = None
created_utc_timestamps: Optional[int] = Field(None, description='create time')
modified_utc_timestamps: Optional[int] = Field(None, description='modify time')
newPassword: Optional[str] = Field(None, description='new password')
oldPassword: Optional[str] = Field(None, description='old password')
username: Optional[str] = Field(None, description='username')
class DescribeCollectionResponse(BaseModel):
aliases: Optional[List[str]] = Field(
None, description='The aliases of this collection'
)
collection_name: Optional[str] = Field(None, description='The collection name')
collectionID: Optional[int] = Field(None, description='The collection id')
consistency_level: Optional[int] = Field(
None,
description='The consistency level that the collection used, modification is not supported now.',
)
created_timestamp: Optional[int] = Field(
None, description='Hybrid timestamp in milvus'
)
created_utc_timestamp: Optional[int] = Field(
None, description='The utc timestamp calculated by created_timestamp'
)
physical_channel_names: Optional[List[str]] = Field(
None, description='System design related, users should not perceive'
)
schema_: Optional[schema.CollectionSchema] = Field(None, alias='schema')
shards_num: Optional[int] = Field(None, description='The shards number you set.')
start_positions: Optional[List[common.KeyDataPair]] = Field(
None, description='The message ID/posititon when collection is created'
)
status: Optional[common.Status] = None
virtual_channel_names: Optional[List[str]] = Field(
None, description='System design related, users should not perceive'
)
class DescribeIndexResponse(BaseModel):
index_descriptions: Optional[List[IndexDescription]] = Field(
None,
description='All index informations, for now only return tha latest index you created for the collection.',
)
status: Optional[common.Status] = None
class GetPersistentSegmentInfoResponse(BaseModel):
infos: Optional[List[PersistentSegmentInfo]] = None
status: Optional[common.Status] = None
class GetQuerySegmentInfoResponse(BaseModel):
infos: Optional[List[QuerySegmentInfo]] = None
status: Optional[common.Status] = None
class ReplicaInfo(BaseModel):
collectionID: Optional[int] = None
node_ids: Optional[List[int]] = Field(None, description='include leaders')
partition_ids: Optional[List[int]] = Field(
None, description='empty indicates to load collection'
)
replicaID: Optional[int] = None
shard_replicas: Optional[List[ShardReplica]] = None
class GetReplicasResponse(BaseModel):
replicas: Optional[List[ReplicaInfo]] = None
status: Optional[common.Status] = None

View File

@ -1,8 +1,5 @@
[pytest]
addopts = --host 10.101.178.131 --html=/tmp/ci_logs/report.html --self-contained-html -v
# python3 -W ignore -m pytest
addopts = --strict --host 127.0.0.1 --port 19530 --username root --password Milvus --log-cli-level=INFO --capture=no
log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_date_format = %Y-%m-%d %H:%M:%S
@ -10,3 +7,8 @@ log_date_format = %Y-%m-%d %H:%M:%S
filterwarnings =
ignore::DeprecationWarning
markers =
L0 : 'L0 case, high priority'
L1 : 'L1 case, second priority'

View File

@ -1,2 +1,10 @@
decorest~=0.1.0
pydantic~=1.10.2
requests~=2.26.0
urllib3==1.26.16
loguru~=0.5.3
pytest~=7.2.0
pyyaml~=6.0
numpy~=1.24.3
allure-pytest>=2.8.18
Faker==19.2.0
pymilvus~=2.2.9
scikit-learn~=1.1.3

View File

@ -0,0 +1,395 @@
import datetime
import random
import time
from utils.util_log import test_log as logger
from utils.utils import gen_collection_name
import pytest
from api.milvus import CollectionClient
from base.testbase import TestBase
import threading
@pytest.mark.L0
class TestCreateCollection(TestBase):
@pytest.mark.parametrize("vector_field", [None, "vector", "emb"])
@pytest.mark.parametrize("primary_field", [None, "id", "doc_id"])
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
@pytest.mark.parametrize("dim", [32, 32768])
@pytest.mark.parametrize("db_name", ["prod", "default"])
def test_create_collections_default(self, dim, metric_type, primary_field, vector_field, db_name):
"""
target: test create collection
method: create a collection with a simple schema
expected: create collection success
"""
self.create_database(db_name)
name = gen_collection_name()
dim = 128
client = self.collection_client
client.db_name = db_name
payload = {
"collectionName": name,
"dimension": dim,
"metricType": metric_type,
"primaryField": primary_field,
"vectorField": vector_field,
}
if primary_field is None:
del payload["primaryField"]
if vector_field is None:
del payload["vectorField"]
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
def test_create_collections_concurrent_with_same_param(self):
"""
target: test create collection with same param
method: concurrent create collections with same param with multi thread
expected: create collections all success
"""
concurrent_rsp = []
def create_collection(c_name, vector_dim, c_metric_type):
collection_payload = {
"collectionName": c_name,
"dimension": vector_dim,
"metricType": c_metric_type,
}
rsp = client.collection_create(collection_payload)
concurrent_rsp.append(rsp)
logger.info(rsp)
name = gen_collection_name()
dim = 128
metric_type = "L2"
client = self.collection_client
threads = []
for i in range(10):
t = threading.Thread(target=create_collection, args=(name, dim, metric_type,))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
time.sleep(10)
success_cnt = 0
for rsp in concurrent_rsp:
if rsp["code"] == 200:
success_cnt += 1
logger.info(concurrent_rsp)
assert success_cnt == 10
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
assert f"FloatVector({dim})" in str(rsp['data']['fields'])
def test_create_collections_concurrent_with_different_param(self):
"""
target: test create collection with different param
method: concurrent create collections with different param with multi thread
expected: only one collection can success
"""
concurrent_rsp = []
def create_collection(c_name, vector_dim, c_metric_type):
collection_payload = {
"collectionName": c_name,
"dimension": vector_dim,
"metricType": c_metric_type,
}
rsp = client.collection_create(collection_payload)
concurrent_rsp.append(rsp)
logger.info(rsp)
name = gen_collection_name()
dim = 128
client = self.collection_client
threads = []
for i in range(0, 5):
t = threading.Thread(target=create_collection, args=(name, dim + i, "L2",))
threads.append(t)
for i in range(5, 10):
t = threading.Thread(target=create_collection, args=(name, dim + i, "IP",))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
time.sleep(10)
success_cnt = 0
for rsp in concurrent_rsp:
if rsp["code"] == 200:
success_cnt += 1
logger.info(concurrent_rsp)
assert success_cnt == 1
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
def test_create_collections_with_invalid_api_key(self):
"""
target: test create collection with invalid api key(wrong username and password)
method: create collections with invalid api key
expected: create collection failed
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
client.api_key = "illegal_api_key"
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 1800
@pytest.mark.parametrize("name", [" ", "test_collection_" * 100, "test collection", "test/collection", "test\collection"])
def test_create_collections_with_invalid_collection_name(self, name):
"""
target: test create collection with invalid collection name
method: create collections with invalid collection name
expected: create collection failed with right error message
"""
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 1
@pytest.mark.L0
class TestListCollections(TestBase):
def test_list_collections_default(self):
"""
target: test list collection with a simple schema
method: create collections and list them
expected: created collections are in list
"""
client = self.collection_client
name_list = []
for i in range(2):
name = gen_collection_name()
dim = 128
payload = {
"collectionName": name,
"dimension": dim,
}
time.sleep(1)
rsp = client.collection_create(payload)
assert rsp['code'] == 200
name_list.append(name)
rsp = client.collection_list()
all_collections = rsp['data']
for name in name_list:
assert name in all_collections
def test_list_collections_with_invalid_api_key(self):
"""
target: test list collection with an invalid api key
method: list collection with invalid api key
expected: raise error with right error code and message
"""
client = self.collection_client
name_list = []
for i in range(2):
name = gen_collection_name()
dim = 128
payload = {
"collectionName": name,
"dimension": dim,
}
time.sleep(1)
rsp = client.collection_create(payload)
assert rsp['code'] == 200
name_list.append(name)
client = self.collection_client
client.api_key = "illegal_api_key"
rsp = client.collection_list()
assert rsp['code'] == 1800
@pytest.mark.L0
class TestDescribeCollection(TestBase):
def test_describe_collections_default(self):
"""
target: test describe collection with a simple schema
method: describe collection
expected: info of description is same with param passed to create collection
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
rsp = client.collection_describe(name)
assert rsp['code'] == 200
assert rsp['data']['collectionName'] == name
assert f"FloatVector({dim})" in str(rsp['data']['fields'])
def test_describe_collections_with_invalid_api_key(self):
"""
target: test describe collection with invalid api key
method: describe collection with invalid api key
expected: raise error with right error code and message
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
illegal_client = CollectionClient(self.url, "illegal_api_key")
rsp = illegal_client.collection_describe(name)
assert rsp['code'] == 1800
def test_describe_collections_with_invalid_collection_name(self):
"""
target: test describe collection with invalid collection name
method: describe collection with invalid collection name
expected: raise error with right error code and message
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection
invalid_name = "invalid_name"
rsp = client.collection_describe(invalid_name)
assert rsp['code'] == 1
@pytest.mark.L0
class TestDropCollection(TestBase):
def test_drop_collections_default(self):
"""
Drop a collection with a simple schema
target: test drop collection with a simple schema
method: drop collection
expected: dropped collection was not in collection list
"""
clo_list = []
for i in range(5):
time.sleep(1)
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f_%f")
payload = {
"collectionName": name,
"dimension": 128,
}
rsp = self.collection_client.collection_create(payload)
assert rsp['code'] == 200
clo_list.append(name)
rsp = self.collection_client.collection_list()
all_collections = rsp['data']
for name in clo_list:
assert name in all_collections
for name in clo_list:
time.sleep(0.2)
payload = {
"collectionName": name,
}
rsp = self.collection_client.collection_drop(payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_list()
all_collections = rsp['data']
for name in clo_list:
assert name not in all_collections
def test_drop_collections_with_invalid_api_key(self):
"""
target: test drop collection with invalid api key
method: drop collection with invalid api key
expected: raise error with right error code and message; collection still in collection list
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# drop collection
payload = {
"collectionName": name,
}
illegal_client = CollectionClient(self.url, "invalid_api_key")
rsp = illegal_client.collection_drop(payload)
assert rsp['code'] == 1800
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
def test_drop_collections_with_invalid_collection_name(self):
"""
target: test drop collection with invalid collection name
method: drop collection with invalid collection name
expected: raise error with right error code and message
"""
name = gen_collection_name()
dim = 128
client = self.collection_client
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# drop collection
invalid_name = "invalid_name"
payload = {
"collectionName": invalid_name,
}
rsp = client.collection_drop(payload)
assert rsp['code'] == 100

View File

@ -1,49 +0,0 @@
from time import sleep
from common import common_type as ct
from common import common_func as cf
from base.client_base import TestBase
from utils.util_log import test_log as log
class TestDefault(TestBase):
def test_e2e(self):
collection_name, schema = self.init_collection()
nb = ct.default_nb
# insert
res = self.entity_service.insert(collection_name=collection_name, fields_data=cf.gen_fields_data(schema, nb=nb),
num_rows=nb)
log.info(f"insert {nb} rows into collection {collection_name}, response: {res}")
# flush
res = self.entity_service.flush(collection_names=[collection_name])
log.info(f"flush collection {collection_name}, response: {res}")
# create index for vector field
vector_field_name = cf.get_vector_field(schema)
vector_index_params = cf.gen_index_params(index_type="HNSW")
res = self.index_service.create_index(collection_name=collection_name, field_name=vector_field_name,
extra_params=vector_index_params)
log.info(f"create index for vector field {vector_field_name}, response: {res}")
# load
res = self.collection_service.load_collection(collection_name=collection_name)
log.info(f"load collection {collection_name}, response: {res}")
sleep(5)
# search
vectors = cf.gen_vectors(nq=ct.default_nq, schema=schema)
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
output_fields=[ct.default_int64_field_name],
search_params=cf.gen_search_params())
log.info(f"search collection {collection_name}, response: {res}")
# hybrid search
res = self.entity_service.search(collection_name=collection_name, vectors=vectors,
output_fields=[ct.default_int64_field_name],
search_params=cf.gen_search_params(),
dsl=ct.default_dsl)
log.info(f"hybrid search collection {collection_name}, response: {res}")
# query
res = self.entity_service.query(collection_name=collection_name, expr=ct.default_expr)
log.info(f"query collection {collection_name}, response: {res}")

View File

@ -0,0 +1,984 @@
import datetime
import random
import time
from sklearn import preprocessing
import numpy as np
import sys
import json
import time
from utils import constant
from utils.utils import gen_collection_name
from utils.util_log import test_log as logger
import pytest
from api.milvus import VectorClient
from base.testbase import TestBase
from utils.utils import (get_data_by_fields, get_data_by_payload, get_common_fields_by_data)
class TestInsertVector(TestBase):
@pytest.mark.L0
@pytest.mark.parametrize("insert_round", [2, 1])
@pytest.mark.parametrize("nb", [100, 10, 1])
@pytest.mark.parametrize("dim", [32, 128])
@pytest.mark.parametrize("primary_field", ["id", "url"])
@pytest.mark.parametrize("vector_field", ["vector", "embedding"])
@pytest.mark.parametrize("db_name", ["prod", "default"])
def test_insert_vector_with_simple_payload(self, db_name, vector_field, primary_field, nb, dim, insert_round):
"""
Insert a vector with a simple payload
"""
self.update_database(db_name=db_name)
# create a collection
name = gen_collection_name()
collection_payload = {
"collectionName": name,
"dimension": dim,
"primaryField": primary_field,
"vectorField": vector_field,
}
rsp = self.collection_client.collection_create(collection_payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
logger.info(f"rsp: {rsp}")
assert rsp['code'] == 200
# insert data
for i in range(insert_round):
data = get_data_by_payload(collection_payload, nb)
payload = {
"collectionName": name,
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
rsp = self.vector_client.vector_insert(payload)
assert rsp['code'] == 200
assert rsp['data']['insertCount'] == nb
logger.info("finished")
@pytest.mark.L0
@pytest.mark.parametrize("insert_round", [10])
def test_insert_vector_with_multi_round(self, insert_round):
"""
Insert a vector with a simple payload
"""
# create a collection
name = gen_collection_name()
collection_payload = {
"collectionName": name,
"dimension": 768,
}
rsp = self.collection_client.collection_create(collection_payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
logger.info(f"rsp: {rsp}")
assert rsp['code'] == 200
# insert data
nb = 300
for i in range(insert_round):
data = get_data_by_payload(collection_payload, nb)
payload = {
"collectionName": name,
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
rsp = self.vector_client.vector_insert(payload)
assert rsp['code'] == 200
assert rsp['data']['insertCount'] == nb
logger.info("finished")
def test_insert_vector_with_invalid_api_key(self):
"""
Insert a vector with invalid api key
"""
# create a collection
name = gen_collection_name()
dim = 128
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = self.collection_client.collection_create(payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
assert rsp['code'] == 200
# insert data
nb = 10
data = [
{
"vector": [np.float64(random.random()) for _ in range(dim)],
} for _ in range(nb)
]
payload = {
"collectionName": name,
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
client = self.vector_client
client.api_key = "invalid_api_key"
rsp = client.vector_insert(payload)
assert rsp['code'] == 1800
def test_insert_vector_with_invalid_collection_name(self):
"""
Insert a vector with an invalid collection name
"""
# create a collection
name = gen_collection_name()
dim = 128
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = self.collection_client.collection_create(payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
assert rsp['code'] == 200
# insert data
nb = 100
data = get_data_by_payload(payload, nb)
payload = {
"collectionName": "invalid_collection_name",
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
rsp = self.vector_client.vector_insert(payload)
assert rsp['code'] == 1
def test_insert_vector_with_invalid_database_name(self):
"""
Insert a vector with an invalid database name
"""
# create a collection
name = gen_collection_name()
dim = 128
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = self.collection_client.collection_create(payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
assert rsp['code'] == 200
# insert data
nb = 10
data = get_data_by_payload(payload, nb)
payload = {
"collectionName": name,
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
success = False
rsp = self.vector_client.vector_insert(payload, db_name="invalid_database")
assert rsp['code'] == 800
def test_insert_vector_with_mismatch_dim(self):
"""
Insert a vector with mismatch dim
"""
# create a collection
name = gen_collection_name()
dim = 32
payload = {
"collectionName": name,
"dimension": dim,
}
rsp = self.collection_client.collection_create(payload)
assert rsp['code'] == 200
rsp = self.collection_client.collection_describe(name)
assert rsp['code'] == 200
# insert data
nb = 1
data = [
{
"vector": [np.float64(random.random()) for _ in range(dim + 1)],
} for i in range(nb)
]
payload = {
"collectionName": name,
"data": data,
}
body_size = sys.getsizeof(json.dumps(payload))
logger.info(f"body size: {body_size / 1024 / 1024} MB")
rsp = self.vector_client.vector_insert(payload)
assert rsp['code'] == 1804
assert rsp['message'] == "fail to deal the insert data"
class TestSearchVector(TestBase):
@pytest.mark.L0
@pytest.mark.parametrize("metric_type", ["IP", "L2"])
def test_search_vector_with_simple_payload(self, metric_type):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
self.init_collection(name, metric_type=metric_type)
# search data
dim = 128
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
payload = {
"collectionName": name,
"vector": vector_to_search,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
limit = int(payload.get("limit", 100))
assert len(res) == limit
ids = [item['id'] for item in res]
assert len(ids) == len(set(ids))
distance = [item['distance'] for item in res]
if metric_type == "L2":
assert distance == sorted(distance)
if metric_type == "IP":
assert distance == sorted(distance, reverse=True)
@pytest.mark.L0
@pytest.mark.parametrize("sum_limit_offset", [16384, 16385])
def test_search_vector_with_exceed_sum_limit_offset(self, sum_limit_offset):
"""
Search a vector with a simple payload
"""
max_search_sum_limit_offset = constant.MAX_SUM_OFFSET_AND_LIMIT
name = gen_collection_name()
self.name = name
nb = sum_limit_offset + 2000
metric_type = "IP"
limit = 100
self.init_collection(name, metric_type=metric_type, nb=nb, batch_size=2000)
# search data
dim = 128
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
payload = {
"collectionName": name,
"vector": vector_to_search,
"limit": limit,
"offset": sum_limit_offset-limit,
}
rsp = self.vector_client.vector_search(payload)
if sum_limit_offset > max_search_sum_limit_offset:
assert rsp['code'] == 1
return
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
limit = int(payload.get("limit", 100))
assert len(res) == limit
ids = [item['id'] for item in res]
assert len(ids) == len(set(ids))
distance = [item['distance'] for item in res]
if metric_type == "L2":
assert distance == sorted(distance)
if metric_type == "IP":
assert distance == sorted(distance, reverse=True)
@pytest.mark.L0
@pytest.mark.parametrize("level", [0, 1, 2])
@pytest.mark.parametrize("offset", [0, 10, 100])
@pytest.mark.parametrize("limit", [1, 100])
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
def test_search_vector_with_complex_payload(self, limit, offset, level, metric_type):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
nb = limit + offset + 100
dim = 128
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": "uid >= 0",
"limit": limit,
"offset": offset,
}
rsp = self.vector_client.vector_search(payload)
if offset + limit > constant.MAX_SUM_OFFSET_AND_LIMIT:
assert rsp['code'] == 90126
return
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) == limit
for item in res:
assert item.get("uid") >= 0
for field in output_fields:
assert field in item
@pytest.mark.L0
@pytest.mark.parametrize("filter_expr", ["uid >= 0", "uid >= 0 and uid < 100", "uid in [1,2,3]"])
def test_search_vector_with_complex_int_filter(self, filter_expr):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": filter_expr,
"limit": limit,
"offset": 0,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) <= limit
for item in res:
uid = item.get("uid")
eval(filter_expr)
@pytest.mark.L0
@pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""])
def test_search_vector_with_complex_varchar_filter(self, filter_expr):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
names = []
for item in data:
names.append(item.get("name"))
names.sort()
logger.info(f"names: {names}")
mid = len(names) // 2
prefix = names[mid][0:2]
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
filter_expr = filter_expr.replace("placeholder", prefix)
logger.info(f"filter_expr: {filter_expr}")
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": filter_expr,
"limit": limit,
"offset": 0,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) <= limit
for item in res:
name = item.get("name")
logger.info(f"name: {name}")
if ">" in filter_expr:
assert name > prefix
if "like" in filter_expr:
assert name.startswith(prefix)
@pytest.mark.L0
@pytest.mark.parametrize("filter_expr", ["uid < 100 and name > \"placeholder\"",
"uid < 100 and name like \"placeholder%\""
])
def test_search_vector_with_complex_int64_varchar_and_filter(self, filter_expr):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
names = []
for item in data:
names.append(item.get("name"))
names.sort()
logger.info(f"names: {names}")
mid = len(names) // 2
prefix = names[mid][0:2]
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
filter_expr = filter_expr.replace("placeholder", prefix)
logger.info(f"filter_expr: {filter_expr}")
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": filter_expr,
"limit": limit,
"offset": 0,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) <= limit
for item in res:
uid = item.get("uid")
name = item.get("name")
logger.info(f"name: {name}")
uid_expr = filter_expr.split("and")[0]
assert eval(uid_expr) is True
varchar_expr = filter_expr.split("and")[1]
if ">" in varchar_expr:
assert name > prefix
if "like" in varchar_expr:
assert name.startswith(prefix)
@pytest.mark.parametrize("limit", [0, 16385])
def test_search_vector_with_invalid_limit(self, limit):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
dim = 128
schema_payload, data = self.init_collection(name, dim=dim)
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": "uid >= 0",
"limit": limit,
"offset": 0,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 1
@pytest.mark.parametrize("offset", [-1, 100_001])
def test_search_vector_with_invalid_offset(self, offset):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
dim = 128
schema_payload, data = self.init_collection(name, dim=dim)
vector_field = schema_payload.get("vectorField")
# search data
dim = 128
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"vector": vector_to_search,
"outputFields": output_fields,
"filter": "uid >= 0",
"limit": 100,
"offset": offset,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 1
def test_search_vector_with_illegal_api_key(self):
"""
Search a vector with an illegal api key
"""
pass
def test_search_vector_with_invalid_collection_name(self):
"""
Search a vector with an invalid collection name
"""
pass
def test_search_vector_with_invalid_output_field(self):
"""
Search a vector with an invalid output field
"""
pass
@pytest.mark.parametrize("invalid_expr", ["invalid_field > 0", "12-s", "中文", "a", " "])
def test_search_vector_with_invalid_expression(self, invalid_expr):
"""
Search a vector with an invalid expression
"""
pass
def test_search_vector_with_invalid_vector_field(self):
"""
Search a vector with an invalid vector field for ann search
"""
pass
@pytest.mark.parametrize("dim_offset", [1, -1])
def test_search_vector_with_mismatch_vector_dim(self, dim_offset):
"""
Search a vector with a mismatch vector dim
"""
pass
class TestQueryVector(TestBase):
@pytest.mark.L0
@pytest.mark.parametrize("expr", ["10+20 <= uid < 20+30", "uid in [1,2,3,4]",
"uid > 0", "uid >= 0", "uid > 0",
"uid > -100 and uid < 100"])
@pytest.mark.parametrize("include_output_fields", [True, False])
@pytest.mark.parametrize("partial_fields", [True, False])
def test_query_vector_with_int64_filter(self, expr, include_output_fields, partial_fields):
"""
Query a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
schema_payload, data = self.init_collection(name)
output_fields = get_common_fields_by_data(data)
if partial_fields:
output_fields = output_fields[:len(output_fields) // 2]
if "uid" not in output_fields:
output_fields.append("uid")
else:
output_fields = output_fields
# query data
payload = {
"collectionName": name,
"filter": expr,
"limit": 100,
"offset": 0,
"outputFields": output_fields
}
if not include_output_fields:
payload.pop("outputFields")
if 'vector' in output_fields:
output_fields.remove("vector")
time.sleep(5)
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
for r in res:
uid = r['uid']
assert eval(expr) is True
for field in output_fields:
assert field in r
@pytest.mark.L0
@pytest.mark.parametrize("filter_expr", ["name > \"placeholder\"", "name like \"placeholder%\""])
@pytest.mark.parametrize("include_output_fields", [True, False])
def test_query_vector_with_varchar_filter(self, filter_expr, include_output_fields):
"""
Query a vector with a complex payload
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
names = []
for item in data:
names.append(item.get("name"))
names.sort()
logger.info(f"names: {names}")
mid = len(names) // 2
prefix = names[mid][0:2]
# search data
output_fields = get_common_fields_by_data(data)
filter_expr = filter_expr.replace("placeholder", prefix)
logger.info(f"filter_expr: {filter_expr}")
payload = {
"collectionName": name,
"outputFields": output_fields,
"filter": filter_expr,
"limit": limit,
"offset": 0,
}
if not include_output_fields:
payload.pop("outputFields")
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) <= limit
for item in res:
name = item.get("name")
logger.info(f"name: {name}")
if ">" in filter_expr:
assert name > prefix
if "like" in filter_expr:
assert name.startswith(prefix)
@pytest.mark.parametrize("sum_of_limit_offset", [16384, 16385])
def test_query_vector_with_large_sum_of_limit_offset(self, sum_of_limit_offset):
"""
Query a vector with sum of limit and offset larger than max value
"""
max_sum_of_limit_offset = 16384
name = gen_collection_name()
filter_expr = "name > \"placeholder\""
self.name = name
nb = 200
dim = 128
limit = 100
offset = sum_of_limit_offset - limit
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
names = []
for item in data:
names.append(item.get("name"))
names.sort()
logger.info(f"names: {names}")
mid = len(names) // 2
prefix = names[mid][0:2]
# search data
output_fields = get_common_fields_by_data(data)
filter_expr = filter_expr.replace("placeholder", prefix)
logger.info(f"filter_expr: {filter_expr}")
payload = {
"collectionName": name,
"outputFields": output_fields,
"filter": filter_expr,
"limit": limit,
"offset": offset,
}
rsp = self.vector_client.vector_query(payload)
if sum_of_limit_offset > max_sum_of_limit_offset:
assert rsp['code'] == 1
return
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) <= limit
for item in res:
name = item.get("name")
logger.info(f"name: {name}")
if ">" in filter_expr:
assert name > prefix
if "like" in filter_expr:
assert name.startswith(prefix)
class TestGetVector(TestBase):
@pytest.mark.L0
def test_get_vector_with_simple_payload(self):
"""
Search a vector with a simple payload
"""
name = gen_collection_name()
self.name = name
self.init_collection(name)
# search data
dim = 128
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
payload = {
"collectionName": name,
"vector": vector_to_search,
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
limit = int(payload.get("limit", 100))
assert len(res) == limit
ids = [item['id'] for item in res]
assert len(ids) == len(set(ids))
payload = {
"collectionName": name,
"outputFields": ["*"],
"id": ids[0],
}
rsp = self.vector_client.vector_get(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {res}")
logger.info(f"res: {len(res)}")
for item in res:
assert item['id'] == ids[0]
@pytest.mark.L0
@pytest.mark.parametrize("id_field_type", ["list", "one"])
@pytest.mark.parametrize("include_invalid_id", [True, False])
@pytest.mark.parametrize("include_output_fields", [True, False])
def test_get_vector_complex(self, id_field_type, include_output_fields, include_invalid_id):
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
output_fields = get_common_fields_by_data(data)
uids = []
for item in data:
uids.append(item.get("uid"))
payload = {
"collectionName": name,
"outputFields": output_fields,
"filter": f"uid in {uids}",
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
ids = []
for r in res:
ids.append(r['id'])
logger.info(f"ids: {len(ids)}")
id_to_get = None
if id_field_type == "list":
id_to_get = ids
if id_field_type == "one":
id_to_get = ids[0]
if include_invalid_id:
if isinstance(id_to_get, list):
id_to_get[-1] = 0
else:
id_to_get = 0
# get by id list
payload = {
"collectionName": name,
"outputFields": output_fields,
"id": id_to_get
}
rsp = self.vector_client.vector_get(payload)
assert rsp['code'] == 200
res = rsp['data']
if isinstance(id_to_get, list):
if include_invalid_id:
assert len(res) == len(id_to_get) - 1
else:
assert len(res) == len(id_to_get)
else:
if include_invalid_id:
assert len(res) == 0
else:
assert len(res) == 1
for r in rsp['data']:
if isinstance(id_to_get, list):
assert r['id'] in id_to_get
else:
assert r['id'] == id_to_get
if include_output_fields:
for field in output_fields:
assert field in r
class TestDeleteVector(TestBase):
@pytest.mark.L0
@pytest.mark.parametrize("include_invalid_id", [True, False])
@pytest.mark.parametrize("id_field_type", ["list", "one"])
def test_delete_vector_default(self, id_field_type, include_invalid_id):
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
time.sleep(1)
output_fields = get_common_fields_by_data(data)
uids = []
for item in data:
uids.append(item.get("uid"))
payload = {
"collectionName": name,
"outputFields": output_fields,
"filter": f"uid in {uids}",
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
ids = []
for r in res:
ids.append(r['id'])
logger.info(f"ids: {len(ids)}")
id_to_get = None
if id_field_type == "list":
id_to_get = ids
if id_field_type == "one":
id_to_get = ids[0]
if include_invalid_id:
if isinstance(id_to_get, list):
id_to_get.append(0)
else:
id_to_get = 0
if isinstance(id_to_get, list):
if len(id_to_get) >= 100:
id_to_get = id_to_get[-100:]
# delete by id list
payload = {
"collectionName": name,
"id": id_to_get
}
rsp = self.vector_client.vector_delete(payload)
assert rsp['code'] == 200
logger.info(f"delete res: {rsp}")
# verify data deleted
if not isinstance(id_to_get, list):
id_to_get = [id_to_get]
payload = {
"collectionName": name,
"filter": f"id in {id_to_get}",
}
time.sleep(5)
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
assert len(rsp['data']) == 0
def test_delete_vector_with_invalid_api_key(self):
"""
Delete a vector with an invalid api key
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
output_fields = get_common_fields_by_data(data)
uids = []
for item in data:
uids.append(item.get("uid"))
payload = {
"collectionName": name,
"outputFields": output_fields,
"filter": f"uid in {uids}",
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
ids = []
for r in res:
ids.append(r['id'])
logger.info(f"ids: {len(ids)}")
id_to_get = ids
# delete by id list
payload = {
"collectionName": name,
"id": id_to_get
}
client = self.vector_client
client.api_key = "invalid_api_key"
rsp = client.vector_delete(payload)
assert rsp['code'] == 1800
def test_delete_vector_with_invalid_collection_name(self):
"""
Delete a vector with an invalid collection name
"""
name = gen_collection_name()
self.name = name
self.init_collection(name, dim=128, nb=3000)
# query data
# expr = f"id in {[i for i in range(10)]}".replace("[", "(").replace("]", ")")
expr = "id > 0"
payload = {
"collectionName": name,
"filter": expr,
"limit": 3000,
"offset": 0,
"outputFields": ["id", "uid"]
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
id_list = [r['id'] for r in res]
delete_expr = f"id in {[i for i in id_list[:10]]}"
# query data before delete
payload = {
"collectionName": name,
"filter": delete_expr,
"limit": 3000,
"offset": 0,
"outputFields": ["id", "uid"]
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
# delete data
payload = {
"collectionName": name + "_invalid",
"filter": delete_expr,
}
rsp = self.vector_client.vector_delete(payload)
assert rsp['code'] == 1
def test_delete_vector_with_non_primary_key(self):
"""
Delete a vector with a non-primary key, expect no data were deleted
"""
name = gen_collection_name()
self.name = name
self.init_collection(name, dim=128, nb=300)
expr = "uid > 0"
payload = {
"collectionName": name,
"filter": expr,
"limit": 3000,
"offset": 0,
"outputFields": ["id", "uid"]
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
logger.info(f"res: {len(res)}")
id_list = [r['uid'] for r in res]
delete_expr = f"uid in {[i for i in id_list[:10]]}"
# query data before delete
payload = {
"collectionName": name,
"filter": delete_expr,
"limit": 3000,
"offset": 0,
"outputFields": ["id", "uid"]
}
rsp = self.vector_client.vector_query(payload)
assert rsp['code'] == 200
res = rsp['data']
num_before_delete = len(res)
logger.info(f"res: {len(res)}")
# delete data
payload = {
"collectionName": name,
"filter": delete_expr,
}
rsp = self.vector_client.vector_delete(payload)
# query data after delete
payload = {
"collectionName": name,
"filter": delete_expr,
"limit": 3000,
"offset": 0,
"outputFields": ["id", "uid"]
}
time.sleep(1)
rsp = self.vector_client.vector_query(payload)
assert len(rsp["data"]) == num_before_delete

View File

@ -0,0 +1,2 @@
MAX_SUM_OFFSET_AND_LIMIT = 16384

View File

@ -1,4 +1,5 @@
import logging
from loguru import logger as loguru_logger
import sys
from config.log_config import log_config
@ -54,4 +55,6 @@ log_debug = log_config.log_debug
log_info = log_config.log_info
log_err = log_config.log_err
log_worker = log_config.log_worker
test_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
self_defined_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
loguru_log = loguru_logger
test_log = self_defined_log

View File

@ -1,38 +0,0 @@
import time
from datetime import datetime
import functools
from utils.util_log import test_log as log
DEFAULT_FMT = '[{start_time}] [{elapsed:0.8f}s] {collection_name} {func_name} -> {res!r}'
def trace(fmt=DEFAULT_FMT, prefix='test', flag=True):
def decorate(func):
@functools.wraps(func)
def inner_wrapper(*args, **kwargs):
# args[0] is an instance of ApiCollectionWrapper class
flag = args[0].active_trace
if flag:
start_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
t0 = time.perf_counter()
res, result = func(*args, **kwargs)
elapsed = time.perf_counter() - t0
end_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
func_name = func.__name__
collection_name = args[0].collection.name
# arg_lst = [repr(arg) for arg in args[1:]][:100]
# arg_lst.extend(f'{k}={v!r}' for k, v in kwargs.items())
# arg_str = ', '.join(arg_lst)[:200]
log_str = f"[{prefix}]" + fmt.format(**locals())
# TODO: add report function in this place, like uploading to influxdb
# it is better a async way to do this, in case of blocking the request processing
log.info(log_str)
return res, result
else:
res, result = func(*args, **kwargs)
return res, result
return inner_wrapper
return decorate

View File

@ -0,0 +1,155 @@
import random
import time
import random
import string
from faker import Faker
import numpy as np
from sklearn import preprocessing
import requests
from loguru import logger
import datetime
fake = Faker()
def random_string(length=8):
letters = string.ascii_letters
return ''.join(random.choice(letters) for _ in range(length))
def gen_collection_name(prefix="test_collection", length=8):
name = f'{prefix}_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + random_string(length=length)
return name
def admin_password():
return "Milvus"
def invalid_cluster_name():
res = [
"demo" * 100,
"demo" + "!",
"demo" + "@",
]
return res
def wait_cluster_be_ready(cluster_id, client, timeout=120):
t0 = time.time()
while True and time.time() - t0 < timeout:
rsp = client.cluster_describe(cluster_id)
if rsp['code'] == 200:
if rsp['data']['status'] == "RUNNING":
return time.time() - t0
time.sleep(1)
logger.debug("wait cluster to be ready, cost time: %s" % (time.time() - t0))
return -1
def gen_data_by_type(field):
data_type = field["type"]
if data_type == "bool":
return random.choice([True, False])
if data_type == "int8":
return random.randint(-128, 127)
if data_type == "int16":
return random.randint(-32768, 32767)
if data_type == "int32":
return random.randint(-2147483648, 2147483647)
if data_type == "int64":
return random.randint(-9223372036854775808, 9223372036854775807)
if data_type == "float32":
return np.float64(random.random()) # Object of type float32 is not JSON serializable, so set it as float64
if data_type == "float64":
return np.float64(random.random())
if "varchar" in data_type:
length = int(data_type.split("(")[1].split(")")[0])
return "".join([chr(random.randint(97, 122)) for _ in range(length)])
if "floatVector" in data_type:
dim = int(data_type.split("(")[1].split(")")[0])
return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
return None
def get_data_by_fields(fields, nb):
# logger.info(f"fields: {fields}")
fields_not_auto_id = []
for field in fields:
if not field.get("autoId", False):
fields_not_auto_id.append(field)
# logger.info(f"fields_not_auto_id: {fields_not_auto_id}")
data = []
for i in range(nb):
tmp = {}
for field in fields_not_auto_id:
tmp[field["name"]] = gen_data_by_type(field)
data.append(tmp)
return data
def get_random_json_data(uid=None):
# gen random dict data
if uid is None:
uid = 0
data = {"uid": uid, "name": fake.name(), "address": fake.address(), "text": fake.text(), "email": fake.email(),
"phone_number": fake.phone_number(),
"json": {
"name": fake.name(),
"address": fake.address()
}
}
for i in range(random.randint(1, 10)):
data["key" + str(random.randint(1, 100_000))] = "value" + str(random.randint(1, 100_000))
return data
def get_data_by_payload(payload, nb=100):
dim = payload.get("dimension", 128)
vector_field = payload.get("vectorField", "vector")
data = []
if nb == 1:
data = [{
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
**get_random_json_data()
}]
else:
for i in range(nb):
data.append({
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
**get_random_json_data(uid=i)
})
return data
def get_common_fields_by_data(data, exclude_fields=None):
fields = set()
if isinstance(data, dict):
data = [data]
if not isinstance(data, list):
raise Exception("data must be list or dict")
common_fields = set(data[0].keys())
for d in data:
keys = set(d.keys())
common_fields = common_fields.intersection(keys)
if exclude_fields is not None:
exclude_fields = set(exclude_fields)
common_fields = common_fields.difference(exclude_fields)
return list(common_fields)
def get_all_fields_by_data(data, exclude_fields=None):
fields = set()
for d in data:
keys = list(d.keys())
fields.union(keys)
if exclude_fields is not None:
exclude_fields = set(exclude_fields)
fields = fields.difference(exclude_fields)
return list(fields)